-
Notifications
You must be signed in to change notification settings - Fork 73
Add Cutlass MxFp8 Block Scale Matrix Multiplication #5736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit ca3d9f6 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
| ||||||||
| Configuration changes |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Missing Performance Data
|
Greptile SummaryThis PR adds support for MxFp8 (microscaling FP8) block-scaled matrix multiplication to nvfuser_direct by implementing CUTLASS kernels optimized for SM100+ (compute capability 10.x) architectures. Key Changes:
Architecture: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Python User
participant Binding as cutlass.cpp (Python Binding)
participant API as mxfp8_scaled_mm (nvf_cutlass.cpp)
participant Validator as validateInputsMxFp8ScaledMm
participant Kernel as runGemm<T>
participant CUTLASS as CUTLASS GEMM Adapter
User->>Binding: mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
Binding->>API: cutlass_kernels::mxfp8_scaled_mm(...)
API->>Validator: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks)
Validator->>Validator: Check dimensions (a.dim==2, b.dim==2)
Validator->>Validator: Check CUDA device & contiguity
Validator->>Validator: Validate dtypes (Float8_e4m3fn, Float8_e8m0fnu)
Validator->>Validator: Check alignment (K%16==0, N%16==0)
Validator->>Validator: Validate scale matrix shapes
Validator-->>API: Return (m, n, k)
API->>API: Create output tensor
API->>Kernel: runGemm<cutlass::half_t or bfloat16_t>(...)
Kernel->>Kernel: args_from_options (setup CUTLASS arguments)
Kernel->>Kernel: Allocate workspace
Kernel->>CUTLASS: gemm.can_implement(arguments)
CUTLASS-->>Kernel: Status
Kernel->>CUTLASS: gemm.initialize(arguments, workspace, stream)
CUTLASS-->>Kernel: Status
Kernel->>CUTLASS: gemm.run(arguments, workspace, stream)
CUTLASS-->>Kernel: Status (compute C = alpha * A @ B)
Kernel-->>API: Return
API-->>Binding: Return output tensor
Binding-->>User: Return torch.Tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 5 comments
67ef893 to
372473e
Compare
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
372473e to
71390b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
71390b5 to
ca3d9f6
Compare
|
!test |
This PR adds MxFp8 cutlass kernels to
nvfuser_direct.Shape<_256, _256, _256>Shape<_2, _4, _1>Shape<_128, _256, _256>