Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Sep 5, 2025

Description

This PR disentangles the backend triton implementation from the front-end API, creating a unified intermediate te_norm_fwd_triton which is a generalized dispatch function. This PR is fully backwards compatible, as te_rmsnorm_fwd_triton and te_layernorm_fwd_triton are preserved and implemented as thin wrappers around te_norm_fwd_triton.

This way, when bugs appear, we fix them once without needing to duplicate across norms.

Consequently, there are some changes to the imports to accommodate this restructuring. This PR also includes a minor cleanup/simplification of previously redundant behavior in the layernorm fwd implementation, as well as support for Float8CurrentScalingQuantizer.

FWIW I don't think we can apply a similar unification to the backwards passes, as it seems that -- at least for layernorm -- the backwards implementations are pretty specialized and have asymmetric heuristics.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the Triton normalization (RMSNorm and LayerNorm) implementations by creating a unified dispatch mechanism. It introduces a new te_norm_fwd_triton function that serves as a generalized entry point for both norm types, while preserving backward compatibility by maintaining the existing API functions as thin wrappers.

Key changes include:

  • Created a unified te_norm_fwd_triton dispatch function in a new norms.py file
  • Modified kernel signatures to support both RMSNorm and LayerNorm use cases
  • Updated imports across multiple modules to reference the new consolidated location

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
transformer_engine/pytorch/triton_kernels/norms.py New file containing unified norm dispatch logic and relocated function implementations
transformer_engine/pytorch/triton_kernels/rmsnorm.py Removed te_rmsnorm_fwd_triton function and updated kernel signature for unification
transformer_engine/pytorch/triton_kernels/layernorm.py Removed forward/backward functions and simplified reduction kernel signature
transformer_engine/pytorch/ops/basic/rmsnorm.py Updated import to reference new norms module
transformer_engine/pytorch/ops/basic/layer_norm.py Updated import to reference new norms module
transformer_engine/pytorch/module/layernorm_mlp.py Consolidated imports from new norms module
transformer_engine/pytorch/module/layernorm_linear.py Consolidated imports from new norms module
transformer_engine/pytorch/module/_common.py Consolidated imports from new norms module
tests/pytorch/triton_kernels/test_norms.py Updated imports to reference new norms module

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Micky774 Micky774 marked this pull request as draft September 5, 2025 22:35
@Micky774 Micky774 marked this pull request as ready for review September 5, 2025 23:09
@Micky774
Copy link
Contributor Author

Micky774 commented Sep 5, 2025

Note that currently some of the layernorm tests are failing, but they're citing NaN vals in the expected tensor, i.e. the HIP reference kernel. I tried including @alextmagro's PR #303 but it still fails. @alextmagro have you seen such an error as well? Is it something related?

@alextmagro
Copy link
Contributor

Note that currently some of the layernorm tests are failing, but they're citing NaN vals in the expected tensor, i.e. the HIP reference kernel. I tried including @alextmagro's PR #303 but it still fails. @alextmagro have you seen such an error as well? Is it something related?

I haven't seen anything like that.

@Micky774
Copy link
Contributor Author

Micky774 commented Sep 8, 2025

Turns out something in this PR makes it so that the layernorm kernel has bad memory behavior. Specifically, it mutates either the weight tensor, or the bias tensor in the test. This happens, I believe, because they are allocated on GPU contiguously wrt each other (i.e. first input array, then gamma, then bias) which leads me to suspect that there's some kind of masking problem with the layernorm kernel, but I have not been able to pinpoint it yet.

Everything seems to work on dev, but I don't have any functional changes aside from a logical simplification of non-atomic layernorm fwd cases. Most of it is just variable renaming.

@wenchenvincent
Copy link
Collaborator

@Micky774 Could you remind me of what we had decided on this PR?

@Micky774
Copy link
Contributor Author

Micky774 commented Jan 14, 2026

@Micky774 Could you remind me of what we had decided on this PR?

@wenchenvincent last we talked about this, I think @matthiasdiener was supposed to eventually take it over. There's currently a bug exposed by this PR that will require a bit of work to resolve I think. There's some kind of memory mismanagement occurring, where output tensors' memory is being overwritten after being produced.

@Micky774
Copy link
Contributor Author

Micky774 commented Jan 28, 2026

Note that I skip the tests where the HIP kernel generates nan values, but if desired, I can instead test against a reference torch implementation. I prefer the skip, since it highlights the fact that the test isn't fully-enabled as opposed to us forgetting that certain cases were handled with a torch reference implementation...

This was a manifestation of the aforementioned bug

@Micky774 Micky774 mentioned this pull request Jan 28, 2026
13 tasks
@Micky774
Copy link
Contributor Author

cc: @wenchenvincent @wangye805
I finally found the underlying issue. There was a problem with the sizing of the amax array in the layernorm kernel, which led to unsafe memory writes corrupting adjacent data. I've corrected this.

The PR is ready for review!

@wenchenvincent
Copy link
Collaborator

LGTM. @ipanfilo You reviewed it a while ago. Do you have further comments?

@@ -1,29 +1,30 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2025-2026

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

import pytest
from functools import partial
from itertools import product
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire file is ROCM specific. Basically we can assume IS_HIP_EXTENSION is true when running this pytest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was unused anyways -- fixed.


# The scale_inv values may differ slightly, but will still dequantize close enough to
# pass the earlier comparisons.
compare_func = partial(te_compare_results, atol=1, rtol=0, use_torch_semantics=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mxfp8 data and scale inv comparison, we can reuse the same logic in cpp gtest:

void adjust_ref_for_e8m0_scale_error(const std::string &name,

#ifdef __HIP_PLATFORM_AMD__
if (::testing::Test::HasFatalFailure()) return;
adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr,
ref_output_scales.get(), scales_stride, rows, cols, rowwise,
ref_output_c.get(), otype);
mismatches_scales = 0;
#endif
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We essentially already do this implicitly by relying on the dequantization of the MXFP8Tensors before comparison. While we could handle this explicitly as in the C tests, I don't think that's necessary given that the dequantization behavior has its own testing which passes. Let me know if you have other thoughts on the matter.

# The MXFP8 tensors carry their scale_inv values in a padded
# format, hence we must omit the padded values.
input_shape = out_triton.shape
unpadded_scale_inv_shape = (math.prod(input_shape[:-1]), input_shape[-1] // MXFP8_BLOCK_SCALING_SIZE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have different shape for row-wise and col-wise scaling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We resolve this with re-indexing, but I've updated the variable name for a bit of extra clarity.

@@ -1,11 +1,11 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...fp8 import FP8GlobalStateManager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to import FP8GlobalStateManager and QuantizedTensor here? For both NV upstream and us

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed extra import

if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...tensor import QuantizedTensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed for this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it was leftover on accident. Removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright date

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, along with a few others I missed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants