[Perf][DSv4/DSv3.2] Add cluster-cooperative topK kernel for low-latency scenarios#43008
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a cluster-persistent TopK implementation for SM90+ architectures, leveraging TMA and DSMEM for improved performance at K=1024. The review feedback identifies critical correctness issues related to potential out-of-bounds memory accesses during TMA load operations due to incorrect size calculations. It also suggests using unsigned integers for counters to improve type safety and consistency.
| for (uint32_t i = 0; i < kNumStages8; i++) { | ||
| if (i >= num_iters) break; | ||
| const auto off = i * kSizePerStage; | ||
| const auto sz = min(kSizePerStage, len_aligned - off) * sizeof(float); |
There was a problem hiding this comment.
The tma_load operation uses len_aligned - off to determine the size of the data to load. len_aligned is the my_len value rounded up to the nearest multiple of 4. If my_len is not a multiple of 4, len_aligned will be greater than my_len. This can lead to tma_load attempting to read beyond the actual data available (my_len), resulting in an out-of-bounds memory access. The size should be capped by the actual my_len.
For example, if my_len = 5 and kAlign = 4, then len_aligned = 8. If off = 0, sz would be calculated based on 8 - 0 = 8 elements, but only 5 are valid. This is a critical correctness issue.
const auto sz = min(kSizePerStage, my_len - off) * sizeof(float);
| for (uint32_t i = 0; i < kNumStages4; i++) { | ||
| if (i >= ni) break; | ||
| const auto o = i * kSizePerStage; | ||
| const auto sz = min(kSizePerStage, la - o) * sizeof(float); |
There was a problem hiding this comment.
Similar to the issue on line 505, the tma_load in stream_pass4 uses la - o to determine the size. la is the length value rounded up to the nearest multiple of 4. This can cause an out-of-bounds read if length is not a multiple of 4, as tma_load might attempt to read beyond the actual length of the data. The size should be capped by the actual length.
This is a critical correctness issue.
const auto sz = min(kSizePerStage, length - o) * sizeof(float);
| const auto u = (sl + kA-1)/kA, b = u/CS, e = u%CS; | ||
| const auto lu = b + (rank < e ? 1u : 0u); | ||
| const auto ou = rank * b + min(rank, e); | ||
| const auto ms = ou * kA, ml = min(ms + lu * kA, sl) - ms; |
There was a problem hiding this comment.
This line in large_topk_twopass4 appears to be a copy-paste error from large_topk_fused8. It incorrectly uses len_aligned - off (which is not defined in this scope and would be a compilation error if not for len_aligned being defined in large_topk_fused8 but not here, leading to potential undefined behavior or a compiler error depending on context). It should use ml - off, which represents the actual length of the current partition. Using an undefined or incorrect variable for tma_load size is a critical correctness issue.
const auto sz = min(kSizePerStage, ml - off) * sizeof(float);
|
|
||
| struct alignas(16) MatchBin { uint32_t bin, above_count, equal_count; }; | ||
| struct alignas(8) Tie { uint32_t idx; float score; }; | ||
| struct ClusterState { int output_counter; }; |
There was a problem hiding this comment.
The output_counter in ClusterState is declared as an int. While unlikely to overflow with current K and kMaxTies values, it's generally safer and more consistent to use uint32_t for counters that are always non-negative, especially when uint32_t values (la, le) are cast to int before being added atomically. This prevents any potential issues if la or le were to exceed INT_MAX in future modifications.
struct ClusterState { uint32_t output_counter; };
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
UpdateSeems like increasing the Cluster Size to 16 for On Blackwell:
src: https://docs.nvidia.com/cuda/blackwell-tuning-guide/index.html#thread-block-clusters Microbenchmarks |
6259142 to
313480b
Compare
Port cooperative cluster top-k kernels and launchers to csrc/libtorch_stable/, gate registration with VLLM_ENABLE_COOPERATIVE_TOPK, and route decode sparse indexer to cooperative_topk when eligible. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
313480b to
d867d1f
Compare
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
…cy scenarios (vllm-project#43008) Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
@LopezCastroRoberto - this PR breaks DeepSeek V4 Flash on DGX Spark (sm120). It fails during startup CUDA graph memory profiling. The crash is in the sparse attention indexer path added/changed by this PR: sparse_attn_indexer -> torch.ops._C.cooperative_topk Error: RuntimeError: launch_cooperative_cluster, This reproduces with DeepSeek-V4-Flash, MTP enabled, TP=2, kv-cache fp8, on SM12.1. It also reproduces whether VLLM_USE_BREAKABLE_CUDAGRAPH is auto-enabled or explicitly disabled, so breakable CUDA graphs do not appear to avoid the failing kernel. The selector currently uses has_device_capability(90), so SM100/SM120 take the cooperative_topk path. If I locally restrict cooperative_topk to exact SM90 and let SM120 fall back to persistent_topk, the model starts successfully, completes CUDA graph profiling/capture, and reaches API server startup. Could cooperative_topk be guarded to SM90 only, or otherwise validated/fallbacked for SM100/SM120? @mgoin - FYI. |
…cy scenarios (vllm-project#43008) Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Qiang Li <qiang.li2@amd.com>
…cy scenarios (vllm-project#43008) Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Summary
Adds a cluster-cooperative topK kernel for low-latency cases, which uses TMA and DSMEM, meaning that SM90+ arch is required. This new kernel has been extensively tuned to cover the whole low-latency regime (i.e. bs ≤32) with a cluster-level cooperation via cluster.sync() and distributed SMEM histogram reduction, eliminating the complexity of the persistent scheduler and multi-CTA spin-barrier coordination in persistent_topK v1 (PR #37421).
This is also expected to improve GPU contention since other streams are running in parallel — the persistent scheduler in topK v1, for some configs, occupied all the GPU resources, starving concurrent work on other streams. This new version avoids the headroom pre-allocation that was needed to prevent the persistent kernel from deadlocking under occupancy pressure.
Approach
Additional features added to this algorithm by this PR — bs≤32:
histogram_4096_topk<12>- evolved from v1'shistogram_2048_topk, widened to 4096-bin coarse histogram with warp-ballot tie-breaking for ≤64 ties, eliminating most radix refinement rounds.redux.sync.addhardware warp reduce, replacing the__shfl_xor_syncbutterfly tree with a single PTX instruction for warp-wide reduction.For bs>32, we inherited FilteredTopK from topK v1 (PR #37421):
histogram_4096_topk<12, 8>fast path for sl ≤ 32K (32 floats per thread, 4096-bin histogram)Architecture of topK v2
Microbenchmarking - vLLM topK v2 vs. v1 (B300)
topK=512:
topK=1024:
topK=2048:
E2E results (B300)
vllm serve deepseek-ai/DeepSeek-V4-Flash -tp 4 --kv-cache-dtype fp8vllm bench serve --model deepseek-ai/DeepSeek-V4-Flash --input-len 512000 --output-len 2048 --num-prompts 8 --max-concurrency 1MAIN:
PR:
~10% TPOT improvement
E2E results with concurrency=1 for increasing ISL - Averaged on 3 runs on B300
#Conclusions: TPOT is super flat for ISL 32K up to 512K - note that the difference is 7.1ms vs 7.3ms. For 1M it increases a bit more than expected - should probably be studied separately. These conclusions match the results in the microbenchmarks (topK=512): while for 32-128K the topK v2 kernel gets 70-80% perf improvement w.r.t. v1, for 262K the improvement is ~40%.
UPDATE: Check-out the latest follow-up comments on this PR below.
Accuracy
GSM8K
python tests/evals/gsm8k/gsm8k_eval.pyMAIN:
PR:
MRCR 2-needle eval — MAIN vs PR (DeepSeek-V4-Flash, B300, TP=4)
Potential TODOs:
logits.stride(0) % 4 == 0, enforced via TORCH_CHECK rn. Always true whenstride = max_model_lenfrom model config, but odd user-supplied--max-model-lenwould crash. Decide whether to keep topK v1 as fallback, pad in the dispatcher, or just trigger TORCH_CHECK for non multiple of 4--max-model-lenconfigs.