-
Notifications
You must be signed in to change notification settings - Fork 23
Triton norms dispatch refactor #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the Triton normalization (RMSNorm and LayerNorm) implementations by creating a unified dispatch mechanism. It introduces a new te_norm_fwd_triton function that serves as a generalized entry point for both norm types, while preserving backward compatibility by maintaining the existing API functions as thin wrappers.
Key changes include:
- Created a unified
te_norm_fwd_tritondispatch function in a newnorms.pyfile - Modified kernel signatures to support both RMSNorm and LayerNorm use cases
- Updated imports across multiple modules to reference the new consolidated location
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
transformer_engine/pytorch/triton_kernels/norms.py |
New file containing unified norm dispatch logic and relocated function implementations |
transformer_engine/pytorch/triton_kernels/rmsnorm.py |
Removed te_rmsnorm_fwd_triton function and updated kernel signature for unification |
transformer_engine/pytorch/triton_kernels/layernorm.py |
Removed forward/backward functions and simplified reduction kernel signature |
transformer_engine/pytorch/ops/basic/rmsnorm.py |
Updated import to reference new norms module |
transformer_engine/pytorch/ops/basic/layer_norm.py |
Updated import to reference new norms module |
transformer_engine/pytorch/module/layernorm_mlp.py |
Consolidated imports from new norms module |
transformer_engine/pytorch/module/layernorm_linear.py |
Consolidated imports from new norms module |
transformer_engine/pytorch/module/_common.py |
Consolidated imports from new norms module |
tests/pytorch/triton_kernels/test_norms.py |
Updated imports to reference new norms module |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
Note that currently some of the layernorm tests are failing, but they're citing |
I haven't seen anything like that. |
|
Turns out something in this PR makes it so that the layernorm kernel has bad memory behavior. Specifically, it mutates either the weight tensor, or the bias tensor in the test. This happens, I believe, because they are allocated on GPU contiguously wrt each other (i.e. first input array, then gamma, then bias) which leads me to suspect that there's some kind of masking problem with the layernorm kernel, but I have not been able to pinpoint it yet. Everything seems to work on |
|
@Micky774 Could you remind me of what we had decided on this PR? |
@wenchenvincent last we talked about this, I think @matthiasdiener was supposed to eventually take it over. There's currently a bug exposed by this PR that will require a bit of work to resolve I think. There's some kind of memory mismanagement occurring, where output tensors' memory is being overwritten after being produced. |
|
This was a manifestation of the aforementioned bug |
|
cc: @wenchenvincent @wangye805 The PR is ready for review! |
|
LGTM. @ipanfilo You reviewed it a while ago. Do you have further comments? |
| @@ -1,29 +1,30 @@ | |||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2025-2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| import pytest | ||
| from functools import partial | ||
| from itertools import product | ||
| from torch.utils.cpp_extension import IS_HIP_EXTENSION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire file is ROCM specific. Basically we can assume IS_HIP_EXTENSION is true when running this pytest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was unused anyways -- fixed.
|
|
||
| # The scale_inv values may differ slightly, but will still dequantize close enough to | ||
| # pass the earlier comparisons. | ||
| compare_func = partial(te_compare_results, atol=1, rtol=0, use_torch_semantics=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For mxfp8 data and scale inv comparison, we can reuse the same logic in cpp gtest:
TransformerEngine/tests/cpp/test_common.cu
Line 730 in 0dfee56
| void adjust_ref_for_e8m0_scale_error(const std::string &name, |
TransformerEngine/tests/cpp/operator/test_cast_mxfp8.cu
Lines 331 to 355 in 0dfee56
| #ifdef __HIP_PLATFORM_AMD__ | |
| if (::testing::Test::HasFatalFailure()) return; | |
| adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, | |
| ref_output_scales.get(), scales_stride, rows, cols, rowwise, | |
| ref_output_c.get(), otype); | |
| mismatches_scales = 0; | |
| #endif | |
| const size_t mismatches_elts = 32 * mismatches_scales; | |
| auto [atol, rtol] = getTolerances(otype); | |
| compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); | |
| if (processing_method == ProcessingMethod::CAST_DBIAS | |
| || processing_method == ProcessingMethod::CAST_DBIAS_DACT) | |
| { | |
| auto [atol_dbias, rtol_dbias] = getTolerances(itype); | |
| if (itype == DType::kFloat32) { | |
| atol_dbias = 1e-4; | |
| rtol_dbias *= sqrt(static_cast<double>(rows)) ; | |
| } else { | |
| rtol_dbias *= 4; | |
| } | |
| compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); | |
| } | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We essentially already do this implicitly by relying on the dequantization of the MXFP8Tensors before comparison. While we could handle this explicitly as in the C tests, I don't think that's necessary given that the dequantization behavior has its own testing which passes. Let me know if you have other thoughts on the matter.
| # The MXFP8 tensors carry their scale_inv values in a padded | ||
| # format, hence we must omit the padded values. | ||
| input_shape = out_triton.shape | ||
| unpadded_scale_inv_shape = (math.prod(input_shape[:-1]), input_shape[-1] // MXFP8_BLOCK_SCALING_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have different shape for row-wise and col-wise scaling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We resolve this with re-indexing, but I've updated the variable name for a bit of extra clarity.
| @@ -1,11 +1,11 @@ | |||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| if IS_HIP_EXTENSION: | ||
| from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton | ||
| from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton | ||
| from ...fp8 import FP8GlobalStateManager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to import FP8GlobalStateManager and QuantizedTensor here? For both NV upstream and us
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed extra import
| if IS_HIP_EXTENSION: | ||
| from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton | ||
| from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton | ||
| from ...tensor import QuantizedTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it needed for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it was leftover on accident. Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright date
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, along with a few others I missed.
Description
This PR disentangles the backend triton implementation from the front-end API, creating a unified intermediate
te_norm_fwd_tritonwhich is a generalized dispatch function. This PR is fully backwards compatible, aste_rmsnorm_fwd_tritonandte_layernorm_fwd_tritonare preserved and implemented as thin wrappers aroundte_norm_fwd_triton.This way, when bugs appear, we fix them once without needing to duplicate across norms.
Consequently, there are some changes to the imports to accommodate this restructuring. This PR also includes a minor cleanup/simplification of previously redundant behavior in the layernorm fwd implementation, as well as support for
Float8CurrentScalingQuantizer.FWIW I don't think we can apply a similar unification to the backwards passes, as it seems that -- at least for layernorm -- the backwards implementations are pretty specialized and have asymmetric heuristics.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: