Conversation
| import numpy as np | ||
| import os | ||
|
|
||
| os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1" |
There was a problem hiding this comment.
We previously already defined env NVTE_USE_CAST_TRANSPOSE_TRITON.
| def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix): | ||
| """Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle. | ||
|
|
||
| Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel. |
There was a problem hiding this comment.
If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?
There was a problem hiding this comment.
This is kept to ensure API consistency between Triton and the upcoming hip kernel for which I'll create a separate PR. In the hip kernel we were able to fuse the shuffle.
There was a problem hiding this comment.
hip vs triton flow
Input: BF16 [M, N]
↓
MXFP4Quantizer.update_quantized()
↓
tex.cast_transpose_mxfp4_fused_shuffle() [Single HIP kernel]
↓
├─→ Rowwise FP4 [M, K/2] (MFMA shuffled)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (MFMA shuffled)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)
↓
AITER gemm_a4w4 (zero-copy)
vs
Input: BF16 [M, N]
↓
MXFP4Quantizer.update_quantized()
↓
te_cast_transpose_mxfp4_triton() [Triton JIT kernel]
↓
├─→ Rowwise FP4 [M, K/2] (linear layout)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (linear layout)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)
↓
aiter.ops.shuffle.shuffle_weight() [External call]
↓
FP4 data → MFMA layout
↓
AITER gemm_a4w4
| (32768, 160), | ||
| (4096, 1632), | ||
| (8, 32, 1024), | ||
| (16, 8, 4, 512), |
There was a problem hiding this comment.
Can we add some prime numbers like
TransformerEngine/tests/cpp/operator/test_cast_transpose.cu
Lines 90 to 92 in 9d6b0e5
| data_atol = 20.0 if in_dtype != torch.float32 else 16.0 | ||
| scale_atol = 2.0 if in_dtype != torch.float32 else 1.0 |
There was a problem hiding this comment.
Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme:
TransformerEngine/tests/cpp/test_common.cu
Line 730 in 9d6b0e5
| use_torch_semantics=True | ||
| ) | ||
|
|
||
| # Compare only valid (non-padded) region - no shuffle extraction needed |
There was a problem hiding this comment.
fp4 shuffle basically rearranges [M, K/2] linear layout → MFMA instruction layout (16×16).
The currently flow training workflow if TE MXFP4 Quantization Kernel is used is as follows
TE Triton Kernel → Linear FP4 [N, K/2] → aiter.ops.shuffle_weight() → MFMA FP4 → aiter.gemm_a4w4()
You can find the shuffle code in aiter/aiter/ops/shuffle.py
| .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ | ||
| .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ |
There was a problem hiding this comment.
If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:
| - Data: [M, K/2] uint8 tensor (2 FP4 values packed per byte) | ||
| - Scale: [M, K/32] uint8 tensor (E8M0 format, one scale per 32-element block) |
There was a problem hiding this comment.
Is there alignment/padding requirements for M and K?
| if inp.ndim < 2: | ||
| return False |
There was a problem hiding this comment.
TE currently supported 2D matrices from flatted high-dimensional tensors:
TransformerEngine/transformer_engine/common/common.h
Lines 238 to 262 in 9d6b0e5
|
|
||
| # Allocate PADDED scale tensors for shuffle compatibility | ||
| rowwise_scale_N = K // MXFP4_BLOCK_SCALING_SIZE | ||
| rowwise_scale_M_pad = cdiv(M, 256) * 256 |
There was a problem hiding this comment.
I presume this 256 is from some alignment/padding requirement?
There was a problem hiding this comment.
The 256 alignment is required by AITER's CK-based MXFP4 GEMM kernels for scale tensor swizzle/shuffle layout.
Required for scale swizzle layout: 256 = ScaleBlockSize(32) × 8 waves.
See aiter/aiter/utility/fp4_utils.py:398 and gemm_a4w4_blockscale_common.cuh:66
| @@ -0,0 +1,178 @@ | |||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
There was a problem hiding this comment.
You will need to add this pytest into our ci script (somewhere near
TransformerEngine/ci/pytorch.sh
Line 74 in 9d6b0e5
Description
Implements the MXFP4
rowwiseandcolumnwiseFP32/BF16 -> MXFP4 fused quantization + cast kernelVerify Tolerances and functional Unit Tests
The triton
te_cast_transpose_mxfp4_tritoncurrently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER'sgemm_a4w4requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.