Description
With BF16 inputs:
import jax
import jax.numpy as jnp
jax.config.update("jax_allow_f16_reductions", False)
x = jnp.ones((100,), dtype=jnp.bfloat16)
f = jax.jit(jax.grad(jnp.max))
f(x)
fails with
[google3/third_party/py/jax/_src/lax/lax.py](https://colab.corp.google.com/drive/1d-RElSsJRqBWB_U2hnKUq5r8eJjn6ucM#) in _reduce_sum_dtype_rule(operand, axes, **_)
8432 not config.allow_f16_reductions.value and
8433 not all(core.definitely_equal(operand.shape[d], 1) for d in axes)):
-> 8434 raise ValueError(f"reduce_sum on operand {operand.str_short(True)} is not "
8435 "allowed when jax_allow_f16_reductions=False.")
8436 return dt
ValueError: reduce_sum on operand bf16[100] is not allowed when jax_allow_f16_reductions=False.
The cause: the jnp.max VJP needs to count the number of values in the array that have exactly the max value, and divide by this count. The implementation checks for equality, producing a bool array, then sums the bool array, which is then erroneously done in BF16, rather than an integer type appropriate for counts. This could lead to unstable numerics. This might be even worse for lower precision types (FP8/FP4).
Note that jnp.count_nonzero is effectively what the VJP needs to call, and this implementation does the correct thing (casts the bool array to an integer array before reducing to count, rather than cast to the input type).
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.11.0
jaxlib: 0.11.0
numpy: 2.4.6
python: 3.13.11
Description
With BF16 inputs:
fails with
The cause: the
jnp.maxVJP needs to count the number of values in the array that have exactly the max value, and divide by this count. The implementation checks for equality, producing a bool array, then sums the bool array, which is then erroneously done in BF16, rather than an integer type appropriate for counts. This could lead to unstable numerics. This might be even worse for lower precision types (FP8/FP4).Note that
jnp.count_nonzerois effectively what the VJP needs to call, and this implementation does the correct thing (casts the bool array to an integer array before reducing to count, rather than cast to the input type).System info (python version, jaxlib version, accelerator, etc.)