-
Notifications
You must be signed in to change notification settings - Fork 23
MXFP4 Cast Transpose Triton [WIP] #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
fd7129d
aca9e33
b7cc9f2
7b2b4e5
df39c9a
c1680cb
f2bef5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,178 @@ | ||||||||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||||||||
| # License for AMD contributions = MIT. See LICENSE for more information | ||||||||
|
|
||||||||
| import math | ||||||||
| import pytest | ||||||||
| import torch | ||||||||
| import numpy as np | ||||||||
| import os | ||||||||
|
|
||||||||
| os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1" | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We previously already defined env NVTE_USE_CAST_TRANSPOSE_TRITON. |
||||||||
|
|
||||||||
| from transformer_engine.pytorch.tensor.mxfp4_tensor import MXFP4Quantizer, MXFP4_BLOCK_SCALING_SIZE | ||||||||
| from transformer_engine.pytorch.triton_kernels.cast import te_quantize_triton | ||||||||
| from test_common import te_compare_results, fill_uniform | ||||||||
|
|
||||||||
|
|
||||||||
| def mxfp4_quantize_cpu(input_tensor, axis='row'): | ||||||||
| """CPU reference for MXFP4 quantization matching Triton kernel behavior with shuffle.""" | ||||||||
| original_shape = input_tensor.shape | ||||||||
| if input_tensor.dim() > 2: | ||||||||
| input_tensor = input_tensor.view(-1, input_tensor.shape[-1]) | ||||||||
|
|
||||||||
| M, N = input_tensor.shape | ||||||||
|
|
||||||||
| if axis == 'col': | ||||||||
| input_tensor = input_tensor.t().contiguous() | ||||||||
| M, N = N, M | ||||||||
|
|
||||||||
| data = input_tensor.cpu().float().numpy() | ||||||||
|
|
||||||||
| BLOCK_SIZE = 32 | ||||||||
| assert N % BLOCK_SIZE == 0, f"N={N} must be divisible by {BLOCK_SIZE}" | ||||||||
|
|
||||||||
| num_blocks = N // BLOCK_SIZE | ||||||||
|
|
||||||||
| # E2M1 FP4 lookup table | ||||||||
| fp4_values = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) | ||||||||
|
|
||||||||
| # Reshape to blocks: [M, num_blocks, BLOCK_SIZE] | ||||||||
| data_blocks = data.reshape(M, num_blocks, BLOCK_SIZE) | ||||||||
| amax_blocks = np.max(np.abs(data_blocks), axis=2) | ||||||||
|
|
||||||||
| # Triton's amax rounding: (amax + 0x200000) & 0xFF800000 | ||||||||
| amax_int = amax_blocks.astype(np.float32).view(np.uint32) | ||||||||
| amax_int = ((amax_int + 0x200000) & 0xFF800000).astype(np.uint32) | ||||||||
| amax_rounded = amax_int.view(np.float32) | ||||||||
|
|
||||||||
| # E8M0 scale computation: floor(log2(amax)) - 2 + 127 | ||||||||
| scale_unbiased = np.floor(np.log2(np.maximum(amax_rounded, 1e-45))) - 2 | ||||||||
| scale_unbiased = np.clip(scale_unbiased, -127, 127) | ||||||||
| scales = (scale_unbiased + 127).astype(np.uint8) | ||||||||
| scales = np.where(amax_blocks == 0, 0, scales) | ||||||||
|
|
||||||||
| # Scale values for quantization | ||||||||
| scale_vals = np.where(scales[:, :, None] > 0, | ||||||||
| 2.0 ** (-(scales[:, :, None] - 127)), | ||||||||
| 1.0) | ||||||||
|
|
||||||||
| scaled_blocks = data_blocks * scale_vals | ||||||||
|
|
||||||||
| # Quantize to FP4 | ||||||||
| signs = (scaled_blocks < 0).astype(np.uint8) | ||||||||
| abs_vals = np.abs(scaled_blocks) | ||||||||
| diffs = np.abs(abs_vals[:, :, :, None] - fp4_values[None, None, None, :]) | ||||||||
| indices = np.argmin(diffs, axis=3).astype(np.uint8) | ||||||||
| fp4_encoded = (signs << 3) | indices | ||||||||
|
|
||||||||
| fp4_flat = fp4_encoded.reshape(M, N) | ||||||||
|
|
||||||||
| # Pack: (odd_col << 4) | even_col | ||||||||
| fp4_even = fp4_flat[:, 0::2] | ||||||||
| fp4_odd = fp4_flat[:, 1::2] | ||||||||
| fp4_packed = ((fp4_odd << 4) | fp4_even).astype(np.uint8) | ||||||||
|
|
||||||||
| def cdiv(a, b): return (a + b - 1) // b | ||||||||
|
|
||||||||
| scale_M_pad = cdiv(M, 256) * 256 | ||||||||
| scale_N_pad = cdiv(num_blocks, 8) * 8 | ||||||||
| scales_padded = np.full((scale_M_pad, scale_N_pad), 127, dtype=np.uint8) | ||||||||
|
|
||||||||
| # Copy scales directly (no data shuffle support in Triton kernel) | ||||||||
| scales_padded[:M, :num_blocks] = scales | ||||||||
|
|
||||||||
| fp4_packed_torch = torch.from_numpy(fp4_packed).to(input_tensor.device) | ||||||||
| scales_torch = torch.from_numpy(scales_padded).to(input_tensor.device) | ||||||||
|
|
||||||||
| return fp4_packed_torch, scales_torch | ||||||||
|
|
||||||||
|
|
||||||||
| @pytest.mark.parametrize("shape", [ | ||||||||
| (128, 128), | ||||||||
| (256, 256), | ||||||||
| (256, 1024), | ||||||||
| (2048, 6144), | ||||||||
| (16384, 128), | ||||||||
| (32768, 160), | ||||||||
| (4096, 1632), | ||||||||
| (8, 32, 1024), | ||||||||
| (16, 8, 4, 512), | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add some prime numbers like TransformerEngine/tests/cpp/operator/test_cast_transpose.cu Lines 90 to 92 in 9d6b0e5
|
||||||||
| ]) | ||||||||
| @pytest.mark.parametrize("in_dtype", [torch.float32, torch.bfloat16]) | ||||||||
| @pytest.mark.parametrize(("rowwise", "columnwise"), [ | ||||||||
| (True, True), | ||||||||
| (False, True), | ||||||||
| (True, False) | ||||||||
| ]) | ||||||||
| @pytest.mark.parametrize("shuffle_B_matrix", [False, True]) | ||||||||
| def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix): | ||||||||
| """Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle. | ||||||||
|
|
||||||||
| Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel. | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is kept to ensure API consistency between Triton and the upcoming hip kernel for which I'll create a separate PR. In the hip kernel we were able to fuse the shuffle.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hip vs triton flow vs Input: BF16 [M, N] |
||||||||
| """ | ||||||||
| if shuffle_B_matrix: | ||||||||
| pytest.skip("FP4 data shuffle not yet supported in Triton kernel") | ||||||||
|
|
||||||||
| input_tensor = fill_uniform(shape, dtype=in_dtype) | ||||||||
|
|
||||||||
| quantizer = MXFP4Quantizer( | ||||||||
| rowwise=rowwise, | ||||||||
| columnwise=columnwise, | ||||||||
| shuffle_B_matrix_for_aiter=shuffle_B_matrix | ||||||||
| ) | ||||||||
| out = quantizer.make_empty(input_tensor.shape, dtype=in_dtype) | ||||||||
| quantized_out = te_quantize_triton(input_tensor, quantizer=quantizer, output=out) | ||||||||
|
|
||||||||
| # Tolerance: allow 1 nibble diff for rare edge cases near FP4 boundaries | ||||||||
| data_atol = 20.0 if in_dtype != torch.float32 else 16.0 | ||||||||
| scale_atol = 2.0 if in_dtype != torch.float32 else 1.0 | ||||||||
|
Comment on lines
+127
to
+128
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme: TransformerEngine/tests/cpp/test_common.cu Line 730 in 9d6b0e5
|
||||||||
|
|
||||||||
| if rowwise: | ||||||||
| ref_data, ref_scale = mxfp4_quantize_cpu(input_tensor, axis='row') | ||||||||
| M = math.prod(input_tensor.shape[:-1]) | ||||||||
| K = input_tensor.shape[-1] | ||||||||
| num_blocks = K // MXFP4_BLOCK_SCALING_SIZE | ||||||||
|
|
||||||||
| te_compare_results( | ||||||||
| quantized_out._rowwise_data.view(torch.uint8), | ||||||||
| ref_data, | ||||||||
| atol=data_atol, | ||||||||
| rtol=0.0, | ||||||||
| msg="rowwise FP4 data mismatch", | ||||||||
| use_torch_semantics=True | ||||||||
| ) | ||||||||
|
|
||||||||
| # Compare only valid (non-padded) region - no shuffle extraction needed | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is fp4 shuffle?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fp4 shuffle basically rearranges [M, K/2] linear layout → MFMA instruction layout (16×16). The currently flow training workflow if TE MXFP4 Quantization Kernel is used is as follows You can find the shuffle code in |
||||||||
| te_compare_results( | ||||||||
| quantized_out._rowwise_scale.view(torch.uint8)[:M, :num_blocks], | ||||||||
| ref_scale[:M, :num_blocks], | ||||||||
| atol=scale_atol, | ||||||||
| rtol=0.0, | ||||||||
| msg="rowwise E8M0 scales mismatch", | ||||||||
| use_torch_semantics=True | ||||||||
| ) | ||||||||
|
|
||||||||
| if columnwise: | ||||||||
| ref_data, ref_scale = mxfp4_quantize_cpu(input_tensor, axis='col') | ||||||||
| M = math.prod(input_tensor.shape[:-1]) | ||||||||
| K = input_tensor.shape[-1] | ||||||||
| num_blocks = M // MXFP4_BLOCK_SCALING_SIZE | ||||||||
|
|
||||||||
| te_compare_results( | ||||||||
| quantized_out._columnwise_data.view(torch.uint8), | ||||||||
| ref_data, | ||||||||
| atol=data_atol, | ||||||||
| rtol=0.0, | ||||||||
| msg="columnwise FP4 data mismatch", | ||||||||
| use_torch_semantics=True | ||||||||
| ) | ||||||||
|
|
||||||||
| # Compare only valid (non-padded) region - no shuffle extraction needed | ||||||||
| te_compare_results( | ||||||||
| quantized_out._columnwise_scale.view(torch.uint8)[:K, :num_blocks], | ||||||||
| ref_scale[:K, :num_blocks], | ||||||||
| atol=scale_atol, | ||||||||
| rtol=0.0, | ||||||||
| msg="columnwise E8M0 scales mismatch", | ||||||||
| use_torch_semantics=True | ||||||||
| ) | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,7 +108,8 @@ | |
| .value("kFloat16", transformer_engine::DType::kFloat16) \ | ||
| .value("kBFloat16", transformer_engine::DType::kBFloat16) \ | ||
| .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ | ||
| .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:
|
||
| pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \ | ||
| .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ | ||
| .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to add this pytest into our ci script (somewhere near
TransformerEngine/ci/pytorch.sh
Line 74 in 9d6b0e5