Description
jax.nn.scaled_matmul crashes immediately on CPU with an internal NotImplementedError rather than raising a clear user-facing error indicating the function requires GPU/cuDNN.
Minimal reproducer
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, [3, 16, 16], dtype=jnp.bfloat16)
out = jax.nn.scaled_matmul(x, x, x, x)
Error
NotImplementedError: MLIR translation rule for primitive 'scaled_matmul' not found for platform cpu
Expected behavior
A clear error message stating that scaled_matmul requires a GPU with cuDNN support, not an internal MLIR lowering error.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.10.1
jaxlib: 0.10.1
numpy: 2.4.6
python: 3.12.13 | packaged by Anaconda, Inc. | (main, Mar 19 2026, 20:20:58) [GCC 14.3.0]
device info: cpu-1, 1 local devices
JAX_PLATFORMS=cpu
Description
jax.nn.scaled_matmulcrashes immediately on CPU with an internalNotImplementedErrorrather than raising a clear user-facing error indicating the function requires GPU/cuDNN.Minimal reproducer
Error
NotImplementedError: MLIR translation rule for primitive 'scaled_matmul' not found for platform cpu
Expected behavior
A clear error message stating that
scaled_matmulrequires a GPU with cuDNN support, not an internal MLIR lowering error.System info (python version, jaxlib version, accelerator, etc.)
jax: 0.10.1
jaxlib: 0.10.1
numpy: 2.4.6
python: 3.12.13 | packaged by Anaconda, Inc. | (main, Mar 19 2026, 20:20:58) [GCC 14.3.0]
device info: cpu-1, 1 local devices
JAX_PLATFORMS=cpu