feat(ops): add RmsNorm with Iluvatar, NVIDIA, CPU backends and fp16/bf16 support#6
Open
zhangyue207 wants to merge 5 commits intofeat/dev-infrafrom
Open
feat(ops): add RmsNorm with Iluvatar, NVIDIA, CPU backends and fp16/bf16 support#6zhangyue207 wants to merge 5 commits intofeat/dev-infrafrom
RmsNorm with Iluvatar, NVIDIA, CPU backends and fp16/bf16 support#6zhangyue207 wants to merge 5 commits intofeat/dev-infrafrom
Conversation
zhangyue207
commented
Mar 2, 2026
- Add 'RmsNorm' operator with 'CPU', 'NVIDIA', and 'Iluvatar' implementations
- Support fp32/fp16/bf16 on NVIDIA and Iluvatar; fp32 only on CPU
- Add shared CUDA kernel (kernel.cuh) and backend-specific wrappers
- Extend generate_wrappers.py and CMake for RmsNorm
- Add tests covering backends and dtypes
ea03f0f to
10187f4
Compare
RmsNorm with Iluvatar, NVIDIA, CPU backends and fp16/bf16 support
…eta, and trans parameters
Author
|
Iluvatar |
Author
|
Nvidia |
voltjia
reviewed
Mar 3, 2026
|
|
||
| # NVIDIA and Iluvatar are parallel backends; only one GPU backend at a time. | ||
| if(WITH_NVIDIA AND WITH_ILUVATAR) | ||
| message(FATAL_ERROR "WITH_NVIDIA and WITH_ILUVATAR cannot both be ON. Build one GPU backend at a time.") |
Collaborator
There was a problem hiding this comment.
使用 Markdown 语法:"`WITH_NVIDIA` and `WITH_ILUVATAR` cannot both be `ON`. Build one GPU backend at a time."。
| find_package(CUDAToolkit REQUIRED) | ||
| endif() | ||
|
|
||
| # Iluvatar: CUDA-compatible device, uses clang++ with -x ivcore (not nvcc). |
Collaborator
There was a problem hiding this comment.
使用 Markdown 语法:
# Iluvatar: CUDA-compatible device, uses `clang++` with `-x ivcore` (not `nvcc`).
# Reference: `InfiniCore` `xmake/iluvatar.lua`.
| if(NOT WITH_NVIDIA) | ||
| enable_language(CUDA) | ||
| find_package(CUDAToolkit REQUIRED) | ||
| set(ILUVATAR_ARCH "ivcore20" CACHE STRING "Iluvatar GPU architecture") |
Collaborator
There was a problem hiding this comment.
天数上我开发的不太多,但是我记得之前开发 Add 和 Gemm 的时候好像没有这些也编译通过了,能简单讲一下这两段代码的作用和引入理由嘛。
| from tests.utils import Payload, empty_strided, randn_strided | ||
|
|
||
|
|
||
| def _rms_norm(x, w, out, *, epsilon=1e-6): |
Collaborator
There was a problem hiding this comment.
_rms_norm 和 _torch_rms_norm 是私有的,应该放在文件后方,也就是 test_rms_norm 后面。
| return out | ||
|
|
||
|
|
||
| def _torch_rms_norm(x, w, out, *, epsilon=1e-6): |
|
|
||
| } // namespace | ||
|
|
||
| template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, |
| struct NvidiaBackend { | ||
| using stream_t = cudaStream_t; | ||
|
|
||
| static constexpr auto setDevice = [](int) {}; |
Collaborator
There was a problem hiding this comment.
文件中的命名和顺序应该按照前面说的改成跟 PyTorch 对齐的。
| CUDA_STANDARD_REQUIRED ON) | ||
| endif() | ||
|
|
||
| # Iluvatar: CUDA-compatible device; -x ivcore and flags from top-level CMakeLists.txt |
| target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) | ||
|
|
||
| set_target_properties(infiniops PROPERTIES CUDA_STANDARD 17 | ||
| CUDA_STANDARD_REQUIRED ON) |
Collaborator
There was a problem hiding this comment.
CUDA_STANDARD_REQUIRED 和 CUDA_STANDARD 平齐吧,当然这个格式咱们暂时也没有确定的标准。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.