[CPU][RISC-V] Add RVV micro GEMM for WNA16#44324
Conversation
Add RVV as a new ISA type and wire it through the WNA16 GEMM dispatch. RVV reuses the VEC micro-kernel implementation and follows the same dequantize path (has_zp / use_desc_act) as x86 AMX/VEC. On the Python side, detect RISC-V architecture and return "rvv" as the ISA hint, skipping AMX detection. Co-authored-by: lyd1992 <liuyudong@iscas.ac.cn> Co-authored-by: wcy2003 <233313160abc@gmail.com> Signed-off-by: wcy <233313160abc@gmail.com>
Add a dedicated RVV MicroGemm implementation for WNA16. The RVV path uses an Mx8 micro-kernel with K unrolled by 4 and scalar-vector RVV FMA, while keeping the existing packed B layout and MicroGemm interface. Co-authored-by: lyd1992 <liuyudong@iscas.ac.cn> Co-authored-by: wcy2003 <233313160abc@gmail.com> Signed-off-by: wcy <233313160abc@gmail.com>
Co-authored-by: lyd1992 <liuyudong@iscas.ac.cn> Co-authored-by: wcy2003 <233313160abc@gmail.com> Signed-off-by: wcy <233313160abc@gmail.com>
Co-authored-by: lyd1992 <liuyudong@iscas.ac.cn> Co-authored-by: wcy2003 <233313160abc@gmail.com> Signed-off-by: wcy <233313160abc@gmail.com>
|
Hi @bigPYJ1151, gentle ping when you have a chance to review this PR. This adds an RVV-specific I also added test/benchmark results in the PR description: the RVV path matches the existing VEC path with Thanks! |
|
Hi, @bigPYJ1151 |
Signed-off-by: wcy <233313160abc@gmail.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Signed-off-by: wcy <233313160abc@gmail.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Signed-off-by: Qiang Li <qiang.li2@amd.com>
This PR adds an RVV-specific micro GEMM implementation for CPU WNA16 on
RISC-V and wires the W4A16 GPTQ CPU path to use it.
Purpose
Add an RVV-specific micro GEMM kernel for WNA16
The existing VEC path already uses the generic vector abstraction on
RISC-V, but it still follows the generic
FP32Vec16tile shape. For theWNA16 micro-kernel this creates higher register pressure on current RVV
targets.
This PR adds
MicroGemm<ISA::RVV, scalar_t>with an RVV-specific innerkernel. The new kernel uses an internal Mx8 tile, keeps the external N=32
packed weight layout compatible, uses scalar-vector FMA for the activation
broadcast pattern, and unrolls K by 4.
Wire W4A16 GPTQ to the RVV GEMM backend
The W4A16 GPTQ CPU path now dispatches to
MicroGemm<ISA::RVV, scalar_t>whenisa_hint == "rvv". The dequantizationpath remains shared with the existing WNA16 implementation; only the micro
GEMM backend changes after the packed B buffer is prepared.
The Python W4A16 kernel dispatch was also verified to pass
isa_hint="rvv"into
ops.cpu_gemm_wna16on RISC-V.Follow-up
The current RVV micro GEMM tile shape is tuned for the tested VLEN=128 target.
A future optimization can select different tile sizes according to the target
VLEN.
Test Plan
1. RVV GEMM test
The synthetic benchmark directly calls the public CPU WNA16 op twice with the
same GPTQ WNA16 inputs and uses
torch.profilerto measure thecpu_wna16::gemmevent:ops.cpu_gemm_wna16(..., isa_hint="vec")ops.cpu_gemm_wna16(..., isa_hint="rvv")The outputs are compared before timing the GEMM profiler event. All tested
GPTQ WNA16 shapes have
max|rvv-vec| = 0.2. WNA16 dispatch test
The WNA16 model smoke tests were also run on RISC-V. The following
WNA16-related model cases passed:
OMP_NUM_THREADS=60 python -m pytest \ tests/quantization/test_cpu_wna16.py \ -vv -s --tb=short \ -k "AWQ or GPTQ or w4a16 or int4"test Result:
1. RVV GEMM test
Measured with GPTQ/AWQ WNA16, BF16 activation/output, no zero-points. The table
reports the
cpu_wna16::gemmprofiler event only. The benchmark usestorch.profileraround the publicops.cpu_gemm_wna16call and extracts theGEMM event to compare the VEC and RVV micro GEMM backends while keeping the
same dequantization path and packed B-buffer layout.
2. tests/quantization/test_cpu_wna16.py
TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQTheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQQwen/Qwen1.5-0.5B-Chat-GPTQ-Int4RedHatAI/Qwen3-1.7B-quantized.w4a16OPEA/Qwen2.5-0.5B-Instruct-int4-sym-incThe remaining full-file failures were from FP8/MXFP4 model cases, which are
outside the W4A16 GPTQ RVV scope of this PR.