Skip to content

Implicit Gemm NVFP4 on Conv3D#886

Open
jingyu-ml wants to merge 4 commits intomainfrom
jingyux/implicit-gemm-nvfp4
Open

Implicit Gemm NVFP4 on Conv3D#886
jingyu-ml wants to merge 4 commits intomainfrom
jingyux/implicit-gemm-nvfp4

Conversation

@jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Feb 13, 2026

What does this PR do?

Type of change: new feature

Overview:

Experimental Conv3D implicit-GEMM CUDA kernel with optional NVFP4-style (E2M1 + FP8 E4M3 scale) fake quantization for activations.

It is intended for research/prototyping and quantization-accuracy experiments only, not production deployment.
The implementation runs as a JIT-compiled PyTorch extension, mirrors conv3d output shape, and provides a quantized and non-quantized path to compare numerical behavior.

There is currently no real quantized production kernel integration in the formal ModelOpt export/compress/runtime stack; this path is kept in experimental/ for fake-quant accuracy validation and benchmarking.

Usage

import torch

from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op

x = torch.randn(1, 128, 21, 60, 106, device="cuda")
w = torch.randn(512, 128, 3, 3, 3, device="cuda")
block_size = 128

# Without FP4 activation quantization (drop-in-style Conv3D call)
out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1))

# Optional block quantization of weights for experiments
w_q = dynamic_block_quantize_op(
    w,
    block_size,
    w.abs().max().unsqueeze(0),
    4,  # num_bits
    2,  # exponent_bits
    8,  # scale_num_bits
    4,  # scale_exponent_bits
)

# With FP4 activation fake quantization
out_q = conv3d_implicit_gemm_cuda(
    x,
    w_q,
    stride=(1, 1, 1),
    padding=(1, 1, 1),
    act_amax=x.abs().max().unsqueeze(0),
    quant_act=True,
    fp4_block_size=block_size,  # 128 or 256
)

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added experimental Conv3D CUDA implementation with implicit GEMM optimization.
    • Supports FP4 quantization and BF16 tensor cores for performance optimization.
    • Includes configurable stride, padding, dilation, and optional activation quantization controls.
  • Documentation

    • Added experimental documentation covering implementation details, API reference, deployment notes, and known limitations.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from a team as a code owner February 13, 2026 08:40
@jingyu-ml jingyu-ml marked this pull request as draft February 13, 2026 08:40
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

The PR adds an experimental Conv3D implementation using implicit GEMM with CUDA optimization, FP4 fake quantization, and BF16 tensor cores. The implementation includes comprehensive documentation and Python bindings that dispatch to JIT-compiled CUDA kernels with configurable quantization and activation support.

Changes

Cohort / File(s) Summary
Documentation
experimental/conv/README.md
Describes the Conv3D implicit GEMM prototype, its experimental status, API surface, usage examples, CUDA JIT-compilation behavior, FP4 fake quantization options, and known limitations.
CUDA Implementation
experimental/conv/implicit_gemm_cuda.py
Introduces conv3d_implicit_gemm_cuda() Python binding orchestrating input validation, padding, reshaping, and dispatching to JIT-compiled CUDA kernels. Implements optimized implicit GEMM kernel with FP4 quantization helpers, BF16 WMMA tensor cores, tile configurations, shared memory management, and optional bias/activation quantization.

Sequence Diagram(s)

sequenceDiagram
    participant User as Python User
    participant Wrapper as Python Wrapper<br/>(conv3d_implicit_gemm_cuda)
    participant Preprocess as Input Validation<br/>& Preprocessing
    participant CUDA as CUDA JIT<br/>Compilation
    participant Kernel as CUDA Kernel<br/>(implicit_gemm_wmma)
    participant Postprocess as Output<br/>Conversion

    User->>Wrapper: Call conv3d_implicit_gemm_cuda(x, w, bias, ...)
    Wrapper->>Preprocess: Validate tensors & parameters
    Preprocess->>Preprocess: Apply padding, reshape weights to matrix
    Preprocess->>CUDA: Request JIT-compiled kernel
    CUDA->>CUDA: Compile conv3d_implicit_gemm_wmma template
    CUDA->>Kernel: Load compiled kernel
    Preprocess->>Kernel: Execute kernel on GPU with FP4 quantization
    Kernel->>Kernel: Load tiles, quantize A-tile to FP4
    Kernel->>Kernel: Perform WMMA operations with BF16
    Kernel->>Kernel: Apply optional activation quantization
    Kernel->>Postprocess: Return output tensor
    Postprocess->>Postprocess: Reshape to [N, Cout, D, H, W]
    Postprocess->>User: Return result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (2 files):

