Skip to content

jnp.max uses input precision for VJP reduction #38885

Description

@sbodenstein

Description

With BF16 inputs:

import jax
import jax.numpy as jnp

jax.config.update("jax_allow_f16_reductions", False)
x = jnp.ones((100,), dtype=jnp.bfloat16)
f = jax.jit(jax.grad(jnp.max))
f(x)

fails with

[google3/third_party/py/jax/_src/lax/lax.py](https://colab.corp.google.com/drive/1d-RElSsJRqBWB_U2hnKUq5r8eJjn6ucM#) in _reduce_sum_dtype_rule(operand, axes, **_)
   8432       not config.allow_f16_reductions.value and
   8433       not all(core.definitely_equal(operand.shape[d], 1) for d in axes)):
-> 8434     raise ValueError(f"reduce_sum on operand {operand.str_short(True)} is not "
   8435                      "allowed when jax_allow_f16_reductions=False.")
   8436   return dt

ValueError: reduce_sum on operand bf16[100] is not allowed when jax_allow_f16_reductions=False.

The cause: the jnp.max VJP needs to count the number of values in the array that have exactly the max value, and divide by this count. The implementation checks for equality, producing a bool array, then sums the bool array, which is then erroneously done in BF16, rather than an integer type appropriate for counts. This could lead to unstable numerics. This might be even worse for lower precision types (FP8/FP4).

Note that jnp.count_nonzero is effectively what the VJP needs to call, and this implementation does the correct thing (casts the bool array to an integer array before reducing to count, rather than cast to the input type).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.11.0
jaxlib: 0.11.0
numpy:  2.4.6
python: 3.13.11

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions