Skip to content

jax.nn.scaled_matmul raises NotImplementedError on CPU with no graceful error #38813

Description

@anna-researcher

Description

jax.nn.scaled_matmul crashes immediately on CPU with an internal NotImplementedError rather than raising a clear user-facing error indicating the function requires GPU/cuDNN.

Minimal reproducer

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, [3, 16, 16], dtype=jnp.bfloat16)

out = jax.nn.scaled_matmul(x, x, x, x)

Error

NotImplementedError: MLIR translation rule for primitive 'scaled_matmul' not found for platform cpu

Expected behavior

A clear error message stating that scaled_matmul requires a GPU with cuDNN support, not an internal MLIR lowering error.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.10.1
jaxlib: 0.10.1
numpy: 2.4.6
python: 3.12.13 | packaged by Anaconda, Inc. | (main, Mar 19 2026, 20:20:58) [GCC 14.3.0]
device info: cpu-1, 1 local devices
JAX_PLATFORMS=cpu

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions