Skip to content

MXFP4 Cast Transpose Triton [WIP]#422

Draft
sarthak-amd wants to merge 7 commits intodevfrom
feature/cast-transpose-mxfp4
Draft

MXFP4 Cast Transpose Triton [WIP]#422
sarthak-amd wants to merge 7 commits intodevfrom
feature/cast-transpose-mxfp4

Conversation

@sarthak-amd
Copy link
Collaborator

@sarthak-amd sarthak-amd commented Jan 20, 2026

Description

Implements the MXFP4 rowwise and columnwise FP32/BF16 -> MXFP4 fused quantization + cast kernel

  • Verify Tolerances and functional Unit Tests

  • The triton te_cast_transpose_mxfp4_triton currently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER's gemm_a4w4 requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You

import numpy as np
import os

os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We previously already defined env NVTE_USE_CAST_TRANSPOSE_TRITON.

def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix):
"""Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle.

Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kept to ensure API consistency between Triton and the upcoming hip kernel for which I'll create a separate PR. In the hip kernel we were able to fuse the shuffle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hip vs triton flow
Input: BF16 [M, N]

MXFP4Quantizer.update_quantized()

tex.cast_transpose_mxfp4_fused_shuffle() [Single HIP kernel]

├─→ Rowwise FP4 [M, K/2] (MFMA shuffled)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (MFMA shuffled)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)

AITER gemm_a4w4 (zero-copy)

vs

Input: BF16 [M, N]

MXFP4Quantizer.update_quantized()

te_cast_transpose_mxfp4_triton() [Triton JIT kernel]

├─→ Rowwise FP4 [M, K/2] (linear layout)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (linear layout)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)

aiter.ops.shuffle.shuffle_weight() [External call]

FP4 data → MFMA layout

AITER gemm_a4w4

(32768, 160),
(4096, 1632),
(8, 32, 1024),
(16, 8, 4, 512),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some prime numbers like

{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123

Comment on lines +127 to +128
data_atol = 20.0 if in_dtype != torch.float32 else 16.0
scale_atol = 2.0 if in_dtype != torch.float32 else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme:

void adjust_ref_for_e8m0_scale_error(const std::string &name,

use_torch_semantics=True
)

# Compare only valid (non-padded) region - no shuffle extraction needed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is fp4 shuffle?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp4 shuffle basically rearranges [M, K/2] linear layout → MFMA instruction layout (16×16).

The currently flow training workflow if TE MXFP4 Quantization Kernel is used is as follows
TE Triton Kernel → Linear FP4 [N, K/2] → aiter.ops.shuffle_weight() → MFMA FP4 → aiter.gemm_a4w4()

You can find the shuffle code in aiter/aiter/ops/shuffle.py

.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:

Image

Comment on lines +61 to +62
- Data: [M, K/2] uint8 tensor (2 FP4 values packed per byte)
- Scale: [M, K/32] uint8 tensor (E8M0 format, one scale per 32-element block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there alignment/padding requirements for M and K?

Comment on lines +113 to +114
if inp.ndim < 2:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TE currently supported 2D matrices from flatted high-dimensional tensors:

size_t flat_first_dim() const {
const auto &full_shape = shape();
size_t ret = 1;
if (!full_shape.empty()) {
for (size_t i = 0; i < full_shape.size() - 1; i++) {
ret *= full_shape[i];
}
}
return ret;
}
/*! Matrix width after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
const auto &full_shape = shape();
if (full_shape.empty()) {
return 1;
} else {
return full_shape.back();
}
}
};


# Allocate PADDED scale tensors for shuffle compatibility
rowwise_scale_N = K // MXFP4_BLOCK_SCALING_SIZE
rowwise_scale_M_pad = cdiv(M, 256) * 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this 256 is from some alignment/padding requirement?

Copy link
Collaborator Author

@sarthak-amd sarthak-amd Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 256 alignment is required by AITER's CK-based MXFP4 GEMM kernels for scale tensor swizzle/shuffle layout.
Required for scale swizzle layout: 256 = ScaleBlockSize(32) × 8 waves.
See aiter/aiter/utility/fp4_utils.py:398 and gemm_a4w4_blockscale_common.cuh:66

@@ -0,0 +1,178 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to add this pytest into our ci script (somewhere near

run_default_fa 1 triton_kernels/test_norms.py
) otherwise it won't be tested

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants