From d93e4e4bd0d6f1a372a889a6f6ca9f46db19d2aa Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sat, 13 Jun 2026 23:25:03 -0700 Subject: [PATCH 01/10] feat: enable native mxfp8 moe for minimax m3 mi300x --- .github/configs/amd-master.yaml | 8 +- .../fixed_seq_len/minimaxm3_fp8_mi300x.sh | 25 +- .../minimaxm3_mi300x_mxfp8.patch | 656 ++++++++++++++++++ 3 files changed, 681 insertions(+), 8 deletions(-) create mode 100644 benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index f18b3f94e..c9225a984 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -2847,10 +2847,10 @@ minimaxm3-fp8-mi355x-vllm-mtp: - { tp: 4, conc-start: 1, conc-end: 64, spec-decoding: mtp } - { tp: 8, ep: 8, dp-attn: true, conc-start: 128, conc-end: 256, spec-decoding: mtp } -# MiniMax-M3 MXFP8 MI300X day-zero recipe. Reuse the dedicated ROCm image and -# MI355X serving shape, but retain the default BF16 KV cache because this -# checkpoint lacks calibrated ROCm FP8 attention scales. Use the TP8-only H100 -# search space: TP8 for latency and TP8+EP8 (TEP) at high concurrency. +# MiniMax-M3 MXFP8 MI300X recipe. Apply the checked-in native gfx94x MXFP8 MoE +# patch to the dedicated ROCm image, but retain the default BF16 KV cache +# because this checkpoint lacks calibrated ROCm FP8 attention scales. Use the +# TP8-only H100 search space: TP8 for latency and TP8+EP8 at high concurrency. minimaxm3-fp8-mi300x-vllm: image: vllm/vllm-openai-rocm:minimax-m3 model: MiniMaxAI/MiniMax-M3-MXFP8 diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index e3522e00a..438725883 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -1,10 +1,12 @@ #!/usr/bin/env bash # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. -# Reuses the dedicated ROCm image and the MI355X serving shape. Block size 128 -# is mandatory for MSA sparse attention. Keep the default BF16 KV cache on -# gfx942: the checkpoint has no calibrated q/prob scales for ROCm FP8 -# attention, and vLLM's fallback scale of 1.0 corrupts model accuracy. +# Reuses the dedicated ROCm image and applies the checked-in gfx94x MXFP8 MoE +# patch before starting vLLM. Block size 128 is mandatory for MSA sparse +# attention. Keep the default BF16 KV cache on gfx942: the checkpoint has no +# calibrated q/prob scales for ROCm FP8 attention, and vLLM's fallback scale of +# 1.0 corrupts model accuracy. +# Target image vLLM revision: 4a560dd8db67c270f5e2afb614558271b76f2294. source "$(dirname "$0")/../../benchmark_lib.sh" @@ -24,6 +26,21 @@ if [[ -n "$SLURM_JOB_ID" ]]; then echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" fi +VLLM_PACKAGE_ROOT="$( + python - <<'PY' +from pathlib import Path + +import vllm + +print(Path(vllm.__file__).resolve().parent.parent) +PY +)" +MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" +MXFP8_ORACLE="$VLLM_PACKAGE_ROOT/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py" +if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then + patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH" +fi + if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi if [ -n "$ROCR_VISIBLE_DEVICES" ]; then diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch new file mode 100644 index 000000000..92e5e9890 --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -0,0 +1,656 @@ +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +index 33851fdc8..8bcfa9d13 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +@@ -1,24 +1,25 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +-"""Native MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA4 (gfx950) via Triton +-``tl.dot_scaled`` (hardware microscaling matmul). ++"""Fused MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA3/CDNA4. + + The expert GEMMs consume the FP8 E4M3 weights and their E8M0 block scales + directly (no dequant-to-BF16), and activations are MXFP8-quantized per token. +-On CDNA4 ``dot_scaled`` maps to the native MX matrix-core ops; on other archs +-Triton upcasts to BF16 (so this stays correct, just not faster) — but the +-oracle only selects this path on gfx950 and routes everything else to the +-BF16 ``Mxfp8EmulationTritonExperts`` fallback. ++CDNA4 uses ``tl.dot_scaled`` and native MX matrix-core ops. CDNA3 stores the ++weights as E4M3FNUZ, runs one native FP8 ``tl.dot`` per 32-value MX block, and ++applies the E8M0 scale products in-register. Both paths keep weights compressed ++in HBM instead of expanding them to persistent BF16. + + Structure mirrors vLLM's ``fused_moe_kernel``: tokens are sorted by expert + (``moe_align_block_size``); each program computes a ``[BLOCK_M, BLOCK_N]`` tile +-for one expert, accumulating over K with ``dot_scaled``. SwiGLU-OAI activation +-and the top-k weighted reduction run in PyTorch between/after the two GEMMs. ++for one expert, accumulating over K with the architecture-specific fused path. ++SwiGLU-OAI activation and the top-k weighted reduction run between/after the ++two GEMMs. + """ + + import torch + + import vllm.model_executor.layers.fused_moe.modular_kernel as mk ++from vllm import _custom_ops as ops + from vllm.logger import init_logger + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( + Mxfp8TritonExpertsBase, +@@ -35,8 +36,29 @@ from vllm.triton_utils import tl, triton + logger = init_logger(__name__) + + ++def _select_split_k( ++ max_post_padded: int, ++ block_m: int, ++ N: int, ++ K: int, ++) -> int: ++ if not (current_platform.is_fp8_fnuz() and K >= 2048 and N <= 1024): ++ return 1 ++ ++ base_programs = triton.cdiv(max_post_padded, block_m) * triton.cdiv(N, 128) ++ if base_programs >= 256: ++ return 1 ++ ++ target_split = triton.cdiv(256, max(base_programs, 1)) ++ return min( ++ 8, ++ 1 << (target_split - 1).bit_length(), ++ triton.cdiv(K, 32), ++ ) ++ ++ + @triton.jit +-def _mxfp8_grouped_gemm_kernel( ++def _mxfp8_grouped_gemm_dot_scaled_kernel( + a_ptr, + a_scale_ptr, + b_ptr, +@@ -67,9 +89,11 @@ def _mxfp8_grouped_gemm_kernel( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ++ SPLIT_K: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) ++ pid_k = tl.program_id(2) + num_post = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_M >= num_post: + return +@@ -101,28 +125,194 @@ def _mxfp8_grouped_gemm_kernel( + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + n_mask = offs_n < N +- for _ in range(0, tl.cdiv(K, BLOCK_K)): +- a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) +- b = tl.load(b_ptrs, mask=n_mask[:, None], other=0.0) +- asc = tl.load(as_ptrs, mask=token_mask[:, None], other=0) +- bsc = tl.load(bs_ptrs, mask=n_mask[:, None], other=0) +- acc += tl.dot_scaled(a, asc, "e4m3", b.T, bsc, "e4m3") +- +- a_ptrs += BLOCK_K * stride_ak +- b_ptrs += BLOCK_K * stride_bk +- as_ptrs += (BLOCK_K // 32) * stride_ask +- bs_ptrs += (BLOCK_K // 32) * stride_bsk ++ if SPLIT_K == 1: ++ for _ in range(0, tl.cdiv(K, BLOCK_K)): ++ a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) ++ b = tl.load(b_ptrs, mask=n_mask[:, None], other=0.0) ++ asc = tl.load(as_ptrs, mask=token_mask[:, None], other=0) ++ bsc = tl.load(bs_ptrs, mask=n_mask[:, None], other=0) ++ acc += tl.dot_scaled(a, asc, "e4m3", b.T, bsc, "e4m3") ++ ++ a_ptrs += BLOCK_K * stride_ak ++ b_ptrs += BLOCK_K * stride_bk ++ as_ptrs += (BLOCK_K // 32) * stride_ask ++ bs_ptrs += (BLOCK_K // 32) * stride_bsk ++ else: ++ num_k_tiles = tl.cdiv(K, BLOCK_K) ++ tiles_per_split = tl.cdiv(num_k_tiles, SPLIT_K) ++ k_tile = pid_k * tiles_per_split ++ k_tile_end = min(k_tile + tiles_per_split, num_k_tiles) ++ a_ptrs += k_tile * BLOCK_K * stride_ak ++ b_ptrs += k_tile * BLOCK_K * stride_bk ++ as_ptrs += k_tile * (BLOCK_K // 32) * stride_ask ++ bs_ptrs += k_tile * (BLOCK_K // 32) * stride_bsk ++ while k_tile < k_tile_end: ++ a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) ++ b = tl.load(b_ptrs, mask=n_mask[:, None], other=0.0) ++ asc = tl.load(as_ptrs, mask=token_mask[:, None], other=0) ++ bsc = tl.load(bs_ptrs, mask=n_mask[:, None], other=0) ++ acc += tl.dot_scaled(a, asc, "e4m3", b.T, bsc, "e4m3") ++ ++ a_ptrs += BLOCK_K * stride_ak ++ b_ptrs += BLOCK_K * stride_bk ++ as_ptrs += (BLOCK_K // 32) * stride_ask ++ bs_ptrs += (BLOCK_K // 32) * stride_bsk ++ k_tile += 1 + + if MUL_WEIGHT: + w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) + acc = acc * w[:, None] + + c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn +- tl.store( +- c_ptrs, +- acc.to(c_ptr.dtype.element_ty), +- mask=token_mask[:, None] & n_mask[None, :], ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ if SPLIT_K == 1: ++ tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) ++ else: ++ tl.atomic_add(c_ptrs, acc, mask=c_mask) ++ ++ ++@triton.jit ++def _mxfp8_grouped_gemm_fnuz_kernel( ++ a_ptr, ++ a_scale_ptr, ++ b_ptr, ++ b_scale_ptr, ++ c_ptr, ++ topk_weights_ptr, ++ sorted_token_ids_ptr, ++ expert_ids_ptr, ++ num_tokens_post_padded_ptr, ++ N, ++ K, ++ num_valid_tokens, ++ top_k, ++ stride_am, ++ stride_ak, ++ stride_asm, ++ stride_ask, ++ stride_be, ++ stride_bn, ++ stride_bk, ++ stride_bse, ++ stride_bsn, ++ stride_bsk, ++ stride_cm, ++ stride_cn, ++ A_DIV: tl.constexpr, ++ MUL_WEIGHT: tl.constexpr, ++ BLOCK_M: tl.constexpr, ++ BLOCK_N: tl.constexpr, ++ BLOCK_K: tl.constexpr, ++ SPLIT_K: tl.constexpr, ++): ++ pid_m = tl.program_id(0) ++ pid_n = tl.program_id(1) ++ pid_k = tl.program_id(2) ++ num_post = tl.load(num_tokens_post_padded_ptr) ++ if pid_m * BLOCK_M >= num_post: ++ return ++ ++ offs_tid = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) ++ offs_token = tl.load(sorted_token_ids_ptr + offs_tid).to(tl.int64) ++ token_mask = offs_token < num_valid_tokens ++ off_e = tl.load(expert_ids_ptr + pid_m).to(tl.int64) ++ ++ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ++ offs_k = tl.arange(0, 32) ++ a_row = offs_token // A_DIV ++ ++ a_ptrs = a_ptr + a_row[:, None] * stride_am + offs_k[None, :] * stride_ak ++ as_ptrs = a_scale_ptr + a_row * stride_asm ++ b_ptrs = ( ++ b_ptr ++ + off_e * stride_be ++ + offs_n[:, None] * stride_bn ++ + offs_k[None, :] * stride_bk + ) ++ bs_ptrs = b_scale_ptr + off_e * stride_bse + offs_n * stride_bsn ++ ++ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) ++ n_mask = offs_n < N ++ if SPLIT_K == 1: ++ for _ in range(0, tl.cdiv(K, BLOCK_K)): ++ for k_offset in tl.static_range(0, BLOCK_K, 32): ++ a = tl.load( ++ a_ptrs + k_offset * stride_ak, ++ mask=token_mask[:, None], ++ other=0.0, ++ ) ++ b = tl.load( ++ b_ptrs + k_offset * stride_bk, ++ mask=n_mask[:, None], ++ other=0.0, ++ ) ++ asc = tl.load( ++ as_ptrs + (k_offset // 32) * stride_ask, ++ mask=token_mask, ++ other=0, ++ ).to(tl.float32) ++ bsc = tl.load( ++ bs_ptrs + (k_offset // 32) * stride_bsk, ++ mask=n_mask, ++ other=0, ++ ).to(tl.float32) ++ block_scale = tl.exp2(asc[:, None] + bsc[None, :] - 254.0) ++ acc += tl.dot(a, b.T) * block_scale ++ ++ a_ptrs += BLOCK_K * stride_ak ++ b_ptrs += BLOCK_K * stride_bk ++ as_ptrs += (BLOCK_K // 32) * stride_ask ++ bs_ptrs += (BLOCK_K // 32) * stride_bsk ++ else: ++ num_k_tiles = tl.cdiv(K, BLOCK_K) ++ tiles_per_split = tl.cdiv(num_k_tiles, SPLIT_K) ++ k_tile = pid_k * tiles_per_split ++ k_tile_end = min(k_tile + tiles_per_split, num_k_tiles) ++ a_ptrs += k_tile * BLOCK_K * stride_ak ++ b_ptrs += k_tile * BLOCK_K * stride_bk ++ as_ptrs += k_tile * (BLOCK_K // 32) * stride_ask ++ bs_ptrs += k_tile * (BLOCK_K // 32) * stride_bsk ++ while k_tile < k_tile_end: ++ for k_offset in tl.static_range(0, BLOCK_K, 32): ++ a = tl.load( ++ a_ptrs + k_offset * stride_ak, ++ mask=token_mask[:, None], ++ other=0.0, ++ ) ++ b = tl.load( ++ b_ptrs + k_offset * stride_bk, ++ mask=n_mask[:, None], ++ other=0.0, ++ ) ++ asc = tl.load( ++ as_ptrs + (k_offset // 32) * stride_ask, ++ mask=token_mask, ++ other=0, ++ ).to(tl.float32) ++ bsc = tl.load( ++ bs_ptrs + (k_offset // 32) * stride_bsk, ++ mask=n_mask, ++ other=0, ++ ).to(tl.float32) ++ block_scale = tl.exp2(asc[:, None] + bsc[None, :] - 254.0) ++ acc += tl.dot(a, b.T) * block_scale ++ ++ a_ptrs += BLOCK_K * stride_ak ++ b_ptrs += BLOCK_K * stride_bk ++ as_ptrs += (BLOCK_K // 32) * stride_ask ++ bs_ptrs += (BLOCK_K // 32) * stride_bsk ++ k_tile += 1 ++ ++ if MUL_WEIGHT: ++ w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) ++ acc = acc * w[:, None] ++ ++ c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ if SPLIT_K == 1: ++ tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) ++ else: ++ tl.atomic_add(c_ptrs, acc, mask=c_mask) + + + def _grouped_gemm_mxfp8( +@@ -143,16 +333,48 @@ def _grouped_gemm_mxfp8( + ) -> torch.Tensor: + M_routed = num_valid_tokens + E, N, K = w.shape +- assert K % 128 == 0, f"MXFP8 native MoE requires K%128==0, got K={K}" ++ k_alignment = 32 if current_platform.is_fp8_fnuz() else 128 ++ assert K % k_alignment == 0, ( ++ f"MXFP8 native MoE requires K%{k_alignment}==0, got K={K}" ++ ) ++ BLOCK_K = ( ++ 128 ++ if current_platform.is_fp8_fnuz() and K % 128 == 0 and block_m <= 16 ++ else 64 ++ if current_platform.is_fp8_fnuz() and K % 64 == 0 ++ else 32 ++ if current_platform.is_fp8_fnuz() ++ else 128 ++ ) ++ # moe_align_block_size allocates for the worst case where every expert is ++ # active. At small batches that can be much larger than the number of ++ # blocks that can contain valid assignments. Limit the launch to the ++ # tighter static upper bound; the device-side num_post check handles the ++ # remaining tail. ++ max_post_padded = min(sorted_token_ids.shape[0], M_routed * block_m) ++ BLOCK_N = 128 ++ m_blocks = triton.cdiv(max_post_padded, block_m) ++ n_blocks = triton.cdiv(N, BLOCK_N) ++ split_k = _select_split_k(max_post_padded, block_m, N, K) ++ + # Under expert parallelism (expert_map set) tokens routed to non-local + # experts are dropped from sorted_token_ids, so their output rows are never +- # written — zero them so the downstream reduction ignores their garbage. +- alloc = torch.zeros if expert_map is not None else torch.empty +- out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) +- BLOCK_N = 128 +- BLOCK_K = 128 +- grid = (triton.cdiv(sorted_token_ids.shape[0], block_m), triton.cdiv(N, BLOCK_N)) +- _mxfp8_grouped_gemm_kernel[grid]( ++ # written. Split-K also needs a zeroed FP32 accumulation buffer. ++ kernel_out_dtype = torch.float32 if split_k > 1 else out_dtype ++ needs_zero = expert_map is not None or split_k > 1 ++ alloc = torch.zeros if needs_zero else torch.empty ++ out = alloc((M_routed, N), dtype=kernel_out_dtype, device=a_q.device) ++ grid = (m_blocks, n_blocks, split_k) ++ kernel = ( ++ _mxfp8_grouped_gemm_fnuz_kernel ++ if current_platform.is_fp8_fnuz() ++ else _mxfp8_grouped_gemm_dot_scaled_kernel ++ ) ++ if current_platform.is_fp8_fnuz() and ( ++ a_q.dtype != torch.float8_e4m3fnuz or w.dtype != torch.float8_e4m3fnuz ++ ): ++ raise ValueError("gfx94x MXFP8 MoE requires E4M3FNUZ inputs.") ++ kernel[grid]( + a_q, + a_scale, + w, +@@ -183,7 +405,8 @@ def _grouped_gemm_mxfp8( + BLOCK_M=block_m, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, +- num_warps=8, ++ SPLIT_K=split_k, ++ num_warps=(4 if current_platform.is_fp8_fnuz() and block_m <= 32 else 8), + ) + return out + +@@ -202,12 +425,20 @@ def fused_moe_mxfp8_native( + limit: float | None, + global_num_experts: int, + expert_map: torch.Tensor | None, ++ output: torch.Tensor | None = None, + ) -> torch.Tensor: + T, H = hidden_states.shape + top_k = topk_ids.shape[1] + M = T * top_k + +- block_m = 64 ++ if current_platform.is_fp8_fnuz(): ++ # Padding is per expert, so tile from the average expert occupancy ++ # rather than the total routed-token count. MiniMax-M3 has 128 experts; ++ # a 64-row tile wastes most of both GEMMs at low occupancy. ++ tokens_per_expert = max(1, M // global_num_experts) ++ block_m = max(16, min(1 << (tokens_per_expert - 1).bit_length(), 64)) ++ else: ++ block_m = 64 + sorted_ids, expert_ids, num_post = moe_align_block_size( + topk_ids, + block_m, +@@ -218,6 +449,12 @@ def fused_moe_mxfp8_native( + + # GEMM1: x (mxfp8) @ w13^T -> [M, 2I] + a_q, a_s = mxfp8_e4m3_quantize(hidden_states) ++ max_post_padded = min(sorted_ids.shape[0], M * block_m) ++ g1_dtype = ( ++ torch.float32 ++ if _select_split_k(max_post_padded, block_m, w13.shape[1], w13.shape[2]) > 1 ++ else hidden_states.dtype ++ ) + g1 = _grouped_gemm_mxfp8( + a_q, + a_s, +@@ -229,7 +466,7 @@ def fused_moe_mxfp8_native( + M, + top_k, + block_m, +- hidden_states.dtype, ++ g1_dtype, + a_div=top_k, + expert_map=expert_map, + ) # [M, 2I] +@@ -256,17 +493,27 @@ def fused_moe_mxfp8_native( + M, + top_k, + block_m, +- torch.float32, ++ hidden_states.dtype if current_platform.is_fp8_fnuz() else torch.float32, + a_div=1, + mul_weight_by=topk_weights.reshape(-1).to(torch.float32), + expert_map=expert_map, + ) # [M, H] == [T*top_k, H] + +- return g2.view(T, top_k, H).sum(dim=1).to(hidden_states.dtype) ++ if current_platform.is_fp8_fnuz(): ++ if output is None: ++ output = torch.empty_like(hidden_states) ++ ops.moe_sum(g2.view(T, top_k, H), output) ++ return output ++ ++ result = g2.view(T, top_k, H).sum(dim=1).to(hidden_states.dtype) ++ if output is not None: ++ output.copy_(result) ++ return output ++ return result + + + class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +- """Native MXFP8 MoE (CDNA4 ``dot_scaled``) on gfx950.""" ++ """Fused MXFP8 MoE on gfx94x/gfx95x.""" + + @property + def quant_dtype(self) -> torch.dtype | str | None: +@@ -283,7 +530,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + + @staticmethod + def _supports_current_device() -> bool: +- return current_platform.is_rocm() and current_platform.supports_mx() ++ return current_platform.is_rocm() and ( ++ current_platform.supports_mx() or current_platform.is_fp8_fnuz() ++ ) + + def apply( + self, +@@ -322,5 +571,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + limit=limit, + global_num_experts=global_num_experts, + expert_map=expert_map, ++ output=output, + ) +- output.copy_(out) ++ assert out is output +diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py +index acbf2cb46..1fcf67678 100644 +--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py ++++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py +@@ -55,9 +55,9 @@ class Fp8MoeBackend(Enum): + # Dequantize-to-BF16 emulation for MXFP8 on devices without a native + # MXFP8 MoE kernel (e.g. ROCm). Weights pass through unchanged here. + EMULATION = "EMULATION" +- # MXFP8 MoE via a Triton ``dot_scaled`` kernel that lowers to CDNA4 +- # (gfx950) native MX matrix-core ops. Weights stay in MXFP8 (no load-time +- # format conversion); the FP8 values + E8M0 scales are consumed directly. ++ # Fused ROCm MXFP8 MoE. CDNA4 (gfx95x) uses native ``dot_scaled`` MX ops; ++ # CDNA3 (gfx94x) uses E4M3FNUZ FP8 partial dots with in-register E8M0 scale ++ # application. Both consume compressed weights directly. + NATIVE_MXFP8 = "NATIVE_MXFP8" + + +@@ -463,6 +463,13 @@ def convert_to_fp8_moe_kernel_format( + ) + + w13, w2 = prepare_fp8_moe_layer_for_cpu(w13, w2) ++ elif fp8_backend == Fp8MoeBackend.NATIVE_MXFP8 and current_platform.is_fp8_fnuz(): ++ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ++ normalize_mxfp8_e4m3fn_to_e4m3fnuz, ++ ) ++ ++ w13, w13_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz(w13, w13_scale) ++ w2, w2_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz(w2, w2_scale) + else: + if fp8_backend not in [ + Fp8MoeBackend.TRITON, +@@ -470,8 +477,8 @@ def convert_to_fp8_moe_kernel_format( + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + Fp8MoeBackend.XPU, +- # EMULATION dequantizes weights at runtime; NATIVE_MXFP8 consumes +- # the MXFP8 weights as-is — neither needs a load-time layout change. ++ # EMULATION consumes checkpoint layout directly. CDNA4 NATIVE_MXFP8 ++ # also needs no layout change; CDNA3 normalization is handled above. + Fp8MoeBackend.EMULATION, + Fp8MoeBackend.NATIVE_MXFP8, + ]: +diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +index d0d7c7648..cb3e5d446 100644 +--- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py ++++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +@@ -79,12 +79,20 @@ def _select_kernel_cls( + def _select_rocm_mxfp8_backend() -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: + """ROCm fallback when vendor MXFP8 backends are unavailable.""" + +- if current_platform.supports_mx(): ++ if current_platform.supports_mx() or current_platform.is_fp8_fnuz(): + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( + Mxfp8NativeTritonExperts, + ) + +- logger.info_once("Using native CDNA4 (gfx950) MXFP8 dot_scaled MoE backend.") ++ if current_platform.supports_mx(): ++ logger.info_once( ++ "Using native CDNA4 (gfx95x) MXFP8 dot_scaled MoE backend." ++ ) ++ else: ++ logger.info_once( ++ "Using fused CDNA3 (gfx94x) MXFP8 FP8 MoE backend; weights " ++ "remain compressed and 1x32 scales are applied in-kernel." ++ ) + return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts + + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( +diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +index e6063b463..fa5b01615 100644 +--- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py ++++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +@@ -11,6 +11,32 @@ MXFP8_SCALE_DTYPE = torch.uint8 + MXFP8_BLOCK_SIZE = 32 + + ++def normalize_mxfp8_e4m3fn_to_e4m3fnuz( ++ values: torch.Tensor, ++ scales: torch.Tensor, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ """Convert OCP E4M3 MXFP8 storage to AMD E4M3FNUZ in place. ++ ++ For an identical byte pattern, E4M3FNUZ represents half the E4M3FN value. ++ Incrementing the E8M0 exponent preserves the dequantized value without ++ expanding the one-byte weights. OCP negative zero (0x80) is NaN in FNUZ, ++ so it must be canonicalized to positive zero before reinterpreting. ++ """ ++ if values.dtype == torch.float8_e4m3fnuz: ++ return values, scales ++ if values.dtype != torch.float8_e4m3fn: ++ raise ValueError(f"Expected E4M3FN or E4M3FNUZ values, got {values.dtype}.") ++ if scales.dtype != MXFP8_SCALE_DTYPE: ++ raise ValueError(f"Expected {MXFP8_SCALE_DTYPE} scales, got {scales.dtype}.") ++ if int(scales.max().item()) >= 254: ++ raise ValueError("Cannot convert MXFP8 scale exponent 254 to E4M3FNUZ.") ++ ++ value_bits = values.view(torch.int8) ++ value_bits.masked_fill_(value_bits == -128, 0) ++ scales.add_(1) ++ return value_bits.view(torch.float8_e4m3fnuz), scales ++ ++ + def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout.""" + scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8 +@@ -38,6 +64,7 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: + def _mxfp8_e4m3_quantize_torch( + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, ++ value_dtype: torch.dtype = MXFP8_VALUE_DTYPE, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Naive MXFP8 quantization. + For each block of 32 elements along the last dimension, compute a +@@ -65,7 +92,7 @@ def _mxfp8_e4m3_quantize_torch( + descale = torch.exp2(scale_biased - 127.0) + x_scaled = x_blocked / descale.unsqueeze(-1) + +- x_fp8 = x_scaled.view(orig_shape).to(MXFP8_VALUE_DTYPE) ++ x_fp8 = x_scaled.view(orig_shape).to(value_dtype) + + if x.ndim == 2: + M, K = x.shape +@@ -139,6 +166,7 @@ def _mxfp8_e4m3_quantize_triton( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Fused 2D MXFP8 quant (non-swizzled, row-major [M, K//32] scales).""" ++ from vllm.platforms import current_platform + from vllm.triton_utils import triton + + global _MXFP8_QUANT_KERNEL +@@ -147,7 +175,10 @@ def _mxfp8_e4m3_quantize_triton( + + M, K = x.shape + x = x.contiguous() +- xq = torch.empty((M, K), dtype=MXFP8_VALUE_DTYPE, device=x.device) ++ value_dtype = ( ++ torch.float8_e4m3fnuz if current_platform.is_fp8_fnuz() else MXFP8_VALUE_DTYPE ++ ) ++ xq = torch.empty((M, K), dtype=value_dtype, device=x.device) + scales = torch.empty( + (M, K // MXFP8_BLOCK_SIZE), dtype=MXFP8_SCALE_DTYPE, device=x.device + ) +@@ -233,7 +264,19 @@ def mxfp8_e4m3_quantize_fake( + alignment: int = 0, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile tracing.""" +- fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE) ++ from vllm.platforms import current_platform ++ ++ value_dtype = ( ++ torch.float8_e4m3fnuz ++ if ( ++ current_platform.is_fp8_fnuz() ++ and not is_sf_swizzled_layout ++ and x.ndim == 2 ++ and x.shape[-1] % MXFP8_BLOCK_SIZE == 0 ++ ) ++ else MXFP8_VALUE_DTYPE ++ ) ++ fp_data = torch.empty_like(x, dtype=value_dtype) + + block_size = MXFP8_BLOCK_SIZE + +diff --git a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +index 836649b72..9572c5109 100644 +--- a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py ++++ b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +@@ -24,6 +24,7 @@ HIP graphs already eliminate — measured end-to-end throughput is identical + + import torch + ++from vllm.platforms import current_platform + from vllm.triton_utils import tl, triton + + +@@ -132,11 +133,10 @@ def swiglu_oai_quantize_mxfp8( + ) -> tuple[torch.Tensor, torch.Tensor]: + """SwiGLU-OAI on split-layout ``[M, 2I]`` fused with MXFP8 activation-quant. + +- Returns ``(act_q [M, I] float8_e4m3fn, act_scale [M, I//32] uint8 E8M0)``, +- identical to ``mxfp8_e4m3_quantize(swiglu_oai_split(gate_up))`` but in a +- single Triton pass (no bf16 intermediate). Used between the two GEMMs of the +- native MXFP8 MoE. Numerically equivalent to the unfused chain (bit-exact on +- measured MoE shapes); marginally more accurate (fp32 act, no bf16 round-trip). ++ Returns platform-native E4M3 values plus ``[M, I//32]`` uint8 E8M0 scales, ++ equivalent to ``mxfp8_e4m3_quantize(swiglu_oai_split(gate_up))`` but in a ++ single Triton pass (no bf16 intermediate). gfx94x emits E4M3FNUZ so ++ ``tl.dot`` lowers to the native CDNA3 FP8 matrix cores. + """ + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + MXFP8_BLOCK_SIZE, +@@ -151,7 +151,10 @@ def swiglu_oai_quantize_mxfp8( + ) + g1 = gate_up.reshape(-1, two_i).contiguous() + M = g1.shape[0] +- aq = torch.empty((M, n_inter), dtype=MXFP8_VALUE_DTYPE, device=g1.device) ++ value_dtype = ( ++ torch.float8_e4m3fnuz if current_platform.is_fp8_fnuz() else MXFP8_VALUE_DTYPE ++ ) ++ aq = torch.empty((M, n_inter), dtype=value_dtype, device=g1.device) + asc = torch.empty( + (M, n_inter // MXFP8_BLOCK_SIZE), dtype=MXFP8_SCALE_DTYPE, device=g1.device + ) From 6b7049738d5712ff5510fa1d5c58efc681ab945d Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sat, 13 Jun 2026 23:41:02 -0700 Subject: [PATCH 02/10] chore: trigger MiniMax M3 MI300X MXFP8 sweep --- perf-changelog.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index bcda7dbd0..358c87b3a 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3755,3 +3755,12 @@ - "TP8-only search space (gfx942 192 GB is memory-tight, like H100): TP8 latency rows started at conc 1, TP8+EP8 (TEP) at high concurrency, across 1k1k and 8k1k" - "[AI generated draft test] The shipped ROCm image's AMD MiniMax-M3 model lacks SupportsEagle3, so the recipe patches it in-place at runtime (functionstackx/vllm#1, upstream vllm-project/vllm#45546; validated green on MI355X) before serving — validates EAGLE3 on MI300X ahead of an image rebuild" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1749 + +- config-keys: + - minimaxm3-fp8-mi300x-vllm + description: + - "Replace the gfx942 BF16 MoE emulation path with a fused Triton MXFP8 backend that keeps expert weights compressed and applies 1x32 E8M0 scales in-kernel" + - "Normalize OCP E4M3FN checkpoint values to AMD E4M3FNUZ so the grouped expert GEMMs use native CDNA3 FP8 matrix cores" + - "Preserve the existing TP8 and TP8+EP8 search space, default BF16 KV cache, and BF16 dense-linear path" + - "Reduce model memory from 101.27 GiB to 53.18 GiB per GPU and increase available KV-cache memory from 69.35 GiB to 117.43 GiB" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1753 From e9fa9b7726642fa87675e691e38ed8638d7fe8e2 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 00:28:49 -0700 Subject: [PATCH 03/10] perf: tune MiniMax M3 gfx942 MXFP8 tiles --- .../minimaxm3_mi300x_mxfp8.patch | 327 ++++++------------ 1 file changed, 106 insertions(+), 221 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index 92e5e9890..07a1da5bf 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -1,5 +1,18 @@ +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +index 71dd7634a..63500487d 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +@@ -4,7 +4,7 @@ + + ``Mxfp8TritonExpertsBase`` stashes E8M0 weight scales for checkpoint layout. + ``Mxfp8EmulationTritonExperts`` dequantizes to BF16 and runs ``TritonExperts`` +-for devices without a native MXFP8 MoE kernel (e.g. ROCm gfx942 / MI300). ++for devices without a fused MXFP8 MoE kernel. + """ + + import torch diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 33851fdc8..8bcfa9d13 100644 +index 33851fdc8..942a40876 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py @@ -1,24 +1,25 @@ @@ -36,114 +49,19 @@ index 33851fdc8..8bcfa9d13 100644 from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( Mxfp8TritonExpertsBase, -@@ -35,8 +36,29 @@ from vllm.triton_utils import tl, triton - logger = init_logger(__name__) - - -+def _select_split_k( -+ max_post_padded: int, -+ block_m: int, -+ N: int, -+ K: int, -+) -> int: -+ if not (current_platform.is_fp8_fnuz() and K >= 2048 and N <= 1024): -+ return 1 -+ -+ base_programs = triton.cdiv(max_post_padded, block_m) * triton.cdiv(N, 128) -+ if base_programs >= 256: -+ return 1 -+ -+ target_split = triton.cdiv(256, max(base_programs, 1)) -+ return min( -+ 8, -+ 1 << (target_split - 1).bit_length(), -+ triton.cdiv(K, 32), -+ ) -+ -+ +@@ -36,7 +37,7 @@ logger = init_logger(__name__) + + @triton.jit -def _mxfp8_grouped_gemm_kernel( +def _mxfp8_grouped_gemm_dot_scaled_kernel( a_ptr, a_scale_ptr, b_ptr, -@@ -67,9 +89,11 @@ def _mxfp8_grouped_gemm_kernel( - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -+ SPLIT_K: tl.constexpr, - ): - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) -+ pid_k = tl.program_id(2) - num_post = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_M >= num_post: - return -@@ -101,28 +125,194 @@ def _mxfp8_grouped_gemm_kernel( - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - n_mask = offs_n < N -- for _ in range(0, tl.cdiv(K, BLOCK_K)): -- a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) -- b = tl.load(b_ptrs, mask=n_mask[:, None], other=0.0) -- asc = tl.load(as_ptrs, mask=token_mask[:, None], other=0) -- bsc = tl.load(bs_ptrs, mask=n_mask[:, None], other=0) -- acc += tl.dot_scaled(a, asc, "e4m3", b.T, bsc, "e4m3") -- -- a_ptrs += BLOCK_K * stride_ak -- b_ptrs += BLOCK_K * stride_bk -- as_ptrs += (BLOCK_K // 32) * stride_ask -- bs_ptrs += (BLOCK_K // 32) * stride_bsk -+ if SPLIT_K == 1: -+ for _ in range(0, tl.cdiv(K, BLOCK_K)): -+ a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) -+ b = tl.load(b_ptrs, mask=n_mask[:, None], other=0.0) -+ asc = tl.load(as_ptrs, mask=token_mask[:, None], other=0) -+ bsc = tl.load(bs_ptrs, mask=n_mask[:, None], other=0) -+ acc += tl.dot_scaled(a, asc, "e4m3", b.T, bsc, "e4m3") -+ -+ a_ptrs += BLOCK_K * stride_ak -+ b_ptrs += BLOCK_K * stride_bk -+ as_ptrs += (BLOCK_K // 32) * stride_ask -+ bs_ptrs += (BLOCK_K // 32) * stride_bsk -+ else: -+ num_k_tiles = tl.cdiv(K, BLOCK_K) -+ tiles_per_split = tl.cdiv(num_k_tiles, SPLIT_K) -+ k_tile = pid_k * tiles_per_split -+ k_tile_end = min(k_tile + tiles_per_split, num_k_tiles) -+ a_ptrs += k_tile * BLOCK_K * stride_ak -+ b_ptrs += k_tile * BLOCK_K * stride_bk -+ as_ptrs += k_tile * (BLOCK_K // 32) * stride_ask -+ bs_ptrs += k_tile * (BLOCK_K // 32) * stride_bsk -+ while k_tile < k_tile_end: -+ a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) -+ b = tl.load(b_ptrs, mask=n_mask[:, None], other=0.0) -+ asc = tl.load(as_ptrs, mask=token_mask[:, None], other=0) -+ bsc = tl.load(bs_ptrs, mask=n_mask[:, None], other=0) -+ acc += tl.dot_scaled(a, asc, "e4m3", b.T, bsc, "e4m3") -+ -+ a_ptrs += BLOCK_K * stride_ak -+ b_ptrs += BLOCK_K * stride_bk -+ as_ptrs += (BLOCK_K // 32) * stride_ask -+ bs_ptrs += (BLOCK_K // 32) * stride_bsk -+ k_tile += 1 - - if MUL_WEIGHT: - w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) - acc = acc * w[:, None] - - c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn -- tl.store( -- c_ptrs, -- acc.to(c_ptr.dtype.element_ty), -- mask=token_mask[:, None] & n_mask[None, :], -+ c_mask = token_mask[:, None] & n_mask[None, :] -+ if SPLIT_K == 1: -+ tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) -+ else: -+ tl.atomic_add(c_ptrs, acc, mask=c_mask) -+ -+ +@@ -125,6 +126,108 @@ def _mxfp8_grouped_gemm_kernel( + ) + + +@triton.jit +def _mxfp8_grouped_gemm_fnuz_kernel( + a_ptr, @@ -176,11 +94,9 @@ index 33851fdc8..8bcfa9d13 100644 + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, -+ SPLIT_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) -+ pid_k = tl.program_id(2) + num_post = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_M >= num_post: + return @@ -201,95 +117,57 @@ index 33851fdc8..8bcfa9d13 100644 + + off_e * stride_be + + offs_n[:, None] * stride_bn + + offs_k[None, :] * stride_bk - ) ++ ) + bs_ptrs = b_scale_ptr + off_e * stride_bse + offs_n * stride_bsn + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + n_mask = offs_n < N -+ if SPLIT_K == 1: -+ for _ in range(0, tl.cdiv(K, BLOCK_K)): -+ for k_offset in tl.static_range(0, BLOCK_K, 32): -+ a = tl.load( -+ a_ptrs + k_offset * stride_ak, -+ mask=token_mask[:, None], -+ other=0.0, -+ ) -+ b = tl.load( -+ b_ptrs + k_offset * stride_bk, -+ mask=n_mask[:, None], -+ other=0.0, -+ ) -+ asc = tl.load( -+ as_ptrs + (k_offset // 32) * stride_ask, -+ mask=token_mask, -+ other=0, -+ ).to(tl.float32) -+ bsc = tl.load( -+ bs_ptrs + (k_offset // 32) * stride_bsk, -+ mask=n_mask, -+ other=0, -+ ).to(tl.float32) -+ block_scale = tl.exp2(asc[:, None] + bsc[None, :] - 254.0) -+ acc += tl.dot(a, b.T) * block_scale -+ -+ a_ptrs += BLOCK_K * stride_ak -+ b_ptrs += BLOCK_K * stride_bk -+ as_ptrs += (BLOCK_K // 32) * stride_ask -+ bs_ptrs += (BLOCK_K // 32) * stride_bsk -+ else: -+ num_k_tiles = tl.cdiv(K, BLOCK_K) -+ tiles_per_split = tl.cdiv(num_k_tiles, SPLIT_K) -+ k_tile = pid_k * tiles_per_split -+ k_tile_end = min(k_tile + tiles_per_split, num_k_tiles) -+ a_ptrs += k_tile * BLOCK_K * stride_ak -+ b_ptrs += k_tile * BLOCK_K * stride_bk -+ as_ptrs += k_tile * (BLOCK_K // 32) * stride_ask -+ bs_ptrs += k_tile * (BLOCK_K // 32) * stride_bsk -+ while k_tile < k_tile_end: -+ for k_offset in tl.static_range(0, BLOCK_K, 32): -+ a = tl.load( -+ a_ptrs + k_offset * stride_ak, -+ mask=token_mask[:, None], -+ other=0.0, -+ ) -+ b = tl.load( -+ b_ptrs + k_offset * stride_bk, -+ mask=n_mask[:, None], -+ other=0.0, -+ ) -+ asc = tl.load( -+ as_ptrs + (k_offset // 32) * stride_ask, -+ mask=token_mask, -+ other=0, -+ ).to(tl.float32) -+ bsc = tl.load( -+ bs_ptrs + (k_offset // 32) * stride_bsk, -+ mask=n_mask, -+ other=0, -+ ).to(tl.float32) -+ block_scale = tl.exp2(asc[:, None] + bsc[None, :] - 254.0) -+ acc += tl.dot(a, b.T) * block_scale ++ for _ in range(0, tl.cdiv(K, BLOCK_K)): ++ for k_offset in tl.static_range(0, BLOCK_K, 32): ++ a = tl.load( ++ a_ptrs + k_offset * stride_ak, ++ mask=token_mask[:, None], ++ other=0.0, ++ ) ++ b = tl.load( ++ b_ptrs + k_offset * stride_bk, ++ mask=n_mask[:, None], ++ other=0.0, ++ ) ++ asc = tl.load( ++ as_ptrs + (k_offset // 32) * stride_ask, ++ mask=token_mask, ++ other=0, ++ ).to(tl.float32) ++ bsc = tl.load( ++ bs_ptrs + (k_offset // 32) * stride_bsk, ++ mask=n_mask, ++ other=0, ++ ).to(tl.float32) ++ block_scale = tl.exp2(asc[:, None] + bsc[None, :] - 254.0) ++ acc += tl.dot(a, b.T) * block_scale + -+ a_ptrs += BLOCK_K * stride_ak -+ b_ptrs += BLOCK_K * stride_bk -+ as_ptrs += (BLOCK_K // 32) * stride_ask -+ bs_ptrs += (BLOCK_K // 32) * stride_bsk -+ k_tile += 1 ++ a_ptrs += BLOCK_K * stride_ak ++ b_ptrs += BLOCK_K * stride_bk ++ as_ptrs += (BLOCK_K // 32) * stride_ask ++ bs_ptrs += (BLOCK_K // 32) * stride_bsk + + if MUL_WEIGHT: + w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) + acc = acc * w[:, None] + + c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn -+ c_mask = token_mask[:, None] & n_mask[None, :] -+ if SPLIT_K == 1: -+ tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) -+ else: -+ tl.atomic_add(c_ptrs, acc, mask=c_mask) - - ++ tl.store( ++ c_ptrs, ++ acc.to(c_ptr.dtype.element_ty), ++ mask=token_mask[:, None] & n_mask[None, :], ++ ) ++ ++ def _grouped_gemm_mxfp8( -@@ -143,16 +333,48 @@ def _grouped_gemm_mxfp8( + a_q: torch.Tensor, # [M, K] fp8 e4m3 + a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) +@@ -143,16 +246,58 @@ def _grouped_gemm_mxfp8( ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape @@ -313,26 +191,34 @@ index 33851fdc8..8bcfa9d13 100644 + # tighter static upper bound; the device-side num_post check handles the + # remaining tail. + max_post_padded = min(sorted_token_ids.shape[0], M_routed * block_m) -+ BLOCK_N = 128 ++ if current_platform.is_fp8_fnuz() and block_m <= 16: ++ # One wave per 32 output columns avoids the register pressure of the ++ # original 128-column tile. At the very smallest routed batch, pairing ++ # two waves in a 64-column program amortizes launch/indexing overhead. ++ BLOCK_N = 64 if M_routed < 32 else 32 ++ num_warps = 2 if M_routed < 32 else 1 ++ elif current_platform.is_fp8_fnuz() and block_m >= 64 and N >= 2048 and K >= 2048: ++ # EP prefill GEMMs remain register-bound at a 128-column tile even with ++ # 64 rows. Two-wave 64-column programs expose more independent work. ++ BLOCK_N = 64 ++ num_warps = 2 ++ else: ++ BLOCK_N = 128 ++ num_warps = 4 if current_platform.is_fp8_fnuz() and block_m <= 32 else 8 + m_blocks = triton.cdiv(max_post_padded, block_m) + n_blocks = triton.cdiv(N, BLOCK_N) -+ split_k = _select_split_k(max_post_padded, block_m, N, K) + # Under expert parallelism (expert_map set) tokens routed to non-local # experts are dropped from sorted_token_ids, so their output rows are never - # written — zero them so the downstream reduction ignores their garbage. -- alloc = torch.zeros if expert_map is not None else torch.empty -- out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) ++ # written. + alloc = torch.zeros if expert_map is not None else torch.empty + out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) - BLOCK_N = 128 - BLOCK_K = 128 - grid = (triton.cdiv(sorted_token_ids.shape[0], block_m), triton.cdiv(N, BLOCK_N)) - _mxfp8_grouped_gemm_kernel[grid]( -+ # written. Split-K also needs a zeroed FP32 accumulation buffer. -+ kernel_out_dtype = torch.float32 if split_k > 1 else out_dtype -+ needs_zero = expert_map is not None or split_k > 1 -+ alloc = torch.zeros if needs_zero else torch.empty -+ out = alloc((M_routed, N), dtype=kernel_out_dtype, device=a_q.device) -+ grid = (m_blocks, n_blocks, split_k) ++ grid = (m_blocks, n_blocks) + kernel = ( + _mxfp8_grouped_gemm_fnuz_kernel + if current_platform.is_fp8_fnuz() @@ -346,17 +232,16 @@ index 33851fdc8..8bcfa9d13 100644 a_q, a_scale, w, -@@ -183,7 +405,8 @@ def _grouped_gemm_mxfp8( +@@ -183,7 +328,7 @@ def _grouped_gemm_mxfp8( BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - num_warps=8, -+ SPLIT_K=split_k, -+ num_warps=(4 if current_platform.is_fp8_fnuz() and block_m <= 32 else 8), ++ num_warps=num_warps, ) return out -@@ -202,12 +425,20 @@ def fused_moe_mxfp8_native( +@@ -202,12 +347,20 @@ def fused_moe_mxfp8_native( limit: float | None, global_num_experts: int, expert_map: torch.Tensor | None, @@ -378,29 +263,7 @@ index 33851fdc8..8bcfa9d13 100644 sorted_ids, expert_ids, num_post = moe_align_block_size( topk_ids, block_m, -@@ -218,6 +449,12 @@ def fused_moe_mxfp8_native( - - # GEMM1: x (mxfp8) @ w13^T -> [M, 2I] - a_q, a_s = mxfp8_e4m3_quantize(hidden_states) -+ max_post_padded = min(sorted_ids.shape[0], M * block_m) -+ g1_dtype = ( -+ torch.float32 -+ if _select_split_k(max_post_padded, block_m, w13.shape[1], w13.shape[2]) > 1 -+ else hidden_states.dtype -+ ) - g1 = _grouped_gemm_mxfp8( - a_q, - a_s, -@@ -229,7 +466,7 @@ def fused_moe_mxfp8_native( - M, - top_k, - block_m, -- hidden_states.dtype, -+ g1_dtype, - a_div=top_k, - expert_map=expert_map, - ) # [M, 2I] -@@ -256,17 +493,27 @@ def fused_moe_mxfp8_native( +@@ -256,17 +409,27 @@ def fused_moe_mxfp8_native( M, top_k, block_m, @@ -431,7 +294,7 @@ index 33851fdc8..8bcfa9d13 100644 @property def quant_dtype(self) -> torch.dtype | str | None: -@@ -283,7 +530,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -283,7 +446,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): @staticmethod def _supports_current_device() -> bool: @@ -442,7 +305,7 @@ index 33851fdc8..8bcfa9d13 100644 def apply( self, -@@ -322,5 +571,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -322,5 +487,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): limit=limit, global_num_experts=global_num_experts, expert_map=expert_map, @@ -519,6 +382,28 @@ index d0d7c7648..cb3e5d446 100644 return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( +diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py +index 33c7c7532..5f250c912 100644 +--- a/vllm/model_executor/layers/quantization/modelopt.py ++++ b/vllm/model_executor/layers/quantization/modelopt.py +@@ -2086,7 +2086,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + def _dequant_mxfp8_weights_to_bf16(self, layer: RoutedExperts) -> None: + """One-time MXFP8->BF16 weight dequant for the emulation path. + +- On devices without a native MXFP8 MoE kernel (e.g. gfx942 / MI300), ++ On devices without a fused MXFP8 MoE kernel, + ``Mxfp8EmulationTritonExperts`` otherwise dequantizes every expert + weight to BF16 on *every* forward step -- the dominant cost (conc1 + ~1.3 tok/s). Doing the dequant once here and replacing the MXFP8 +@@ -2158,7 +2158,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + routing_tables=layer._expert_routing_tables(), + ) + +- # No native MXFP8 MoE kernel on this device (e.g. gfx942): the emulation ++ # No fused MXFP8 MoE kernel on this device: the emulation + # experts would dequant MXFP8->BF16 every forward step. Convert the + # weights to BF16 once, here, so the MoE runs like a BF16 checkpoint. + # Opt out (VLLM_MXFP8_EMULATION_DEQUANT_AT_LOAD=0) to keep the 1-byte diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index e6063b463..fa5b01615 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py From c3cdc37a301659046c53e75d129c80869dfe9220 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 03:03:54 -0700 Subject: [PATCH 04/10] perf: recover MiniMax M3 MI300X serving curve --- .github/configs/amd-master.yaml | 9 +- .../fixed_seq_len/minimaxm3_fp8_mi300x.sh | 10 +- .../minimaxm3_mi300x_mxfp8.patch | 290 ++++++++++++++++-- perf-changelog.yaml | 18 +- 4 files changed, 290 insertions(+), 37 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index c9225a984..e7dd7cb09 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -2847,10 +2847,11 @@ minimaxm3-fp8-mi355x-vllm-mtp: - { tp: 4, conc-start: 1, conc-end: 64, spec-decoding: mtp } - { tp: 8, ep: 8, dp-attn: true, conc-start: 128, conc-end: 256, spec-decoding: mtp } -# MiniMax-M3 MXFP8 MI300X recipe. Apply the checked-in native gfx94x MXFP8 MoE -# patch to the dedicated ROCm image, but retain the default BF16 KV cache -# because this checkpoint lacks calibrated ROCm FP8 attention scales. Use the -# TP8-only H100 search space: TP8 for latency and TP8+EP8 at high concurrency. +# MiniMax-M3 MXFP8 MI300X recipe. Apply the checked-in hybrid gfx94x MXFP8 MoE +# patch to the dedicated ROCm image: BF16 for small TP batches and EP, native +# compressed MXFP8 for larger TP batches and long context. Retain the default +# BF16 KV cache because this checkpoint lacks calibrated ROCm FP8 attention +# scales. Use TP8 for latency and TP8+EP8 at high concurrency. minimaxm3-fp8-mi300x-vllm: image: vllm/vllm-openai-rocm:minimax-m3 model: MiniMaxAI/MiniMax-M3-MXFP8 diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index e726309e8..50aa7c236 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. -# Reuses the dedicated ROCm image and applies the checked-in gfx94x MXFP8 MoE -# patch before starting vLLM. Block size 128 is mandatory for MSA sparse -# attention. Keep the default BF16 KV cache on gfx942: the checkpoint has no -# calibrated q/prob scales for ROCm FP8 attention, and vLLM's fallback scale of -# 1.0 corrupts model accuracy. +# Reuses the dedicated ROCm image and applies the checked-in hybrid gfx94x +# MXFP8 MoE patch before starting vLLM. Block size 128 is mandatory for MSA +# sparse attention. Keep the default BF16 KV cache on gfx942: the checkpoint +# has no calibrated q/prob scales for ROCm FP8 attention, and vLLM's fallback +# scale of 1.0 corrupts model accuracy. # Target image vLLM revision: 4a560dd8db67c270f5e2afb614558271b76f2294. source "$(dirname "$0")/../../benchmark_lib.sh" diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index 07a1da5bf..31deb4740 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -1,3 +1,17 @@ +diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py +index 0755699d1..4f6046d88 100644 +--- a/vllm/model_executor/layers/fused_moe/config.py ++++ b/vllm/model_executor/layers/fused_moe/config.py +@@ -1276,7 +1276,9 @@ class FusedMoEConfig: + + moe_backend: MoEBackend = "auto" + max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP ++ max_model_len: int = 0 + has_bias: bool = False ++ has_shared_experts: bool = False + is_lora_enabled: bool = False + + # SwiGLU clamp limit. When set, backends that do not implement the clamp diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py index 71dd7634a..63500487d 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py @@ -12,10 +26,10 @@ index 71dd7634a..63500487d 100644 import torch diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 33851fdc8..942a40876 100644 +index 33851fdc8..943a4aedf 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -@@ -1,24 +1,25 @@ +@@ -1,28 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Native MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA4 (gfx950) via Triton @@ -47,10 +61,37 @@ index 33851fdc8..942a40876 100644 import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops from vllm.logger import init_logger ++from vllm.model_executor.layers.fused_moe.config import ( ++ FusedMoEConfig, ++ FusedMoEQuantConfig, ++ biased_moe_quant_config, ++) from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( Mxfp8TritonExpertsBase, -@@ -36,7 +37,7 @@ logger = init_logger(__name__) + ) ++from vllm.model_executor.layers.fused_moe.experts.triton_moe import TritonExperts + from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size, + ) +@@ -34,9 +41,24 @@ from vllm.triton_utils import tl, triton + logger = init_logger(__name__) + ++_BF16_DECODE_TOKEN_THRESHOLD = 16 ++ ++ ++def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: ++ """Limit duplicate BF16 weights to the short-context MiniMax-M3 TP case.""" ++ return ( ++ moe_config.ep_size == 1 ++ and moe_config.has_shared_experts ++ and moe_config.num_experts == 128 ++ and moe_config.experts_per_token == 4 ++ and moe_config.hidden_dim == 6144 ++ and moe_config.intermediate_size == 3072 ++ and 0 < moe_config.max_model_len <= 4096 ++ ) ++ @triton.jit -def _mxfp8_grouped_gemm_kernel( @@ -58,7 +99,7 @@ index 33851fdc8..942a40876 100644 a_ptr, a_scale_ptr, b_ptr, -@@ -125,6 +126,108 @@ def _mxfp8_grouped_gemm_kernel( +@@ -125,6 +147,116 @@ def _mxfp8_grouped_gemm_kernel( ) @@ -138,13 +179,21 @@ index 33851fdc8..942a40876 100644 + as_ptrs + (k_offset // 32) * stride_ask, + mask=token_mask, + other=0, -+ ).to(tl.float32) ++ ).to(tl.uint16) + bsc = tl.load( + bs_ptrs + (k_offset // 32) * stride_bsk, + mask=n_mask, + other=0, -+ ).to(tl.float32) -+ block_scale = tl.exp2(asc[:, None] + bsc[None, :] - 254.0) ++ ).to(tl.uint16) ++ ++ # E8M0 and BF16 use the same eight-bit biased exponent. Shift each ++ # scale byte into a BF16 exponent field, as Marlin does, then form ++ # the per-token/per-output scale product around the FP8 dot. ++ asc_scale = (asc << 7).to(tl.bfloat16, bitcast=True) ++ bsc_scale = (bsc << 7).to(tl.bfloat16, bitcast=True) ++ block_scale = asc_scale[:, None].to(tl.float32) * bsc_scale[None, :].to( ++ tl.float32 ++ ) + acc += tl.dot(a, b.T) * block_scale + + a_ptrs += BLOCK_K * stride_ak @@ -167,7 +216,7 @@ index 33851fdc8..942a40876 100644 def _grouped_gemm_mxfp8( a_q: torch.Tensor, # [M, K] fp8 e4m3 a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) -@@ -143,16 +246,58 @@ def _grouped_gemm_mxfp8( +@@ -143,16 +275,58 @@ def _grouped_gemm_mxfp8( ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape @@ -232,7 +281,7 @@ index 33851fdc8..942a40876 100644 a_q, a_scale, w, -@@ -183,7 +328,7 @@ def _grouped_gemm_mxfp8( +@@ -183,7 +357,7 @@ def _grouped_gemm_mxfp8( BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, @@ -241,7 +290,7 @@ index 33851fdc8..942a40876 100644 ) return out -@@ -202,12 +347,20 @@ def fused_moe_mxfp8_native( +@@ -202,12 +376,20 @@ def fused_moe_mxfp8_native( limit: float | None, global_num_experts: int, expert_map: torch.Tensor | None, @@ -263,7 +312,7 @@ index 33851fdc8..942a40876 100644 sorted_ids, expert_ids, num_post = moe_align_block_size( topk_ids, block_m, -@@ -256,17 +409,27 @@ def fused_moe_mxfp8_native( +@@ -256,17 +438,62 @@ def fused_moe_mxfp8_native( M, top_k, block_m, @@ -291,10 +340,45 @@ index 33851fdc8..942a40876 100644 class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): - """Native MXFP8 MoE (CDNA4 ``dot_scaled``) on gfx950.""" + """Fused MXFP8 MoE on gfx94x/gfx95x.""" ++ ++ def __init__( ++ self, ++ moe_config: FusedMoEConfig, ++ quant_config: FusedMoEQuantConfig, ++ ): ++ super().__init__(moe_config, quant_config) ++ self.w1_bf16: torch.Tensor | None = None ++ self.w2_bf16: torch.Tensor | None = None ++ self.bf16_experts: TritonExperts | None = None ++ if current_platform.is_fp8_fnuz() and _should_use_bf16_decode_fallback( ++ moe_config ++ ): ++ bf16_config = biased_moe_quant_config( ++ None, ++ None, ++ gemm1_alpha=quant_config.gemm1_alpha, ++ gemm1_beta=quant_config.gemm1_beta, ++ gemm1_clamp_limit=quant_config.gemm1_clamp_limit, ++ ) ++ self.bf16_experts = TritonExperts(moe_config, bf16_config) ++ ++ @property ++ def requires_bf16_decode_weights(self) -> bool: ++ return self.bf16_experts is not None ++ ++ def bind_bf16_weights( ++ self, ++ w1_bf16: torch.Tensor, ++ w2_bf16: torch.Tensor, ++ ) -> None: ++ if self.bf16_experts is None: ++ raise RuntimeError("BF16 decode experts are not enabled for this config.") ++ self.w1_bf16 = w1_bf16 ++ self.w2_bf16 = w2_bf16 @property def quant_dtype(self) -> torch.dtype | str | None: -@@ -283,7 +446,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -283,7 +510,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): @staticmethod def _supports_current_device() -> bool: @@ -305,7 +389,39 @@ index 33851fdc8..942a40876 100644 def apply( self, -@@ -322,5 +487,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -303,6 +532,31 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): ++ if ( ++ self.bf16_experts is not None ++ and hidden_states.shape[0] <= _BF16_DECODE_TOKEN_THRESHOLD ++ ): ++ if self.w1_bf16 is None or self.w2_bf16 is None: ++ raise RuntimeError("BF16 decode weights were not bound after loading.") ++ self.bf16_experts.apply( ++ output=output, ++ hidden_states=hidden_states, ++ w1=self.w1_bf16, ++ w2=self.w2_bf16, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ activation=activation, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map, ++ a1q_scale=None, ++ a2_scale=None, ++ workspace13=workspace13, ++ workspace2=workspace2, ++ expert_tokens_meta=expert_tokens_meta, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ ) ++ return ++ + alpha = self.quant_config.gemm1_alpha + alpha = 1.702 if alpha is None else float(alpha) + beta = self.quant_config.gemm1_beta +@@ -322,5 +576,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): limit=limit, global_num_experts=global_num_experts, expert_map=expert_map, @@ -313,6 +429,24 @@ index 33851fdc8..942a40876 100644 ) - output.copy_(out) + assert out is output +diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py +index 225484385..912ed6152 100644 +--- a/vllm/model_executor/layers/fused_moe/layer.py ++++ b/vllm/model_executor/layers/fused_moe/layer.py +@@ -318,7 +318,13 @@ def FusedMoE( + moe_backend=vllm_config.kernel_config.moe_backend, + router_logits_dtype=router_logits_dtype, + max_num_tokens=max_num_batched_tokens, ++ max_model_len=( ++ vllm_config.model_config.max_model_len ++ if vllm_config.model_config is not None ++ else 0 ++ ), + has_bias=has_bias, ++ has_shared_experts=shared_experts is not None, + is_lora_enabled=vllm_config.lora_config is not None, + activation=moe_activation, + device=vllm_config.device_config.device, diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index acbf2cb46..1fcf67678 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -356,14 +490,32 @@ index acbf2cb46..1fcf67678 100644 Fp8MoeBackend.NATIVE_MXFP8, ]: diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py -index d0d7c7648..cb3e5d446 100644 +index d0d7c7648..bc00da41e 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py -@@ -79,12 +79,20 @@ def _select_kernel_cls( - def _select_rocm_mxfp8_backend() -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: +@@ -76,15 +76,37 @@ def _select_kernel_cls( + ) + + +-def _select_rocm_mxfp8_backend() -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: ++def _select_rocm_mxfp8_backend( ++ config: FusedMoEConfig, ++) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: """ROCm fallback when vendor MXFP8 backends are unavailable.""" - if current_platform.supports_mx(): ++ if current_platform.is_fp8_fnuz() and config.ep_size > 1: ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( ++ Mxfp8EmulationTritonExperts, ++ ) ++ ++ logger.info_once( ++ "Using BF16 MXFP8 emulation for gfx94x expert parallelism; the " ++ "native CDNA3 path is optimized for decode-sized TP workloads and " ++ "is slower for the large local batches reached during EP prefill." ++ ) ++ return Fp8MoeBackend.EMULATION, Mxfp8EmulationTritonExperts ++ + if current_platform.supports_mx() or current_platform.is_fp8_fnuz(): from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( Mxfp8NativeTritonExperts, @@ -382,11 +534,54 @@ index d0d7c7648..cb3e5d446 100644 return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( +@@ -134,6 +156,6 @@ def select_mxfp8_moe_backend( + + # simplify the logic for rocm, refactor later when more backends are supported + if current_platform.is_rocm(): +- return _select_rocm_mxfp8_backend() ++ return _select_rocm_mxfp8_backend(config) + + raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py -index 33c7c7532..5f250c912 100644 +index 33c7c7532..f980635b9 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py -@@ -2086,7 +2086,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -44,6 +44,9 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, + ) ++from vllm.model_executor.layers.fusion.quant_activation import ( ++ expose_input_quant_key, ++) + from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, +@@ -92,6 +95,7 @@ from vllm.model_executor.parameter import ( + PerTensorScaleParameter, + ) + from vllm.model_executor.utils import replace_parameter, set_weight_attrs ++from vllm.platforms import current_platform + + if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper +@@ -470,6 +474,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition ++ layer.orig_dtype = params_dtype + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized +@@ -1192,6 +1197,8 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): + + layer.register_parameter("weight_scale", weight_scale) + ++ expose_input_quant_key(layer, self.kernel) ++ + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if ( + torch.unique(layer.input_scale).numel() != 1 +@@ -2086,7 +2093,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): def _dequant_mxfp8_weights_to_bf16(self, layer: RoutedExperts) -> None: """One-time MXFP8->BF16 weight dequant for the emulation path. @@ -395,11 +590,68 @@ index 33c7c7532..5f250c912 100644 ``Mxfp8EmulationTritonExperts`` otherwise dequantizes every expert weight to BF16 on *every* forward step -- the dominant cost (conc1 ~1.3 tok/s). Doing the dequant once here and replacing the MXFP8 -@@ -2158,7 +2158,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2121,6 +2128,43 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + num_experts, + ) + ++ def _retain_bf16_decode_weights(self, layer: RoutedExperts) -> None: ++ """Keep BF16 weights for decode-sized TP batches on gfx94x.""" ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ Mxfp8NativeTritonExperts, ++ ) ++ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ++ dequant_mxfp8_to_bf16, ++ ) ++ ++ if self.moe_kernel is None: ++ raise RuntimeError("MXFP8 MoE kernel was not initialized.") ++ experts = self.moe_kernel.fused_experts ++ if not isinstance(experts, Mxfp8NativeTritonExperts): ++ raise TypeError( ++ "Expected Mxfp8NativeTritonExperts for the gfx94x native backend." ++ ) ++ ++ target_dtype = getattr(layer, "orig_dtype", torch.bfloat16) ++ w13_bf16 = dequant_mxfp8_to_bf16(layer.w13_weight, layer.w13_weight_scale).to( ++ target_dtype ++ ) ++ w2_bf16 = dequant_mxfp8_to_bf16(layer.w2_weight, layer.w2_weight_scale).to( ++ target_dtype ++ ) ++ ++ layer.register_buffer("_mxfp8_w13_bf16", w13_bf16, persistent=False) ++ layer.register_buffer("_mxfp8_w2_bf16", w2_bf16, persistent=False) ++ experts.bind_bf16_weights( ++ layer._mxfp8_w13_bf16, ++ layer._mxfp8_w2_bf16, ++ ) ++ ++ logger.info_once( ++ "Retaining BF16 MXFP8 MoE weights for decode-sized gfx94x TP " ++ "batches; larger batches continue to use the native MXFP8 kernel." ++ ) ++ + def process_weights_after_loading(self, layer: RoutedExperts) -> None: + # TODO(bnell): why is this required only for mxfp8? + if getattr(layer, "_already_called_process_weights_after_loading", False): +@@ -2158,7 +2202,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): routing_tables=layer._expert_routing_tables(), ) - # No native MXFP8 MoE kernel on this device (e.g. gfx942): the emulation ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ Mxfp8NativeTritonExperts, ++ ) ++ ++ experts = self.moe_kernel.fused_experts ++ if ( ++ self.mxfp8_backend == Fp8MoeBackend.NATIVE_MXFP8 ++ and current_platform.is_fp8_fnuz() ++ and isinstance(experts, Mxfp8NativeTritonExperts) ++ and experts.requires_bf16_decode_weights ++ ): ++ self._retain_bf16_decode_weights(layer) ++ + # No fused MXFP8 MoE kernel on this device: the emulation # experts would dequant MXFP8->BF16 every forward step. Convert the # weights to BF16 once, here, so the MoE runs like a BF16 checkpoint. diff --git a/perf-changelog.yaml b/perf-changelog.yaml index fa394ac31..67dbb5f94 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3756,15 +3756,6 @@ - "[AI generated draft test] The shipped ROCm image's AMD MiniMax-M3 model lacks SupportsEagle3, so the recipe patches it in-place at runtime (functionstackx/vllm#1, upstream vllm-project/vllm#45546; validated green on MI355X) before serving — validates EAGLE3 on MI300X ahead of an image rebuild" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1749 -- config-keys: - - minimaxm3-fp8-mi300x-vllm - description: - - "Replace the gfx942 BF16 MoE emulation path with a fused Triton MXFP8 backend that keeps expert weights compressed and applies 1x32 E8M0 scales in-kernel" - - "Normalize OCP E4M3FN checkpoint values to AMD E4M3FNUZ so the grouped expert GEMMs use native CDNA3 FP8 matrix cores" - - "Preserve the existing TP8 and TP8+EP8 search space, default BF16 KV cache, and BF16 dense-linear path" - - "Reduce model memory from 101.27 GiB to 53.18 GiB per GPU and increase available KV-cache memory from 69.35 GiB to 117.43 GiB" - pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1753 - - config-keys: - minimaxm3-fp8-mi300x-vllm description: @@ -3777,3 +3768,12 @@ - "Run the MiniMax-M3 MXFP8 MI300X EAGLE3 MTP recipe with CUDA graphs instead of --enforce-eager" - "Drop --enforce-eager and set VLLM_USE_BREAKABLE_CUDAGRAPH=0 (matching the non-MTP MI300X recipe, #1750), which avoids the M3-decode breakable-cudagraph path that previously forced eager execution" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1756 + +- config-keys: + - minimaxm3-fp8-mi300x-vllm + description: + - "Add a fused gfx942 MXFP8 MoE kernel, normalize OCP E4M3FN weights to AMD E4M3FNUZ, and reconstruct E8M0 scales with the Marlin-style BF16 exponent bitcast" + - "Dispatch short-context TP decode batches of up to 16 tokens through retained BF16 experts, larger TP batches through native W8A8, and gfx942 expert parallelism through load-time BF16 dequantization" + - "Recover the BF16 low-concurrency and EP serving curve while retaining the native high-concurrency gain: 80.24 tok/s/GPU at TP8 conc-4, 452.47 at conc-64, 626.20 at conc-128, and 767.12 at TP8+EP8 conc-256" + - "Keep only compressed MXFP8 experts for the 8k1k configuration, reducing model memory from 101.27 GiB to 53.18 GiB per GPU and increasing available KV cache from 69.33 GiB to 117.44 GiB" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1753 From 60a0002545ac4e62f0907a79a5bd91e299094f67 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 03:13:40 -0700 Subject: [PATCH 05/10] fix: rebuild MI300X patch from pinned vLLM --- .../minimaxm3_mi300x_mxfp8.patch | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index 31deb4740..1045a0a29 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -543,20 +543,10 @@ index d0d7c7648..bc00da41e 100644 raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py -index 33c7c7532..f980635b9 100644 +index 33c7c7532..e93189b31 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py -@@ -44,6 +44,9 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - make_nvfp4_moe_quant_config, - select_nvfp4_moe_backend, - ) -+from vllm.model_executor.layers.fusion.quant_activation import ( -+ expose_input_quant_key, -+) - from vllm.model_executor.layers.linear import ( - LinearBase, - LinearMethodBase, -@@ -92,6 +95,7 @@ from vllm.model_executor.parameter import ( +@@ -92,6 +92,7 @@ from vllm.model_executor.parameter import ( PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter, set_weight_attrs @@ -564,24 +554,7 @@ index 33c7c7532..f980635b9 100644 if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper -@@ -470,6 +474,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition -+ layer.orig_dtype = params_dtype - weight_dtype = ( - torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized -@@ -1192,6 +1197,8 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): - - layer.register_parameter("weight_scale", weight_scale) - -+ expose_input_quant_key(layer, self.kernel) -+ - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if ( - torch.unique(layer.input_scale).numel() != 1 -@@ -2086,7 +2093,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2086,7 +2087,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): def _dequant_mxfp8_weights_to_bf16(self, layer: RoutedExperts) -> None: """One-time MXFP8->BF16 weight dequant for the emulation path. @@ -590,7 +563,7 @@ index 33c7c7532..f980635b9 100644 ``Mxfp8EmulationTritonExperts`` otherwise dequantizes every expert weight to BF16 on *every* forward step -- the dominant cost (conc1 ~1.3 tok/s). Doing the dequant once here and replacing the MXFP8 -@@ -2121,6 +2128,43 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2121,6 +2122,43 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): num_experts, ) @@ -634,7 +607,7 @@ index 33c7c7532..f980635b9 100644 def process_weights_after_loading(self, layer: RoutedExperts) -> None: # TODO(bnell): why is this required only for mxfp8? if getattr(layer, "_already_called_process_weights_after_loading", False): -@@ -2158,7 +2202,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2158,7 +2196,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): routing_tables=layer._expert_routing_tables(), ) From a38f5ab2d28f3b1fc4b3d167a51ed6a0e29c4dab Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 09:37:16 -0700 Subject: [PATCH 06/10] perf: update MiniMax M3 MI300X MXFP8 patch Co-authored-by: OpenAI Codex --- .../minimaxm3_mi300x_mxfp8.patch | 211 ++++++++++++++---- 1 file changed, 173 insertions(+), 38 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index 1045a0a29..f5316b6f4 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -12,6 +12,106 @@ index 0755699d1..4f6046d88 100644 is_lora_enabled: bool = False # SwiGLU clamp limit. When set, backends that do not implement the clamp +diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X.json +new file mode 100644 +index 000000000..201cfad15 +--- /dev/null ++++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X.json +@@ -0,0 +1,35 @@ ++{ ++ "1": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2 ++ }, ++ "1024": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "8192": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ } ++} +diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X.json +new file mode 100644 +index 000000000..f9de47ad6 +--- /dev/null ++++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X.json +@@ -0,0 +1,53 @@ ++{ ++ "1": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2 ++ }, ++ "128": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "256": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "1024": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2 ++ }, ++ "8192": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 16, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2 ++ } ++} diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py index 71dd7634a..63500487d 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py @@ -26,7 +126,7 @@ index 71dd7634a..63500487d 100644 import torch diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 33851fdc8..943a4aedf 100644 +index 33851fdc8..70c59f55f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py @@ -1,28 +1,35 @@ @@ -73,23 +173,32 @@ index 33851fdc8..943a4aedf 100644 from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) -@@ -34,9 +41,24 @@ from vllm.triton_utils import tl, triton +@@ -34,9 +41,33 @@ from vllm.triton_utils import tl, triton logger = init_logger(__name__) -+_BF16_DECODE_TOKEN_THRESHOLD = 16 ++_BF16_DECODE_TOKEN_THRESHOLD = 8 ++_BF16_PREFILL_TOKEN_THRESHOLD = 1024 ++_LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 + + +def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: -+ """Limit duplicate BF16 weights to the short-context MiniMax-M3 TP case.""" ++ """Limit BF16 fallback weights to the exact MiniMax-M3 TP shape.""" + return ( -+ moe_config.ep_size == 1 ++ current_platform.is_fp8_fnuz() ++ and moe_config.ep_size == 1 + and moe_config.has_shared_experts + and moe_config.num_experts == 128 + and moe_config.experts_per_token == 4 + and moe_config.hidden_dim == 6144 + and moe_config.intermediate_size == 3072 -+ and 0 < moe_config.max_model_len <= 4096 ++ and moe_config.max_model_len > 0 ++ ) ++ ++ ++def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: ++ return ( ++ max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 + ) + @@ -99,7 +208,7 @@ index 33851fdc8..943a4aedf 100644 a_ptr, a_scale_ptr, b_ptr, -@@ -125,6 +147,116 @@ def _mxfp8_grouped_gemm_kernel( +@@ -125,6 +156,116 @@ def _mxfp8_grouped_gemm_kernel( ) @@ -216,7 +325,7 @@ index 33851fdc8..943a4aedf 100644 def _grouped_gemm_mxfp8( a_q: torch.Tensor, # [M, K] fp8 e4m3 a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) -@@ -143,16 +275,58 @@ def _grouped_gemm_mxfp8( +@@ -143,16 +284,58 @@ def _grouped_gemm_mxfp8( ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape @@ -281,7 +390,7 @@ index 33851fdc8..943a4aedf 100644 a_q, a_scale, w, -@@ -183,7 +357,7 @@ def _grouped_gemm_mxfp8( +@@ -183,7 +366,7 @@ def _grouped_gemm_mxfp8( BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, @@ -290,7 +399,7 @@ index 33851fdc8..943a4aedf 100644 ) return out -@@ -202,12 +376,20 @@ def fused_moe_mxfp8_native( +@@ -202,12 +385,20 @@ def fused_moe_mxfp8_native( limit: float | None, global_num_experts: int, expert_map: torch.Tensor | None, @@ -312,7 +421,7 @@ index 33851fdc8..943a4aedf 100644 sorted_ids, expert_ids, num_post = moe_align_block_size( topk_ids, block_m, -@@ -256,17 +438,62 @@ def fused_moe_mxfp8_native( +@@ -256,17 +447,64 @@ def fused_moe_mxfp8_native( M, top_k, block_m, @@ -349,10 +458,9 @@ index 33851fdc8..943a4aedf 100644 + super().__init__(moe_config, quant_config) + self.w1_bf16: torch.Tensor | None = None + self.w2_bf16: torch.Tensor | None = None ++ self.native_weights_available = True + self.bf16_experts: TritonExperts | None = None -+ if current_platform.is_fp8_fnuz() and _should_use_bf16_decode_fallback( -+ moe_config -+ ): ++ if _should_use_bf16_decode_fallback(moe_config): + bf16_config = biased_moe_quant_config( + None, + None, @@ -363,22 +471,25 @@ index 33851fdc8..943a4aedf 100644 + self.bf16_experts = TritonExperts(moe_config, bf16_config) + + @property -+ def requires_bf16_decode_weights(self) -> bool: ++ def requires_bf16_fallback_weights(self) -> bool: + return self.bf16_experts is not None + + def bind_bf16_weights( + self, + w1_bf16: torch.Tensor, + w2_bf16: torch.Tensor, ++ *, ++ native_weights_available: bool, + ) -> None: + if self.bf16_experts is None: + raise RuntimeError("BF16 decode experts are not enabled for this config.") + self.w1_bf16 = w1_bf16 + self.w2_bf16 = w2_bf16 ++ self.native_weights_available = native_weights_available @property def quant_dtype(self) -> torch.dtype | str | None: -@@ -283,7 +510,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -283,7 +521,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): @staticmethod def _supports_current_device() -> bool: @@ -389,17 +500,22 @@ index 33851fdc8..943a4aedf 100644 def apply( self, -@@ -303,6 +532,31 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -303,6 +543,36 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): -+ if ( -+ self.bf16_experts is not None -+ and hidden_states.shape[0] <= _BF16_DECODE_TOKEN_THRESHOLD ++ num_tokens = hidden_states.shape[0] ++ bf16_experts = self.bf16_experts ++ if bf16_experts is not None and ( ++ not self.native_weights_available ++ or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD ++ or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD + ): + if self.w1_bf16 is None or self.w2_bf16 is None: -+ raise RuntimeError("BF16 decode weights were not bound after loading.") -+ self.bf16_experts.apply( ++ raise RuntimeError( ++ "BF16 fallback weights were not bound after loading." ++ ) ++ bf16_experts.apply( + output=output, + hidden_states=hidden_states, + w1=self.w1_bf16, @@ -421,7 +537,7 @@ index 33851fdc8..943a4aedf 100644 alpha = self.quant_config.gemm1_alpha alpha = 1.702 if alpha is None else float(alpha) beta = self.quant_config.gemm1_beta -@@ -322,5 +576,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -322,5 +592,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): limit=limit, global_num_experts=global_num_experts, expert_map=expert_map, @@ -543,7 +659,7 @@ index d0d7c7648..bc00da41e 100644 raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py -index 33c7c7532..e93189b31 100644 +index 33c7c7532..80baf5585 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -92,6 +92,7 @@ from vllm.model_executor.parameter import ( @@ -563,18 +679,20 @@ index 33c7c7532..e93189b31 100644 ``Mxfp8EmulationTritonExperts`` otherwise dequantizes every expert weight to BF16 on *every* forward step -- the dominant cost (conc1 ~1.3 tok/s). Doing the dequant once here and replacing the MXFP8 -@@ -2121,6 +2122,43 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2121,6 +2122,62 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): num_experts, ) -+ def _retain_bf16_decode_weights(self, layer: RoutedExperts) -> None: -+ """Keep BF16 weights for decode-sized TP batches on gfx94x.""" ++ def _retain_bf16_fallback_weights(self, layer: RoutedExperts) -> None: ++ """Keep the BF16 weights selected by the gfx94x TP dispatch policy.""" + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( + Mxfp8NativeTritonExperts, ++ _should_store_bf16_only, + ) + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + dequant_mxfp8_to_bf16, + ) ++ from vllm.model_executor.models.utils import extract_layer_index + + if self.moe_kernel is None: + raise RuntimeError("MXFP8 MoE kernel was not initialized.") @@ -591,23 +709,40 @@ index 33c7c7532..e93189b31 100644 + w2_bf16 = dequant_mxfp8_to_bf16(layer.w2_weight, layer.w2_weight_scale).to( + target_dtype + ) ++ layer_index = extract_layer_index(layer.layer_name) ++ store_bf16_only = _should_store_bf16_only( ++ self.moe.max_model_len, ++ layer_index, ++ ) + -+ layer.register_buffer("_mxfp8_w13_bf16", w13_bf16, persistent=False) -+ layer.register_buffer("_mxfp8_w2_bf16", w2_bf16, persistent=False) ++ if store_bf16_only: ++ replace_parameter(layer, "w13_weight", w13_bf16) ++ replace_parameter(layer, "w2_weight", w2_bf16) ++ else: ++ layer.register_buffer("_mxfp8_w13_bf16", w13_bf16, persistent=False) ++ layer.register_buffer("_mxfp8_w2_bf16", w2_bf16, persistent=False) + experts.bind_bf16_weights( -+ layer._mxfp8_w13_bf16, -+ layer._mxfp8_w2_bf16, ++ w13_bf16, ++ w2_bf16, ++ native_weights_available=not store_bf16_only, + ) + -+ logger.info_once( -+ "Retaining BF16 MXFP8 MoE weights for decode-sized gfx94x TP " -+ "batches; larger batches continue to use the native MXFP8 kernel." -+ ) ++ if self.moe.max_model_len <= 4096: ++ logger.info_once( ++ "Retaining BF16 MXFP8 MoE weights for gfx94x TP decode and " ++ "prefill dispatch." ++ ) ++ else: ++ logger.info_once( ++ "Using BF16-only storage for one-fifth of gfx94x TP MoE " ++ "layers and retaining both MXFP8 and BF16 weights for the " ++ "remaining layers." ++ ) + def process_weights_after_loading(self, layer: RoutedExperts) -> None: # TODO(bnell): why is this required only for mxfp8? if getattr(layer, "_already_called_process_weights_after_loading", False): -@@ -2158,7 +2196,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2158,7 +2215,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): routing_tables=layer._expert_routing_tables(), ) @@ -621,9 +756,9 @@ index 33c7c7532..e93189b31 100644 + self.mxfp8_backend == Fp8MoeBackend.NATIVE_MXFP8 + and current_platform.is_fp8_fnuz() + and isinstance(experts, Mxfp8NativeTritonExperts) -+ and experts.requires_bf16_decode_weights ++ and experts.requires_bf16_fallback_weights + ): -+ self._retain_bf16_decode_weights(layer) ++ self._retain_bf16_fallback_weights(layer) + + # No fused MXFP8 MoE kernel on this device: the emulation # experts would dequant MXFP8->BF16 every forward step. Convert the From 7678b0bcdcf500a7ec93fbba5278db5c37983ffd Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 13:34:48 -0700 Subject: [PATCH 07/10] perf(mi300x): pack MiniMax M3 MXFP8 scales Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: OpenAI Codex --- .../minimaxm3_mi300x_mxfp8.patch | 91 ++++++++++++++++--- 1 file changed, 78 insertions(+), 13 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index f5316b6f4..76d63591e 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -126,7 +126,7 @@ index 71dd7634a..63500487d 100644 import torch diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 33851fdc8..70c59f55f 100644 +index 33851fdc8..815d8ef94 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py @@ -1,28 +1,35 @@ @@ -325,7 +325,7 @@ index 33851fdc8..70c59f55f 100644 def _grouped_gemm_mxfp8( a_q: torch.Tensor, # [M, K] fp8 e4m3 a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) -@@ -143,16 +284,58 @@ def _grouped_gemm_mxfp8( +@@ -143,16 +284,71 @@ def _grouped_gemm_mxfp8( ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape @@ -334,6 +334,19 @@ index 33851fdc8..70c59f55f 100644 + assert K % k_alignment == 0, ( + f"MXFP8 native MoE requires K%{k_alignment}==0, got K={K}" + ) ++ if w_scale.shape == (E, N, K // 32): ++ scale_stride_e = w_scale.stride(0) ++ scale_stride_n = w_scale.stride(1) ++ scale_stride_k = w_scale.stride(2) ++ elif w_scale.shape == (E, K // 32, N): ++ scale_stride_e = w_scale.stride(0) ++ scale_stride_n = w_scale.stride(2) ++ scale_stride_k = w_scale.stride(1) ++ else: ++ raise ValueError( ++ "MXFP8 weight scales must use [E, N, K/32] or packed " ++ f"[E, K/32, N] layout, got {tuple(w_scale.shape)}." ++ ) + BLOCK_K = ( + 128 + if current_platform.is_fp8_fnuz() and K % 128 == 0 and block_m <= 16 @@ -390,7 +403,20 @@ index 33851fdc8..70c59f55f 100644 a_q, a_scale, w, -@@ -183,7 +366,7 @@ def _grouped_gemm_mxfp8( +@@ -173,9 +369,9 @@ def _grouped_gemm_mxfp8( + w.stride(0), + w.stride(1), + w.stride(2), +- w_scale.stride(0), +- w_scale.stride(1), +- w_scale.stride(2), ++ scale_stride_e, ++ scale_stride_n, ++ scale_stride_k, + out.stride(0), + out.stride(1), + A_DIV=a_div, +@@ -183,7 +379,7 @@ def _grouped_gemm_mxfp8( BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, @@ -399,7 +425,7 @@ index 33851fdc8..70c59f55f 100644 ) return out -@@ -202,12 +385,20 @@ def fused_moe_mxfp8_native( +@@ -202,12 +398,20 @@ def fused_moe_mxfp8_native( limit: float | None, global_num_experts: int, expert_map: torch.Tensor | None, @@ -421,7 +447,7 @@ index 33851fdc8..70c59f55f 100644 sorted_ids, expert_ids, num_post = moe_align_block_size( topk_ids, block_m, -@@ -256,17 +447,64 @@ def fused_moe_mxfp8_native( +@@ -256,17 +460,74 @@ def fused_moe_mxfp8_native( M, top_k, block_m, @@ -486,10 +512,20 @@ index 33851fdc8..70c59f55f 100644 + self.w1_bf16 = w1_bf16 + self.w2_bf16 = w2_bf16 + self.native_weights_available = native_weights_available ++ ++ def bind_packed_weight_scales( ++ self, ++ w1_scale: torch.Tensor, ++ w2_scale: torch.Tensor, ++ ) -> None: ++ if not current_platform.is_fp8_fnuz(): ++ raise RuntimeError("Packed MXFP8 scales are specific to gfx94x.") ++ self.w1_scale_val = w1_scale ++ self.w2_scale_val = w2_scale @property def quant_dtype(self) -> torch.dtype | str | None: -@@ -283,7 +521,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -283,7 +544,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): @staticmethod def _supports_current_device() -> bool: @@ -500,7 +536,7 @@ index 33851fdc8..70c59f55f 100644 def apply( self, -@@ -303,6 +543,36 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -303,6 +566,36 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): @@ -537,7 +573,7 @@ index 33851fdc8..70c59f55f 100644 alpha = self.quant_config.gemm1_alpha alpha = 1.702 if alpha is None else float(alpha) beta = self.quant_config.gemm1_beta -@@ -322,5 +592,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -322,5 +615,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): limit=limit, global_num_experts=global_num_experts, expert_map=expert_map, @@ -659,7 +695,7 @@ index d0d7c7648..bc00da41e 100644 raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py -index 33c7c7532..80baf5585 100644 +index 33c7c7532..9e26aa823 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -92,6 +92,7 @@ from vllm.model_executor.parameter import ( @@ -679,7 +715,7 @@ index 33c7c7532..80baf5585 100644 ``Mxfp8EmulationTritonExperts`` otherwise dequantizes every expert weight to BF16 on *every* forward step -- the dominant cost (conc1 ~1.3 tok/s). Doing the dequant once here and replacing the MXFP8 -@@ -2121,6 +2122,62 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2121,6 +2122,90 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): num_experts, ) @@ -738,11 +774,39 @@ index 33c7c7532..80baf5585 100644 + "layers and retaining both MXFP8 and BF16 weights for the " + "remaining layers." + ) ++ ++ def _pack_mxfp8_weight_scales(self, layer: RoutedExperts, experts) -> None: ++ """Pack gfx94x E8M0 scales so consecutive output columns are contiguous.""" ++ if not experts.native_weights_available: ++ return ++ ++ def pack(scale: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: ++ E, N, K = weight.shape ++ if scale.shape == (E, K // 32, N): ++ return scale ++ if scale.shape != (E, N, K // 32): ++ raise ValueError( ++ "Unexpected MXFP8 weight-scale shape " ++ f"{tuple(scale.shape)} for weight {tuple(weight.shape)}." ++ ) ++ return scale.transpose(1, 2).contiguous() ++ ++ w13_scale = pack(layer.w13_weight_scale, layer.w13_weight) ++ w2_scale = pack(layer.w2_weight_scale, layer.w2_weight) ++ replace_parameter(layer, "w13_weight_scale", w13_scale) ++ replace_parameter(layer, "w2_weight_scale", w2_scale) ++ experts.bind_packed_weight_scales( ++ layer.w13_weight_scale, ++ layer.w2_weight_scale, ++ ) ++ logger.info_once( ++ "Packed gfx94x MXFP8 MoE weight scales as [expert, K/32, N]." ++ ) + def process_weights_after_loading(self, layer: RoutedExperts) -> None: # TODO(bnell): why is this required only for mxfp8? if getattr(layer, "_already_called_process_weights_after_loading", False): -@@ -2158,7 +2215,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +@@ -2158,7 +2243,21 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): routing_tables=layer._expert_routing_tables(), ) @@ -756,9 +820,10 @@ index 33c7c7532..80baf5585 100644 + self.mxfp8_backend == Fp8MoeBackend.NATIVE_MXFP8 + and current_platform.is_fp8_fnuz() + and isinstance(experts, Mxfp8NativeTritonExperts) -+ and experts.requires_bf16_fallback_weights + ): -+ self._retain_bf16_fallback_weights(layer) ++ if experts.requires_bf16_fallback_weights: ++ self._retain_bf16_fallback_weights(layer) ++ self._pack_mxfp8_weight_scales(layer, experts) + + # No fused MXFP8 MoE kernel on this device: the emulation # experts would dequant MXFP8->BF16 every forward step. Convert the From 280c030a19a6de9ed0bb992457d796aa4bee8f23 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 17:12:56 -0700 Subject: [PATCH 08/10] perf(mi300x): tune MiniMax M3 MXFP8 refill dispatch Co-authored-by: OpenAI Codex Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> --- .../minimaxm3_mi300x_mxfp8.patch | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index 76d63591e..1b5cd2ac7 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -126,7 +126,7 @@ index 71dd7634a..63500487d 100644 import torch diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 33851fdc8..815d8ef94 100644 +index 33851fdc8..50554b98c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py @@ -1,28 +1,35 @@ @@ -173,12 +173,14 @@ index 33851fdc8..815d8ef94 100644 from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) -@@ -34,9 +41,33 @@ from vllm.triton_utils import tl, triton +@@ -34,9 +41,46 @@ from vllm.triton_utils import tl, triton logger = init_logger(__name__) +_BF16_DECODE_TOKEN_THRESHOLD = 8 -+_BF16_PREFILL_TOKEN_THRESHOLD = 1024 ++# MiniMax-M3 eager refill shapes cross over to the retained BF16 experts ++# between 827 and 843 tokens on MI300X. Keep the cutoff tile-aligned. ++_BF16_PREFILL_TOKEN_THRESHOLD = 832 +_LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 + + @@ -200,6 +202,17 @@ index 33851fdc8..815d8ef94 100644 + return ( + max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 + ) ++ ++ ++def _should_use_bf16_experts( ++ num_tokens: int, ++ native_weights_available: bool, ++) -> bool: ++ return ( ++ not native_weights_available ++ or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD ++ or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD ++ ) + @triton.jit @@ -208,7 +221,7 @@ index 33851fdc8..815d8ef94 100644 a_ptr, a_scale_ptr, b_ptr, -@@ -125,6 +156,116 @@ def _mxfp8_grouped_gemm_kernel( +@@ -125,6 +169,116 @@ def _mxfp8_grouped_gemm_kernel( ) @@ -325,7 +338,7 @@ index 33851fdc8..815d8ef94 100644 def _grouped_gemm_mxfp8( a_q: torch.Tensor, # [M, K] fp8 e4m3 a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) -@@ -143,16 +284,71 @@ def _grouped_gemm_mxfp8( +@@ -143,16 +297,71 @@ def _grouped_gemm_mxfp8( ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape @@ -403,7 +416,7 @@ index 33851fdc8..815d8ef94 100644 a_q, a_scale, w, -@@ -173,9 +369,9 @@ def _grouped_gemm_mxfp8( +@@ -173,9 +382,9 @@ def _grouped_gemm_mxfp8( w.stride(0), w.stride(1), w.stride(2), @@ -416,7 +429,7 @@ index 33851fdc8..815d8ef94 100644 out.stride(0), out.stride(1), A_DIV=a_div, -@@ -183,7 +379,7 @@ def _grouped_gemm_mxfp8( +@@ -183,7 +392,7 @@ def _grouped_gemm_mxfp8( BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, @@ -425,7 +438,7 @@ index 33851fdc8..815d8ef94 100644 ) return out -@@ -202,12 +398,20 @@ def fused_moe_mxfp8_native( +@@ -202,12 +411,20 @@ def fused_moe_mxfp8_native( limit: float | None, global_num_experts: int, expert_map: torch.Tensor | None, @@ -447,7 +460,7 @@ index 33851fdc8..815d8ef94 100644 sorted_ids, expert_ids, num_post = moe_align_block_size( topk_ids, block_m, -@@ -256,17 +460,74 @@ def fused_moe_mxfp8_native( +@@ -256,17 +473,74 @@ def fused_moe_mxfp8_native( M, top_k, block_m, @@ -525,7 +538,7 @@ index 33851fdc8..815d8ef94 100644 @property def quant_dtype(self) -> torch.dtype | str | None: -@@ -283,7 +544,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -283,7 +557,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): @staticmethod def _supports_current_device() -> bool: @@ -536,16 +549,15 @@ index 33851fdc8..815d8ef94 100644 def apply( self, -@@ -303,6 +566,36 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -303,6 +579,35 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): + num_tokens = hidden_states.shape[0] + bf16_experts = self.bf16_experts -+ if bf16_experts is not None and ( -+ not self.native_weights_available -+ or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD -+ or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD ++ if bf16_experts is not None and _should_use_bf16_experts( ++ num_tokens, ++ self.native_weights_available, + ): + if self.w1_bf16 is None or self.w2_bf16 is None: + raise RuntimeError( @@ -573,7 +585,7 @@ index 33851fdc8..815d8ef94 100644 alpha = self.quant_config.gemm1_alpha alpha = 1.702 if alpha is None else float(alpha) beta = self.quant_config.gemm1_beta -@@ -322,5 +615,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -322,5 +627,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): limit=limit, global_num_experts=global_num_experts, expert_map=expert_map, From 23925ccb4860de7f2face75283ebfe3f70943bdc Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 18:41:32 -0700 Subject: [PATCH 09/10] perf(mi300x): tune short-k MXFP8 MoE GEMM2 Co-authored-by: OpenAI Codex Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> --- .../minimaxm3_mi300x_mxfp8.patch | 156 ++++++++++++------ 1 file changed, 109 insertions(+), 47 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch index 1b5cd2ac7..b391d59f1 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -126,7 +126,7 @@ index 71dd7634a..63500487d 100644 import torch diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 33851fdc8..50554b98c 100644 +index 33851fdc8..9e0145ff9 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py @@ -1,28 +1,35 @@ @@ -173,7 +173,7 @@ index 33851fdc8..50554b98c 100644 from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) -@@ -34,9 +41,46 @@ from vllm.triton_utils import tl, triton +@@ -34,9 +41,46 @@ logger = init_logger(__name__) @@ -182,7 +182,7 @@ index 33851fdc8..50554b98c 100644 +# between 827 and 843 tokens on MI300X. Keep the cutoff tile-aligned. +_BF16_PREFILL_TOKEN_THRESHOLD = 832 +_LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 -+ + + +def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: + """Limit BF16 fallback weights to the exact MiniMax-M3 TP shape.""" @@ -214,14 +214,14 @@ index 33851fdc8..50554b98c 100644 + or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD + ) + - ++ @triton.jit -def _mxfp8_grouped_gemm_kernel( +def _mxfp8_grouped_gemm_dot_scaled_kernel( a_ptr, a_scale_ptr, b_ptr, -@@ -125,6 +169,116 @@ def _mxfp8_grouped_gemm_kernel( +@@ -125,6 +169,148 @@ ) @@ -334,11 +334,50 @@ index 33851fdc8..50554b98c 100644 + mask=token_mask[:, None] & n_mask[None, :], + ) + ++ ++def _gfx94x_grouped_gemm_config( ++ m_routed: int, ++ n: int, ++ k: int, ++ block_m: int, ++ is_gemm2: bool, ++) -> tuple[int, int, int]: ++ short_k_gemm2 = is_gemm2 and block_m <= 16 and k <= 512 and k % 64 == 0 ++ if short_k_gemm2: ++ # MiniMax-M3 TP GEMM2 has a wide N=6144 and short K=384. Pairing two ++ # waves over 64 columns amortizes indexing while a 64-wide K tile ++ # exposes enough independent work for the short reduction. ++ return 64, 64, 2 ++ ++ block_k = 128 if k % 128 == 0 and block_m <= 16 else 64 if k % 64 == 0 else 32 ++ if block_m <= 16: ++ # One wave per 32 output columns avoids the register pressure of the ++ # original 128-column tile. At the very smallest routed batch, pairing ++ # two waves in a 64-column program amortizes launch/indexing overhead. ++ block_n = 64 if m_routed < 32 else 32 ++ num_warps = 2 if m_routed < 32 else 1 ++ elif block_m >= 64 and n >= 2048 and k >= 2048: ++ # EP prefill GEMMs remain register-bound at a 128-column tile even with ++ # 64 rows. Two-wave 64-column programs expose more independent work. ++ block_n = 64 ++ num_warps = 2 ++ else: ++ block_n = 128 ++ num_warps = 4 if block_m <= 32 else 8 ++ return block_n, block_k, num_warps ++ + def _grouped_gemm_mxfp8( a_q: torch.Tensor, # [M, K] fp8 e4m3 a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) -@@ -143,16 +297,71 @@ def _grouped_gemm_mxfp8( +@@ -140,19 +326,81 @@ + a_div: int, + mul_weight_by: torch.Tensor | None = None, + expert_map: torch.Tensor | None = None, ++ is_gemm2: bool = False, ++ block_n_override: int = 0, ++ block_k_override: int = 0, ++ num_warps_override: int = 0, ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape @@ -360,35 +399,38 @@ index 33851fdc8..50554b98c 100644 + "MXFP8 weight scales must use [E, N, K/32] or packed " + f"[E, K/32, N] layout, got {tuple(w_scale.shape)}." + ) -+ BLOCK_K = ( -+ 128 -+ if current_platform.is_fp8_fnuz() and K % 128 == 0 and block_m <= 16 -+ else 64 -+ if current_platform.is_fp8_fnuz() and K % 64 == 0 -+ else 32 -+ if current_platform.is_fp8_fnuz() -+ else 128 -+ ) ++ is_fnuz = current_platform.is_fp8_fnuz() ++ if is_fnuz: ++ BLOCK_N, BLOCK_K, num_warps = _gfx94x_grouped_gemm_config( ++ M_routed, ++ N, ++ K, ++ block_m, ++ is_gemm2, ++ ) ++ else: ++ BLOCK_N = 128 ++ BLOCK_K = 128 ++ num_warps = 8 + # moe_align_block_size allocates for the worst case where every expert is + # active. At small batches that can be much larger than the number of + # blocks that can contain valid assignments. Limit the launch to the + # tighter static upper bound; the device-side num_post check handles the + # remaining tail. + max_post_padded = min(sorted_token_ids.shape[0], M_routed * block_m) -+ if current_platform.is_fp8_fnuz() and block_m <= 16: -+ # One wave per 32 output columns avoids the register pressure of the -+ # original 128-column tile. At the very smallest routed batch, pairing -+ # two waves in a 64-column program amortizes launch/indexing overhead. -+ BLOCK_N = 64 if M_routed < 32 else 32 -+ num_warps = 2 if M_routed < 32 else 1 -+ elif current_platform.is_fp8_fnuz() and block_m >= 64 and N >= 2048 and K >= 2048: -+ # EP prefill GEMMs remain register-bound at a 128-column tile even with -+ # 64 rows. Two-wave 64-column programs expose more independent work. -+ BLOCK_N = 64 -+ num_warps = 2 -+ else: -+ BLOCK_N = 128 -+ num_warps = 4 if current_platform.is_fp8_fnuz() and block_m <= 32 else 8 ++ if block_n_override: ++ BLOCK_N = block_n_override ++ if block_k_override: ++ BLOCK_K = block_k_override ++ if num_warps_override: ++ num_warps = num_warps_override ++ if BLOCK_K % 32 != 0 or K % BLOCK_K != 0: ++ raise ValueError( ++ f"MXFP8 grouped GEMM requires BLOCK_K to divide K in 32-value " ++ f"units, got BLOCK_K={BLOCK_K}, K={K}." ++ ) ++ if num_warps not in (1, 2, 4, 8): ++ raise ValueError(f"Unsupported num_warps={num_warps}.") + m_blocks = triton.cdiv(max_post_padded, block_m) + n_blocks = triton.cdiv(N, BLOCK_N) + @@ -416,7 +458,7 @@ index 33851fdc8..50554b98c 100644 a_q, a_scale, w, -@@ -173,9 +382,9 @@ def _grouped_gemm_mxfp8( +@@ -173,9 +421,9 @@ w.stride(0), w.stride(1), w.stride(2), @@ -429,7 +471,7 @@ index 33851fdc8..50554b98c 100644 out.stride(0), out.stride(1), A_DIV=a_div, -@@ -183,7 +392,7 @@ def _grouped_gemm_mxfp8( +@@ -183,7 +431,7 @@ BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, @@ -438,11 +480,17 @@ index 33851fdc8..50554b98c 100644 ) return out -@@ -202,12 +411,20 @@ def fused_moe_mxfp8_native( +@@ -202,12 +450,26 @@ limit: float | None, global_num_experts: int, expert_map: torch.Tensor | None, + output: torch.Tensor | None = None, ++ g1_block_n: int = 0, ++ g1_block_k: int = 0, ++ g1_num_warps: int = 0, ++ g2_block_n: int = 0, ++ g2_block_k: int = 0, ++ g2_num_warps: int = 0, ) -> torch.Tensor: T, H = hidden_states.shape top_k = topk_ids.shape[1] @@ -460,35 +508,49 @@ index 33851fdc8..50554b98c 100644 sorted_ids, expert_ids, num_post = moe_align_block_size( topk_ids, block_m, -@@ -256,17 +473,74 @@ def fused_moe_mxfp8_native( - M, - top_k, +@@ -232,6 +494,9 @@ + hidden_states.dtype, + a_div=top_k, + expert_map=expert_map, ++ block_n_override=g1_block_n, ++ block_k_override=g1_block_k, ++ num_warps_override=g1_num_warps, + ) # [M, 2I] + + # SwiGLU-OAI (split layout: gate=g1[:, :I], up=g1[:, I:]) FUSED with the +@@ -258,15 +523,76 @@ block_m, - torch.float32, + hidden_states.dtype if current_platform.is_fp8_fnuz() else torch.float32, a_div=1, mul_weight_by=topk_weights.reshape(-1).to(torch.float32), expert_map=expert_map, - ) # [M, H] == [T*top_k, H] - +- ) # [M, H] == [T*top_k, H] +- - return g2.view(T, top_k, H).sum(dim=1).to(hidden_states.dtype) ++ is_gemm2=True, ++ block_n_override=g2_block_n, ++ block_k_override=g2_block_k, ++ num_warps_override=g2_num_warps, ++ ) # [M, H] == [T*top_k, H] + + if current_platform.is_fp8_fnuz(): + if output is None: + output = torch.empty_like(hidden_states) + ops.moe_sum(g2.view(T, top_k, H), output) + return output -+ + + result = g2.view(T, top_k, H).sum(dim=1).to(hidden_states.dtype) + if output is not None: + output.copy_(result) + return output + return result - - ++ ++ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): - """Native MXFP8 MoE (CDNA4 ``dot_scaled``) on gfx950.""" + """Fused MXFP8 MoE on gfx94x/gfx95x.""" -+ + + def __init__( + self, + moe_config: FusedMoEConfig, @@ -509,7 +571,7 @@ index 33851fdc8..50554b98c 100644 + ) + self.bf16_experts = TritonExperts(moe_config, bf16_config) + -+ @property + @property + def requires_bf16_fallback_weights(self) -> bool: + return self.bf16_experts is not None + @@ -535,10 +597,10 @@ index 33851fdc8..50554b98c 100644 + raise RuntimeError("Packed MXFP8 scales are specific to gfx94x.") + self.w1_scale_val = w1_scale + self.w2_scale_val = w2_scale - - @property ++ ++ @property def quant_dtype(self) -> torch.dtype | str | None: -@@ -283,7 +557,9 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -283,7 +609,9 @@ @staticmethod def _supports_current_device() -> bool: @@ -549,7 +611,7 @@ index 33851fdc8..50554b98c 100644 def apply( self, -@@ -303,6 +579,35 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -303,6 +631,35 @@ expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): @@ -585,7 +647,7 @@ index 33851fdc8..50554b98c 100644 alpha = self.quant_config.gemm1_alpha alpha = 1.702 if alpha is None else float(alpha) beta = self.quant_config.gemm1_beta -@@ -322,5 +627,6 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -322,5 +679,6 @@ limit=limit, global_num_experts=global_num_experts, expert_map=expert_map, From 1e3bfdd30f1da87c092f7599400b6e57c089b54c Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Sun, 14 Jun 2026 20:25:41 -0700 Subject: [PATCH 10/10] fix(benchmarks): fail if MI300X patch is not applied --- .../single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index 50aa7c236..4fbe92bcd 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -38,7 +38,14 @@ PY MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" MXFP8_ORACLE="$VLLM_PACKAGE_ROOT/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py" if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then - patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH" + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then + echo "Failed to apply the MI300X MXFP8 patch" >&2 + exit 1 + fi +fi +if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then + echo "MI300X MXFP8 backend marker is missing after patching" >&2 + exit 1 fi if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi