[Kernel][Bugfix] Fix INT8 per-token-head KV cache rounding in Triton reshape-and-cache#45361
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. Agent GuidelinesIMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban. 🚀 |
3b05907 to
db35384
Compare
|
JartX@d32166c Resolved also here |
|
resolve the precommit please |
|
@JartX Thanks! I checked the pre-run-check log, and it looks like this is the first-time contributor gate rather than an actual pre-commit failure. The PR needs the I’ve run the local checks listed in the PR description. Could a maintainer please add the appropriate label if the PR looks good to proceed, so the pre-commit / CI jobs can run? |
|
You're right. Please align your commit with this pull request: int4 was already rounded by default. It's to follow design patterns and keep the upstream as clean as possible by eliminating dependencies. Thanks a lot for the catch! This passed the acc tests because it only manifests in long contexts, more so than gsm8k. |
|
I'll try to move it so the bug fix arrives as soon as possible. @tjtanaa @yewentao256 please the bug is real check That error occurred because the gsm8k acc doesn't cover contexts as long as the ones in which it manifests. I've also implemented it here: After the removal of the dedicated kernels, it was lost, so I've reintroduced standard rounding. This aligns with what's pending implementation. In the downstream RDNA3, which is what the community is compiling, it was already implemented in: |
| if current_platform.is_rocm(): | ||
|
|
||
| def _is_supported_kv_cache_dtype(kv_cache_dtype: str) -> bool: | ||
| return kv_cache_dtype in _NATIVE_KV_CACHE_DTYPES or is_quantized_kv_cache( | ||
| kv_cache_dtype | ||
| ) | ||
| @triton.jit | ||
| def _round_to_nearest(x): | ||
| return tl.extra.hip.libdevice.round(x) | ||
|
|
||
|
|
||
| elif current_platform.is_xpu(): | ||
|
|
||
| @triton.jit | ||
| def _round_to_nearest(x): | ||
| return tl.extra.intel.libdevice.round(x) | ||
|
|
||
|
|
||
| else: | ||
|
|
||
| @triton.jit | ||
| def _round_to_nearest(x): | ||
| return tl.extra.cuda.libdevice.round(x) |
Signed-off-by: ZedongLiu <113341356+Zedong-Liu@users.noreply.github.com>
db35384 to
4921619
Compare
|
@JartX @yewentao256 Thanks for the pointer. I updated the PR to align with #40835 / 6cfb19c: INT8 rounding now uses pure Triton arithmetic under IS_INT_QUANT, without tl.extra.*, libdevice, or platform-specific helpers. FP8 keeps the existing clamp + dtype cast behavior. |
|
@Zedong-Liu please disclose the test results of your test plan (snippet of logs). There are no test results disclosed. We will only merge this PR once it has been validated as there are no Radeon GPU on the CI. |
Thanks, here are the local validation results. Environment:
Focused regression test: conda run -n vllm python -m pytest \
tests/quantization/test_per_token_kv_cache.py::test_int8_per_token_head_raw_cache_matches_round_reference -qResult: Additional checks: conda run -n vllm python -m py_compile \
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py \
tests/quantization/test_per_token_kv_cache.py
git diff --checkResult: both passed. This focused test checks the exact raw INT8 KV cache values written by the Triton reshape-and-cache kernel against the PyTorch round reference, so it directly covers the regression fixed by this PR. For downstream validation on our long-context extraction workload, the relative score loss vs bf16 dropped from about 5% with the truncating INT8 path to about 0.5% with this rounding fix. |
|
Hi @Zedong-Liu can you update your branch with main please? |
…ad-kv-rounding-main
a934737 to
23144e5
Compare
|
@JartX @tjtanaa Environment:
Focused regression test: conda run -n vllm python -m pytest \
tests/quantization/test_per_token_kv_cache.py::test_int8_per_token_head_raw_cache_matches_round_reference -qResult: Additional checks: conda run -n vllm python -m py_compile \
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py \
tests/quantization/test_per_token_kv_cache.py
git diff --checkResult: both passed. This focused test checks the exact raw INT8 KV cache values written by the Triton reshape-and-cache kernel against the PyTorch round reference, so it directly covers the regression fixed by this PR. For downstream validation on our long-context extraction workload, the relative score loss vs bf16 dropped from about 5% with the truncating INT8 path to about 0.5% with this rounding fix. |
|
Thanks again for the review and guidance. All required checks are now passing, the PR has the necessary approval, and the branch is up to date with no conflicts. Could a maintainer please merge it when convenient? |
…reshape-and-cache (vllm-project#45361) Signed-off-by: ZedongLiu <113341356+Zedong-Liu@users.noreply.github.com>
…reshape-and-cache (vllm-project#45361) Signed-off-by: ZedongLiu <113341356+Zedong-Liu@users.noreply.github.com>
…reshape-and-cache (vllm-project#45361) Signed-off-by: ZedongLiu <113341356+Zedong-Liu@users.noreply.github.com> Signed-off-by: Qiang Li <qiang.li2@amd.com>
…reshape-and-cache (vllm-project#45361) Signed-off-by: ZedongLiu <113341356+Zedong-Liu@users.noreply.github.com> (cherry picked from commit 9c450b1)
What this PR does
This PR fixes the Triton reshape-and-cache path for
kv_cache_dtype=int8_per_token_head.The per-token-head KV cache kernel computes one absmax scale per
(token, head), but the INT8 path previously clamped the scaled K/V values and then relied on the implicit float-to-int8 store. That store truncates toward zero, while INT8 absmax quantization should round to nearest before storing.This change makes the INT8 path explicitly round the scaled values before clamping/storing them to the KV cache. FP8 per-token-head keeps its existing clamp-and-cast behavior.
Why
This is a bug fix / semantic consistency fix.
The corrected behavior matches standard dynamic absmax INT8 quantization semantics:
It also aligns this path with vLLM's existing per-token INT8 quantization utility and the PyTorch reference used by the tests.
The previous round-trip tests compared only dequantized values with a loose tolerance, so truncate-vs-round differences in the raw INT8 cache values were not caught. Removing the truncation bias can also avoid accuracy degradation in downstream workloads that use this KV cache mode; in our downstream validation, the observed relative score loss versus bf16 decreased from about 5% to about 0.5%.
Duplicate PR check
I checked the currently open PRs for
int8_per_token_head round,triton_reshape_and_cache_flash int8 round, andper-token-head KV cache int8. The related open PRs cover ROCm kernels, INT4/INT2 per-token-head features, hybrid KV manager/page-size fixes, or older INT8 KV cache feature work; I did not find an open PR that fixes this INT8 per-token-head cache-write rounding bug.Tests
conda run -n vllm python -m pytest tests/quantization/test_per_token_kv_cache.py::test_int8_per_token_head_raw_cache_matches_round_reference -qconda run -n vllm python -m py_compile vllm/v1/attention/ops/triton_reshape_and_cache_flash.py tests/quantization/test_per_token_kv_cache.pygit diff --checkAdded a raw INT8 cache-value test that checks the Triton output against a PyTorch round-to-nearest reference.
AI assistance was used in preparing this PR.