⚔️ modelopt/torch/export/plugins/mcore_nemotron.py (content)
⚔️ modelopt/torch/export/plugins/megatron_importer.py (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly describes the main change: adding an implicit GEMM implementation with NVFP4 support for Conv3D operations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/implicit-gemm-nvfp4
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch jingyux/implicit-gemm-nvfp4
  • Create stacked PR with resolved conflicts
  • Post resolved changes as copyable diffs in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jingyu-ml jingyu-ml self-assigned this Feb 13, 2026
@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.74%. Comparing base (ae69d5d) to head (fcb4571).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #886   +/-   ##
=======================================
  Coverage   73.74%   73.74%           
=======================================
  Files         199      199           
  Lines       21163    21163           
=======================================
  Hits        15606    15606           
  Misses       5557     5557           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml marked this pull request as ready for review February 14, 2026 00:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@experimental/conv/implicit_gemm_cuda.py`:
- Around line 549-560: Add an explicit validation for the fp4_block_size
parameter so unsupported values don't silently use the 256 branch: at the start
of the Python wrapper function that accepts fp4_block_size (the function that
ultimately selects between the two LAUNCH_WMMA_KERNEL branches), check that
fp4_block_size is either 128 or 256 and raise a ValueError with a clear message
if not (e.g., "fp4_block_size must be 128 or 256, got {fp4_block_size}"). Ensure
this validation is performed before any kernel-launch logic or passing
fp4_block_size into the CUDA launch path.
- Around line 669-673: The code currently combines quant_act and act_amax into
do_quant, silently disabling quantization when quant_act is True but act_amax is
None; change this by adding an explicit guard: if quant_act is True and act_amax
is None raise a ValueError (e.g., "act_amax is required when quant_act=True") so
callers are notified, otherwise keep the existing behavior of creating amax_t
when do_quant is True; update the block around the symbols quant_act, act_amax,
do_quant, and amax_t accordingly.

In `@experimental/conv/README.md`:
- Line 76: The README table uses the constant-style name `FP4_BLOCK_SIZE` which
doesn't match the Python function parameter `fp4_block_size`; update the table
entry to use `fp4_block_size` (or explicitly list both forms if you want to
document the env/constant separately) so it matches the function signature and
avoids confusion when calling the function with keyword arguments; locate the
table row that currently shows `FP4_BLOCK_SIZE` and replace it with
`fp4_block_size` (or add a parenthetical note like `fp4_block_size
(FP4_BLOCK_SIZE)` if documenting both).
🧹 Nitpick comments (5)
experimental/conv/implicit_gemm_cuda.py (5)

134-138: Stale template-parameter comments — BLOCK_N is 64, not 32.

The comments on lines 135 and 138 say BLOCK_N = 32 and WARPS_N = 2, but every actual instantiation (lines 554, 559) uses BLOCK_N=64, WARPS_N=4. Similarly, the comment on line 423 says 64 * 32 * 4 = 8192 bytes when the real footprint is 64 * 64 * 4 = 16384 bytes. The code is correct (it's fully parameterized), but these stale comments will mislead anyone reading the kernel.


255-319: Quantized A-tile load: FP4 block size is implicitly coupled to BLOCK_K.

The quantize-dequantize path computes one block_max per warp-row (i.e., over BLOCK_K elements via warp_reduce_max). This means the FP4 quantization block size is always exactly BLOCK_K, which only works correctly because BLOCK_K == fp4_block_size for both supported configs. If a future config changes BLOCK_K independently of fp4_block_size, quantization granularity will silently break. Worth a brief comment or a static_assert in the kernel:

static_assert(BLOCK_K == 128 || BLOCK_K == 256, "BLOCK_K must match fp4_block_size");

578-591: verbose=True will spam build logs on every first invocation.

For an experimental module this is fine during development, but consider gating it behind an environment variable or defaulting to False so downstream users don't get unexpected compiler output.


641-646: Input validation uses bare assert, which is stripped under python -O.

The assert statements on lines 643 and 646 will be silently removed when Python runs with optimizations enabled. For a CUDA kernel wrapper, invalid shapes reaching the kernel could cause silent corruption or hard crashes. Consider using explicit checks:

Proposed fix
-    assert x.ndim == 5 and w.ndim == 5
+    if x.ndim != 5 or w.ndim != 5:
+        raise ValueError(f"Expected 5D tensors, got x.ndim={x.ndim}, w.ndim={w.ndim}")
     n_batch, cin, d, h, w_in = x.shape
     cout, cin_w, kd, kh, kw = w.shape
-    assert cin_w == cin
+    if cin_w != cin:
+        raise ValueError(f"Input channels mismatch: x has {cin}, w has {cin_w}")

663-667: All inputs are cast to .float() (FP32) — potential unnecessary memory doubling.

If inputs are already FP32, the .float().contiguous() calls are cheap. But if inputs arrive as BF16 (common for the use-case described), this silently doubles memory. The docstring says "BF16 WMMA" but the kernel actually consumes FP32 global-memory inputs and converts to BF16 only in shared memory. This is worth a brief comment so users understand the kernel is not end-to-end BF16 in global memory.

Comment on lines +549 to +560
if (fp4_block_size == 128) {
// BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WARPS_M=2, WARPS_N=4
// 8 warps = 256 threads -> faster cooperative loading
// WARP_M=32, WARP_N=16, WARP_TILES_M=2, WARP_TILES_N=1 -> 2 mma per warp per K-step
// Shared: 64*(128+8)*2 + 128*(64+8)*2 = 17,408 + 18,432 = 35,840 bytes (~35KB)
LAUNCH_WMMA_KERNEL(64, 64, 128, 2, 4)
} else {
// BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, WARPS_M=2, WARPS_N=4
// 8 warps = 256 threads -> faster cooperative loading
// Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB)
LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

No validation of fp4_block_size — any unsupported value silently falls through to 256.

If a user accidentally passes fp4_block_size=64 or any value other than 128, the else branch selects the 256 config without warning. An explicit check would prevent silent misconfiguration.

Proposed fix (in the Python wrapper)

Add a check near the top of the Python function:

if fp4_block_size not in (128, 256):
    raise ValueError(f"fp4_block_size must be 128 or 256, got {fp4_block_size}")
🤖 Prompt for AI Agents
In `@experimental/conv/implicit_gemm_cuda.py` around lines 549 - 560, Add an
explicit validation for the fp4_block_size parameter so unsupported values don't
silently use the 256 branch: at the start of the Python wrapper function that
accepts fp4_block_size (the function that ultimately selects between the two
LAUNCH_WMMA_KERNEL branches), check that fp4_block_size is either 128 or 256 and
raise a ValueError with a clear message if not (e.g., "fp4_block_size must be
128 or 256, got {fp4_block_size}"). Ensure this validation is performed before
any kernel-launch logic or passing fp4_block_size into the CUDA launch path.

Comment on lines +669 to +673
has_bias = bias is not None
bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) # type: ignore[union-attr]

do_quant = quant_act and act_amax is not None
amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

quant_act=True with act_amax=None silently disables quantization.

Line 672 collapses quant_act and act_amax is not None into do_quant, so a caller who explicitly requests quantization but forgets the scale tensor will get unquantized output with no indication. This is a likely user error, especially since the README states act_amax is "required when quant_act=True."

Consider raising an error instead of silently degrading:

Proposed fix
-    do_quant = quant_act and act_amax is not None
-    amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device)  # type: ignore[union-attr]
+    if quant_act and act_amax is None:
+        raise ValueError("act_amax is required when quant_act=True")
+    do_quant = quant_act
+    amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device)  # type: ignore[union-attr]
🤖 Prompt for AI Agents
In `@experimental/conv/implicit_gemm_cuda.py` around lines 669 - 673, The code
currently combines quant_act and act_amax into do_quant, silently disabling
quantization when quant_act is True but act_amax is None; change this by adding
an explicit guard: if quant_act is True and act_amax is None raise a ValueError
(e.g., "act_amax is required when quant_act=True") so callers are notified,
otherwise keep the existing behavior of creating amax_t when do_quant is True;
update the block around the symbols quant_act, act_amax, do_quant, and amax_t
accordingly.

| `dilation` | Convolution dilation `(D, H, W)` |
| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) |
| `quant_act` | Enable FP4 fake quantization on activations |
| `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Parameter name mismatch with actual API.

The table lists FP4_BLOCK_SIZE but the Python function signature uses fp4_block_size. This will confuse users trying to call the function with keyword arguments.

-| `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) |
+| `fp4_block_size` | FP4 quantization block size (`128` or `256`) |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) |
| `fp4_block_size` | FP4 quantization block size (`128` or `256`) |
🤖 Prompt for AI Agents
In `@experimental/conv/README.md` at line 76, The README table uses the
constant-style name `FP4_BLOCK_SIZE` which doesn't match the Python function
parameter `fp4_block_size`; update the table entry to use `fp4_block_size` (or
explicitly list both forms if you want to document the env/constant separately)
so it matches the function signature and avoids confusion when calling the
function with keyword arguments; locate the table row that currently shows
`FP4_BLOCK_SIZE` and replace it with `fp4_block_size` (or add a parenthetical
note like `fp4_block_size (FP4_BLOCK_SIZE)` if documenting both).

# See the License for the specific language governing permissions and
# limitations under the License.

"""Optimized CUDA-based Conv3D Implicit GEMM with FP4 quantization using BF16 WMMA Tensor Cores.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it uses a lot of wmma specific functions. Then it will work on Hopper but not on Blackwell. Is that expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s fine for fake quant. We can add a note in the README clarifying that it doesn’t work on Blackwell.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants