Skip to content

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Dec 23, 2025

This PR adds MxFp8 cutlass kernels to nvfuser_direct.

  • FP16 and BF16 dtype outputs.
  • The MmaTileShape Shape<_256, _256, _256>
  • The Cluster Shape is Shape<_2, _4, _1>
  • PerSmTileShape_MNK is Shape<_128, _256, _256>

@github-actions
Copy link

github-actions bot commented Dec 23, 2025

Review updated until commit ca3d9f6

Description

  • Add MXFP8 scaled matrix multiplication using CUTLASS kernels for SM100+ GPUs

  • Support FP16 and BF16 output data types with optimized tile shapes

  • Implement comprehensive input validation and scale factor handling

  • Add complete test suite with quantization/dequantization utilities

  • Integrate new kernel into build system and Python bindings

Changes walkthrough

Relevant files
Enhancement
nvf_cutlass.cpp
Add MXFP8 input validation function                                           

cutlass/nvf_cutlass.cpp

  • Added validateInputsMxFp8ScaledMm function for comprehensive input
    validation
  • Validates MXFP8 data types, alignment requirements, and scale matrix
    properties
  • Checks CUDA device, contiguity, and matrix dimension compatibility
  • +112/-0 
    cutlass.cpp
    Add Python binding for MXFP8 GEMM                                               

    python/python_direct/cutlass.cpp

  • Added mxfp8_scaled_mm Python binding with proper documentation
  • Updated nvfp4_scaled_mm docstring to remove fp32 output mention
  • Exported new MXFP8 functionality to Python interface
  • +21/-1   
    mxfp8_scaled_mm.cu
    Implement MXFP8 CUTLASS kernel                                                     

    cutlass/mxfp8_scaled_mm.cu

  • Implemented core MXFP8 scaled matrix multiplication kernel using
    CUTLASS
  • Configured kernel traits for FP16/BF16 outputs with optimized tile
    shapes
  • Added argument construction and GEMM execution functions
  • Included fallback for unsupported CUTLASS versions
  • +316/-0 
    nvf_cutlass.h
    Add MXFP8 function declarations                                                   

    cutlass/nvf_cutlass.h

  • Added declarations for validateInputsMxFp8ScaledMm and mxfp8_scaled_mm
  • Included comprehensive documentation for new functions
  • Extended API with MXFP8 matrix multiplication capabilities
  • +51/-0   
    Tests
    test_cutlass_mxfp8_gemm.py
    Add MXFP8 GEMM test suite                                                               

    tests/python/direct/test_cutlass_mxfp8_gemm.py

  • Added comprehensive test suite for MXFP8 GEMM operations
  • Implemented quantization/dequantization utilities and reference
    implementation
  • Tests cover FP16/BF16 outputs and multiple matrix shapes
  • Validates compute capability requirements and proper scale factor
    handling
  • +122/-0 
    Configuration changes
    CMakeLists.txt
    Add MXFP8 kernel to build                                                               

    CMakeLists.txt

  • Added mxfp8_scaled_mm.cu to NVFUSER_CUTLASS_SRCS list
  • Integrated new kernel into build system compilation
  • +1/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Missing Performance Data

    The PR lacks performance benchmarking data and comparison with existing implementations. No performance goals, roofline analysis, or quantitative results are provided to demonstrate the effectiveness of the MXFP8 implementation.

    // clang-format off
    /*
     * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
     * All rights reserved.
     * SPDX-License-Identifier: BSD-3-Clause
     */
    // clang-format on
    #include <cutlass_utils.h>
    #include <exceptions.h>
    #include <nvf_cutlass.h>
    
    #include <ATen/cuda/CUDAContext.h>
    #include <c10/cuda/CUDAGuard.h>
    #include <torch/torch.h>
    
    #include "cutlass/cutlass.h"
    #include "cutlass/epilogue/collective/collective_builder.hpp"
    #include "cutlass/gemm/collective/collective_builder.hpp"
    #include "cutlass/gemm/device/gemm_universal_adapter.h"
    #include "cutlass/gemm/kernel/gemm_universal.hpp"
    #include "cutlass/util/packed_stride.hpp"
    
    namespace nvfuser::cutlass_kernels {
    
    namespace {
    
    using namespace cute;
    
    #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    // Kernel configuration traits for different output data types
    // Defines tile shapes and cluster configurations.
    template <typename T>
    struct KernelTraits;
    
    // Kernel traits for FP16 output
    template <>
    struct KernelTraits<cutlass::half_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Kernel traits for BF16 output
    template <>
    struct KernelTraits<cutlass::bfloat16_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Main GEMM configuration for MXFP8 scaled matrix multiplication on SM100+
    // Defines all the types, layouts, and configurations needed for the CUTLASS
    // kernel
    template <typename T>
    struct MxFp8GemmSm100 {
      // A matrix configuration
      using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
      using LayoutATag = cutlass::layout::RowMajor;
      static constexpr int kAlignmentA = 16;
    
      // B matrix configuration
      using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
      using LayoutBTag = cutlass::layout::ColumnMajor;
      static constexpr int kAlignmentB = 16;
    
      // C/D matrix configuration
      using ElementD = T;
      using ElementC = T;
      using LayoutCTag = cutlass::layout::RowMajor;
      using LayoutDTag = cutlass::layout::RowMajor;
      static constexpr int kAlignmentD =
          128 / cutlass::sizeof_bits<ElementD>::value;
      static constexpr int kAlignmentC =
          128 / cutlass::sizeof_bits<ElementC>::value;
      // Kernel functional config
      using ElementAccumulator = float;
      using ArchTag = cutlass::arch::Sm100;
      using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
    
      // Kernel Perf config
      using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
      using ClusterShape = typename KernelTraits<T>::ClusterShape;
      using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
    
      using CollectiveEpilogue =
          typename cutlass::epilogue::collective::CollectiveBuilder<
              ArchTag,
              OperatorClass,
              PerSmTileShape_MNK,
              ClusterShape,
              cutlass::epilogue::collective::EpilogueTileAuto,
              ElementAccumulator,
              ElementAccumulator,
              ElementC,
              LayoutCTag,
              kAlignmentC,
              ElementD,
              LayoutDTag,
              kAlignmentD,
              cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
    
      using CollectiveMainloop =
          typename cutlass::gemm::collective::CollectiveBuilder<
              ArchTag,
              OperatorClass,
              ElementA,
              LayoutATag,
              kAlignmentA,
              ElementB,
              LayoutBTag,
              kAlignmentB,
              ElementAccumulator,
              MmaTileShape,
              ClusterShape,
              cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
                  sizeof(typename CollectiveEpilogue::SharedStorage))>,
              cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
    
      using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
          Shape<int, int, int, int>,
          CollectiveMainloop,
          CollectiveEpilogue,
          void>;
      using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
    
      // Reference device GEMM implementation type
      using StrideA = typename Gemm::GemmKernel::StrideA;
      using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
      // Scale Factor tensors have an interleaved layout. Bring Layout instead of
      // stride.
      using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
      using StrideB = typename Gemm::GemmKernel::StrideB;
      using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
      // Scale Factor tensors have an interleaved layout. Bring Layout instead of
      // stride.
      using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
      using StrideC = typename Gemm::GemmKernel::StrideC;
      using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
      using StrideD = typename Gemm::GemmKernel::StrideD;
      using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
    };
    
    // Constructs CUTLASS GEMM arguments from PyTorch tensors and dimensions
    //
    // This function converts PyTorch tensor data and metadata into the format
    // expected by CUTLASS GEMM kernels, including proper stride calculations
    // and layout configurations for the scaled matrix multiplication.
    //
    // Parameters:
    //   output: Output tensor for storing results
    //   a: Input matrix A in MXFP8 format
    //   b: Input matrix B in MXFP8 format
    //   scales_a: Per-block scaling factors for matrix A
    //   scales_b: Per-block scaling factors for matrix B
    //   alpha: Global scaling factor
    //   M, N, K: Matrix dimensions
    //
    // Returns: CUTLASS GEMM arguments structure ready for kernel execution
    template <typename T>
    typename T::Gemm::Arguments args_from_options(
        at::Tensor& output,
        const at::Tensor& a,
        const at::Tensor& b,
        const at::Tensor& scales_a,
        const at::Tensor& scales_b,
        const at::Tensor& alpha,
        int64_t M,
        int64_t N,
        int64_t K) {
      using ElementA = typename T::Gemm::ElementA;
      using ElementB = typename T::Gemm::ElementB;
      using ElementSFA = cutlass::float_ue8m0_t;
      using ElementSFB = cutlass::float_ue8m0_t;
      using ElementD = typename T::Gemm::ElementD;
      using ElementCompute = float;
      using StrideA = typename T::StrideA;
      using StrideB = typename T::StrideB;
      using StrideD = typename T::StrideD;
      using Sm1xxBlkScaledConfig =
          typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
    
      int m = static_cast<int>(M);
      int n = static_cast<int>(N);
      int k = static_cast<int>(K);
      auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
      auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
      auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
    
      auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
          cute::make_shape(m, n, k, 1));
      auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
          cute::make_shape(m, n, k, 1));
    
      typename T::Gemm::Arguments arguments{
          cutlass::gemm::GemmUniversalMode::kGemm,
          {m, n, k, 1},
          {// Mainloop arguments
           static_cast<ElementA const*>(a.data_ptr()),
           stride_A,
           static_cast<ElementB const*>(b.data_ptr()),
           stride_B,
           static_cast<ElementSFA const*>(scales_a.data_ptr()),
           layout_SFA,
           static_cast<ElementSFB const*>(scales_b.data_ptr()),
           layout_SFB},
          {// Epilogue arguments
           {}, // epilogue.thread
           static_cast<ElementD const*>(output.data_ptr()),
           stride_D,
           static_cast<ElementD*>(output.data_ptr()),
           stride_D}};
      auto& fusion_args = arguments.epilogue.thread;
      fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
      return arguments;
    }
    
    // Executes the MXFP8 scaled matrix multiplication using CUTLASS kernels
    //
    // This function orchestrates the GEMM operation by setting up the kernel,
    // allocating workspace memory, and running the computation on the GPU.
    // It handles the complete lifecycle from kernel initialization to execution.
    //
    // Parameters:
    //   output: Output tensor to store the result
    //   a, b: Input matrices in MXFP8 format
    //   scales_a, scales_b: Per-block scaling factors
    //   alpha: Global scaling factor
    //   m, n, k: Matrix dimensions
    //   stream: CUDA stream for asynchronous execution
    template <typename T>
    void runGemm(
        at::Tensor& output,
        const at::Tensor& a,
        const at::Tensor& b,
        const at::Tensor& scales_a,
        const at::Tensor& scales_b,
        const at::Tensor& alpha,
        int64_t m,
        int64_t n,
        int64_t k,
        cudaStream_t stream) {
      typename MxFp8GemmSm100<T>::Gemm gemm;
    
      auto arguments = args_from_options<MxFp8GemmSm100<T>>(
          output, a, b, scales_a, scales_b, alpha, m, n, k);
    
      size_t workspace_size =
          MxFp8GemmSm100<T>::Gemm::get_workspace_size(arguments);
      auto const workspace_options =
          torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
      auto workspace = torch::empty(workspace_size, workspace_options);
    
      auto can_implement_status = gemm.can_implement(arguments);
      NVF_CHECK(
          can_implement_status == cutlass::Status::kSuccess,
          "Failed to implement GEMM");
    
      auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
      NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
    
      status = gemm.run(arguments, workspace.data_ptr(), stream);
      NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
    }
    #else
    // Fallback implementation for unsupported CUTLASS versions
    // Throws an error when SM100+ CUTLASS support is not available
    template <typename T>
    void runGemm(
        at::Tensor& output,
        at::Tensor const& a,
        at::Tensor const& b,
        at::Tensor const& scales_a,
        at::Tensor const& scales_b,
        at::Tensor const& alpha,
        int64_t m,
        int64_t n,
        int64_t k,
        cudaStream_t stream) {
      NVF_THROW("Unsupported CUTLASS version.");
    }
    #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    
    } // namespace
    
    torch::Tensor mxfp8_scaled_mm(
        const torch::Tensor& a,
        const torch::Tensor& b,
        const torch::Tensor& scales_a,
        const torch::Tensor& scales_b,
        const torch::Tensor& alpha,
        const at::ScalarType out_dtype,
        bool skip_checks) {
      // Validate all inputs and get matrix dimensions
      auto [m, n, k] =
          validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks);
    
      at::cuda::CUDAGuard device_guard{(int8_t)a.get_device()};
      const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
    
      auto options =
          at::TensorOptions().dtype(out_dtype).device(at::kCUDA, a.get_device());
      torch::Tensor output = at::empty({a.sizes()[0], b.sizes()[0]}, options);
    
      if (out_dtype == at::ScalarType::Half) {
        runGemm<cutlass::half_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else if (out_dtype == at::ScalarType::BFloat16) {
        runGemm<cutlass::bfloat16_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else {
        NVF_THROW("Unsupported output data type of mxfp8 scaled_mm.");
      }
      return output;
    }
    
    } // namespace nvfuser::cutlass_kernels
    No Regression Analysis

    No analysis of potential performance regressions or impact on existing functionality is provided. The PR should include comparative performance data and ensure no regressions in related operations.

    // clang-format off
    /*
     * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
     * All rights reserved.
     * SPDX-License-Identifier: BSD-3-Clause
     */
    // clang-format on
    #include <cutlass_utils.h>
    #include <exceptions.h>
    #include <nvf_cutlass.h>
    
    #include <ATen/cuda/CUDAContext.h>
    #include <c10/cuda/CUDAGuard.h>
    #include <torch/torch.h>
    
    #include "cutlass/cutlass.h"
    #include "cutlass/epilogue/collective/collective_builder.hpp"
    #include "cutlass/gemm/collective/collective_builder.hpp"
    #include "cutlass/gemm/device/gemm_universal_adapter.h"
    #include "cutlass/gemm/kernel/gemm_universal.hpp"
    #include "cutlass/util/packed_stride.hpp"
    
    namespace nvfuser::cutlass_kernels {
    
    namespace {
    
    using namespace cute;
    
    #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    // Kernel configuration traits for different output data types
    // Defines tile shapes and cluster configurations.
    template <typename T>
    struct KernelTraits;
    
    // Kernel traits for FP16 output
    template <>
    struct KernelTraits<cutlass::half_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Kernel traits for BF16 output
    template <>
    struct KernelTraits<cutlass::bfloat16_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Main GEMM configuration for MXFP8 scaled matrix multiplication on SM100+
    // Defines all the types, layouts, and configurations needed for the CUTLASS
    // kernel
    template <typename T>
    struct MxFp8GemmSm100 {
      // A matrix configuration
      using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
      using LayoutATag = cutlass::layout::RowMajor;
      static constexpr int kAlignmentA = 16;
    
      // B matrix configuration
      using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
      using LayoutBTag = cutlass::layout::ColumnMajor;
      static constexpr int kAlignmentB = 16;
    
      // C/D matrix configuration
      using ElementD = T;
      using ElementC = T;
      using LayoutCTag = cutlass::layout::RowMajor;
      using LayoutDTag = cutlass::layout::RowMajor;
      static constexpr int kAlignmentD =
          128 / cutlass::sizeof_bits<ElementD>::value;
      static constexpr int kAlignmentC =
          128 / cutlass::sizeof_bits<ElementC>::value;
      // Kernel functional config
      using ElementAccumulator = float;
      using ArchTag = cutlass::arch::Sm100;
      using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
    
      // Kernel Perf config
      using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
      using ClusterShape = typename KernelTraits<T>::ClusterShape;
      using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
    
      using CollectiveEpilogue =
          typename cutlass::epilogue::collective::CollectiveBuilder<
              ArchTag,
              OperatorClass,
              PerSmTileShape_MNK,
              ClusterShape,
              cutlass::epilogue::collective::EpilogueTileAuto,
              ElementAccumulator,
              ElementAccumulator,
              ElementC,
              LayoutCTag,
              kAlignmentC,
              ElementD,
              LayoutDTag,
              kAlignmentD,
              cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
    
      using CollectiveMainloop =
          typename cutlass::gemm::collective::CollectiveBuilder<
              ArchTag,
              OperatorClass,
              ElementA,
              LayoutATag,
              kAlignmentA,
              ElementB,
              LayoutBTag,
              kAlignmentB,
              ElementAccumulator,
              MmaTileShape,
              ClusterShape,
              cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
                  sizeof(typename CollectiveEpilogue::SharedStorage))>,
              cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
    
      using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
          Shape<int, int, int, int>,
          CollectiveMainloop,
          CollectiveEpilogue,
          void>;
      using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
    
      // Reference device GEMM implementation type
      using StrideA = typename Gemm::GemmKernel::StrideA;
      using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
      // Scale Factor tensors have an interleaved layout. Bring Layout instead of
      // stride.
      using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
      using StrideB = typename Gemm::GemmKernel::StrideB;
      using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
      // Scale Factor tensors have an interleaved layout. Bring Layout instead of
      // stride.
      using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
      using StrideC = typename Gemm::GemmKernel::StrideC;
      using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
      using StrideD = typename Gemm::GemmKernel::StrideD;
      using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
    };
    
    // Constructs CUTLASS GEMM arguments from PyTorch tensors and dimensions
    //
    // This function converts PyTorch tensor data and metadata into the format
    // expected by CUTLASS GEMM kernels, including proper stride calculations
    // and layout configurations for the scaled matrix multiplication.
    //
    // Parameters:
    //   output: Output tensor for storing results
    //   a: Input matrix A in MXFP8 format
    //   b: Input matrix B in MXFP8 format
    //   scales_a: Per-block scaling factors for matrix A
    //   scales_b: Per-block scaling factors for matrix B
    //   alpha: Global scaling factor
    //   M, N, K: Matrix dimensions
    //
    // Returns: CUTLASS GEMM arguments structure ready for kernel execution
    template <typename T>
    typename T::Gemm::Arguments args_from_options(
        at::Tensor& output,
        const at::Tensor& a,
        const at::Tensor& b,
        const at::Tensor& scales_a,
        const at::Tensor& scales_b,
        const at::Tensor& alpha,
        int64_t M,
        int64_t N,
        int64_t K) {
      using ElementA = typename T::Gemm::ElementA;
      using ElementB = typename T::Gemm::ElementB;
      using ElementSFA = cutlass::float_ue8m0_t;
      using ElementSFB = cutlass::float_ue8m0_t;
      using ElementD = typename T::Gemm::ElementD;
      using ElementCompute = float;
      using StrideA = typename T::StrideA;
      using StrideB = typename T::StrideB;
      using StrideD = typename T::StrideD;
      using Sm1xxBlkScaledConfig =
          typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
    
      int m = static_cast<int>(M);
      int n = static_cast<int>(N);
      int k = static_cast<int>(K);
      auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
      auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
      auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
    
      auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
          cute::make_shape(m, n, k, 1));
      auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
          cute::make_shape(m, n, k, 1));
    
      typename T::Gemm::Arguments arguments{
          cutlass::gemm::GemmUniversalMode::kGemm,
          {m, n, k, 1},
          {// Mainloop arguments
           static_cast<ElementA const*>(a.data_ptr()),
           stride_A,
           static_cast<ElementB const*>(b.data_ptr()),
           stride_B,
           static_cast<ElementSFA const*>(scales_a.data_ptr()),
           layout_SFA,
           static_cast<ElementSFB const*>(scales_b.data_ptr()),
           layout_SFB},
          {// Epilogue arguments
           {}, // epilogue.thread
           static_cast<ElementD const*>(output.data_ptr()),
           stride_D,
           static_cast<ElementD*>(output.data_ptr()),
           stride_D}};
      auto& fusion_args = arguments.epilogue.thread;
      fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
      return arguments;
    }
    
    // Executes the MXFP8 scaled matrix multiplication using CUTLASS kernels
    //
    // This function orchestrates the GEMM operation by setting up the kernel,
    // allocating workspace memory, and running the computation on the GPU.
    // It handles the complete lifecycle from kernel initialization to execution.
    //
    // Parameters:
    //   output: Output tensor to store the result
    //   a, b: Input matrices in MXFP8 format
    //   scales_a, scales_b: Per-block scaling factors
    //   alpha: Global scaling factor
    //   m, n, k: Matrix dimensions
    //   stream: CUDA stream for asynchronous execution
    template <typename T>
    void runGemm(
        at::Tensor& output,
        const at::Tensor& a,
        const at::Tensor& b,
        const at::Tensor& scales_a,
        const at::Tensor& scales_b,
        const at::Tensor& alpha,
        int64_t m,
        int64_t n,
        int64_t k,
        cudaStream_t stream) {
      typename MxFp8GemmSm100<T>::Gemm gemm;
    
      auto arguments = args_from_options<MxFp8GemmSm100<T>>(
          output, a, b, scales_a, scales_b, alpha, m, n, k);
    
      size_t workspace_size =
          MxFp8GemmSm100<T>::Gemm::get_workspace_size(arguments);
      auto const workspace_options =
          torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
      auto workspace = torch::empty(workspace_size, workspace_options);
    
      auto can_implement_status = gemm.can_implement(arguments);
      NVF_CHECK(
          can_implement_status == cutlass::Status::kSuccess,
          "Failed to implement GEMM");
    
      auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
      NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
    
      status = gemm.run(arguments, workspace.data_ptr(), stream);
      NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
    }
    #else
    // Fallback implementation for unsupported CUTLASS versions
    // Throws an error when SM100+ CUTLASS support is not available
    template <typename T>
    void runGemm(
        at::Tensor& output,
        at::Tensor const& a,
        at::Tensor const& b,
        at::Tensor const& scales_a,
        at::Tensor const& scales_b,
        at::Tensor const& alpha,
        int64_t m,
        int64_t n,
        int64_t k,
        cudaStream_t stream) {
      NVF_THROW("Unsupported CUTLASS version.");
    }
    #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    
    } // namespace
    
    torch::Tensor mxfp8_scaled_mm(
        const torch::Tensor& a,
        const torch::Tensor& b,
        const torch::Tensor& scales_a,
        const torch::Tensor& scales_b,
        const torch::Tensor& alpha,
        const at::ScalarType out_dtype,
        bool skip_checks) {
      // Validate all inputs and get matrix dimensions
      auto [m, n, k] =
          validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks);
    
      at::cuda::CUDAGuard device_guard{(int8_t)a.get_device()};
      const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
    
      auto options =
          at::TensorOptions().dtype(out_dtype).device(at::kCUDA, a.get_device());
      torch::Tensor output = at::empty({a.sizes()[0], b.sizes()[0]}, options);
    
      if (out_dtype == at::ScalarType::Half) {
        runGemm<cutlass::half_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else if (out_dtype == at::ScalarType::BFloat16) {
        runGemm<cutlass::bfloat16_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else {
        NVF_THROW("Unsupported output data type of mxfp8 scaled_mm.");
      }
      return output;
    }
    
    } // namespace nvfuser::cutlass_kernels
    Limited Test Coverage

    While tests are present, they only cover basic correctness. Additional tests should include edge cases, different tensor sizes, and performance stress tests to ensure robustness across various scenarios.

    # SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
    # All rights reserved.
    # SPDX-License-Identifier: BSD-3-Clause
    # Owner(s): ["module: nvfuser"]
    
    import pytest
    import torch
    from nvfuser_direct import nvf_cutlass
    
    compute_cap = torch.cuda.get_device_capability()
    if compute_cap < (10, 0) or compute_cap >= (12, 0):
        pytest.skip(
            reason="MxFp8 Requires compute capability 10.",
            allow_module_level=True,
        )
    
    from python.direct_utils import (
        linear_to_swizzled_128_4,
        swizzled_to_linear_128_4,
    )
    
    
    def dequantize_mxfp8(tensor_fp8, tensor_sf):
        """Dequantize the fp8 tensor back to high precision."""
        m, k = tensor_fp8.shape
        BLOCK_SIZE = 32
        tensor_sf_linear = swizzled_to_linear_128_4(tensor_sf, m, k)
        # Apply scale factor to all elements in the same block
        sf = tensor_sf_linear.repeat_interleave(BLOCK_SIZE, dim=1).to(torch.float32)
        dqx = tensor_fp8.to(torch.float32)
        # Account for padding of scale factor
        sf = sf[: dqx.shape[0], : dqx.shape[1]]
        dequant = dqx * sf
        return dequant.reshape(m, k)
    
    
    def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
        finfo = torch.finfo(torch.float8_e4m3fn)
        return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
            dtype=torch.float8_e4m3fn
        )
    
    
    def pytorch_mxfp8_quantize(a):
        BLOCK_SIZE = 32
        assert (
            a.size(-1) % BLOCK_SIZE == 0
        ), "The inner-most dim must be divisible by block_size; Padding is not implemented."
        assert a.is_contiguous(), "Only contiguous tensors are supported."
    
        # Find absolute maximum along blockwise dimension
        original_shape = a.shape
        a_fp32 = a.float().reshape(original_shape[0], -1, BLOCK_SIZE)
        max_abs = torch.amax(torch.abs(a_fp32), dim=-1)
    
        # Get fp32 block scale factor for fp8
        FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
        block_scale_fp32 = (max_abs / FLOAT8_E4M3_MAX).float()
    
        # Clamp scale factor within UE8M0
        FLOAT8_UE8M0_EPS = torch.finfo(torch.float8_e8m0fnu).tiny
        FLOAT8_UE8M0_MAX = torch.finfo(torch.float8_e8m0fnu).max
        block_scale_fp32 = torch.clamp(
            block_scale_fp32, min=FLOAT8_UE8M0_EPS, max=FLOAT8_UE8M0_MAX
        )
    
        # Apply block conversion factor
        a_scaled = a_fp32 / block_scale_fp32.unsqueeze(-1)
        a_scaled = a_scaled.view(original_shape)
    
        return to_fp8(a_scaled), block_scale_fp32.to(torch.float8_e8m0fnu)
    
    
    def get_ref_results(
        a_fp8,
        b_fp8,
        a_sf,
        b_sf,
        m,
        n,
    ):
        _, m_k = a_fp8.shape
        _, n_k = b_fp8.shape
        assert m_k == n_k
        a_in_dtype = dequantize_mxfp8(a_fp8, a_sf)
        b_in_dtype = dequantize_mxfp8(b_fp8, b_sf)
        return torch.matmul(a_in_dtype, b_in_dtype.t())
    
    
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    @pytest.mark.parametrize(
        "shape", [(128, 128, 128), (128, 128, 256), (256, 128, 128), (128, 256, 256)]
    )
    @torch.inference_mode()
    def test_mxfp8_gemm(
        dtype: torch.dtype,
        shape: tuple[int, int, int],
    ) -> None:
        m, n, k = shape
        block_size = 32
        a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
        b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
    
        alpha = torch.tensor(1.0, device="cuda")
        a_fp8, a_scale_linear = pytorch_mxfp8_quantize(a_dtype)
        b_fp8, b_scale_linear = pytorch_mxfp8_quantize(b_dtype)
        a_scale_interleaved = linear_to_swizzled_128_4(a_scale_linear)
        b_scale_interleaved = linear_to_swizzled_128_4(b_scale_linear)
    
        expected_out = get_ref_results(
            a_fp8,
            b_fp8,
            a_scale_interleaved,
            b_scale_interleaved,
            m,
            n,
        )
        out = nvf_cutlass.mxfp8_scaled_mm(
            a_fp8, b_fp8, a_scale_interleaved, b_scale_interleaved, alpha, dtype
        )
    
        torch.testing.assert_close(out, expected_out.to(dtype=dtype))

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 23, 2025

    Greptile Summary

    This PR adds support for MxFp8 (microscaling FP8) block-scaled matrix multiplication to nvfuser_direct by implementing CUTLASS kernels optimized for SM100+ (compute capability 10.x) architectures.

    Key Changes:

    • Implements mxfp8_scaled_mm kernel in cutlass/mxfp8_scaled_mm.cu using CUTLASS 3.x collective builders with cluster shape Shape<_2, _4, _1> and tile shape Shape<_256, _256, _256>
    • Supports FP16 and BF16 output dtypes with Float8_e4m3fn input matrices and Float8_e8m0fnu block scale factors
    • Follows established patterns from nvfp4_scaled_mm.cu for consistency across the codebase
    • Adds comprehensive input validation in validateInputsMxFp8ScaledMm with alignment checks (K and N must be divisible by 16) and scale matrix shape validation
    • Includes Python bindings and test suite with quantization/dequantization utilities for correctness verification

    Architecture:
    The implementation uses the same architecture as existing NVFP4 kernels: validation layer → kernel launcher → CUTLASS adapter. The MxFp8 format uses per-block (32 elements) scaling factors stored in an interleaved/swizzled layout for optimal memory access patterns.

    Confidence Score: 5/5

    • This PR is safe to merge with no critical issues found
    • The implementation follows established patterns from existing nvfp4 kernels, includes comprehensive validation logic, and provides thorough test coverage. The code is well-documented and properly integrates with the existing build system. Previous review comments addressed error message corrections which are minor issues.
    • No files require special attention

    Important Files Changed

    Filename Overview
    cutlass/mxfp8_scaled_mm.cu New CUTLASS kernel implementation for MxFp8 matrix multiplication with proper SM100+ architecture support, follows existing patterns from nvfp4_scaled_mm.cu
    cutlass/nvf_cutlass.cpp Adds comprehensive input validation for MxFp8 operations with proper dtype checks and alignment requirements
    cutlass/nvf_cutlass.h Header declarations for MxFp8 API matching existing NVFP4 patterns with appropriate documentation
    python/python_direct/cutlass.cpp Python bindings for mxfp8_scaled_mm added correctly with proper documentation
    tests/python/direct/test_cutlass_mxfp8_gemm.py Comprehensive test suite with quantization, dequantization, and reference comparison for multiple dtypes and shapes
    CMakeLists.txt Adds mxfp8_scaled_mm.cu to build system correctly

    Sequence Diagram

    sequenceDiagram
        participant User as Python User
        participant Binding as cutlass.cpp (Python Binding)
        participant API as mxfp8_scaled_mm (nvf_cutlass.cpp)
        participant Validator as validateInputsMxFp8ScaledMm
        participant Kernel as runGemm<T>
        participant CUTLASS as CUTLASS GEMM Adapter
    
        User->>Binding: mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
        Binding->>API: cutlass_kernels::mxfp8_scaled_mm(...)
        API->>Validator: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks)
        Validator->>Validator: Check dimensions (a.dim==2, b.dim==2)
        Validator->>Validator: Check CUDA device & contiguity
        Validator->>Validator: Validate dtypes (Float8_e4m3fn, Float8_e8m0fnu)
        Validator->>Validator: Check alignment (K%16==0, N%16==0)
        Validator->>Validator: Validate scale matrix shapes
        Validator-->>API: Return (m, n, k)
        API->>API: Create output tensor
        API->>Kernel: runGemm<cutlass::half_t or bfloat16_t>(...)
        Kernel->>Kernel: args_from_options (setup CUTLASS arguments)
        Kernel->>Kernel: Allocate workspace
        Kernel->>CUTLASS: gemm.can_implement(arguments)
        CUTLASS-->>Kernel: Status
        Kernel->>CUTLASS: gemm.initialize(arguments, workspace, stream)
        CUTLASS-->>Kernel: Status
        Kernel->>CUTLASS: gemm.run(arguments, workspace, stream)
        CUTLASS-->>Kernel: Status (compute C = alpha * A @ B)
        Kernel-->>API: Return
        API-->>Binding: Return output tensor
        Binding-->>User: Return torch.Tensor
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 5 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 5, 2026

    Greptile's behavior is changing!

    From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

    This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 added the Low precision FP8, FP4, MXFP8, nvFP4 label Jan 5, 2026
    @rdspring1 rdspring1 requested a review from jacobhinkle January 5, 2026 18:10
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Cutlass Low precision FP8, FP4, MXFP8, nvFP4

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants