Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThe PR adds an experimental Conv3D implementation using implicit GEMM with CUDA optimization, FP4 fake quantization, and BF16 tensor cores. The implementation includes comprehensive documentation and Python bindings that dispatch to JIT-compiled CUDA kernels with configurable quantization and activation support. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python User
participant Wrapper as Python Wrapper<br/>(conv3d_implicit_gemm_cuda)
participant Preprocess as Input Validation<br/>& Preprocessing
participant CUDA as CUDA JIT<br/>Compilation
participant Kernel as CUDA Kernel<br/>(implicit_gemm_wmma)
participant Postprocess as Output<br/>Conversion
User->>Wrapper: Call conv3d_implicit_gemm_cuda(x, w, bias, ...)
Wrapper->>Preprocess: Validate tensors & parameters
Preprocess->>Preprocess: Apply padding, reshape weights to matrix
Preprocess->>CUDA: Request JIT-compiled kernel
CUDA->>CUDA: Compile conv3d_implicit_gemm_wmma template
CUDA->>Kernel: Load compiled kernel
Preprocess->>Kernel: Execute kernel on GPU with FP4 quantization
Kernel->>Kernel: Load tiles, quantize A-tile to FP4
Kernel->>Kernel: Perform WMMA operations with BF16
Kernel->>Kernel: Apply optional activation quantization
Kernel->>Postprocess: Return output tensor
Postprocess->>Postprocess: Reshape to [N, Cout, D, H, W]
Postprocess->>User: Return result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #886 +/- ##
=======================================
Coverage 73.74% 73.74%
=======================================
Files 199 199
Lines 21163 21163
=======================================
Hits 15606 15606
Misses 5557 5557 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@experimental/conv/implicit_gemm_cuda.py`:
- Around line 549-560: Add an explicit validation for the fp4_block_size
parameter so unsupported values don't silently use the 256 branch: at the start
of the Python wrapper function that accepts fp4_block_size (the function that
ultimately selects between the two LAUNCH_WMMA_KERNEL branches), check that
fp4_block_size is either 128 or 256 and raise a ValueError with a clear message
if not (e.g., "fp4_block_size must be 128 or 256, got {fp4_block_size}"). Ensure
this validation is performed before any kernel-launch logic or passing
fp4_block_size into the CUDA launch path.
- Around line 669-673: The code currently combines quant_act and act_amax into
do_quant, silently disabling quantization when quant_act is True but act_amax is
None; change this by adding an explicit guard: if quant_act is True and act_amax
is None raise a ValueError (e.g., "act_amax is required when quant_act=True") so
callers are notified, otherwise keep the existing behavior of creating amax_t
when do_quant is True; update the block around the symbols quant_act, act_amax,
do_quant, and amax_t accordingly.
In `@experimental/conv/README.md`:
- Line 76: The README table uses the constant-style name `FP4_BLOCK_SIZE` which
doesn't match the Python function parameter `fp4_block_size`; update the table
entry to use `fp4_block_size` (or explicitly list both forms if you want to
document the env/constant separately) so it matches the function signature and
avoids confusion when calling the function with keyword arguments; locate the
table row that currently shows `FP4_BLOCK_SIZE` and replace it with
`fp4_block_size` (or add a parenthetical note like `fp4_block_size
(FP4_BLOCK_SIZE)` if documenting both).
🧹 Nitpick comments (5)
experimental/conv/implicit_gemm_cuda.py (5)
134-138: Stale template-parameter comments — BLOCK_N is 64, not 32.The comments on lines 135 and 138 say
BLOCK_N = 32andWARPS_N = 2, but every actual instantiation (lines 554, 559) usesBLOCK_N=64, WARPS_N=4. Similarly, the comment on line 423 says64 * 32 * 4 = 8192 byteswhen the real footprint is64 * 64 * 4 = 16384 bytes. The code is correct (it's fully parameterized), but these stale comments will mislead anyone reading the kernel.
255-319: Quantized A-tile load: FP4 block size is implicitly coupled to BLOCK_K.The quantize-dequantize path computes one
block_maxper warp-row (i.e., overBLOCK_Kelements viawarp_reduce_max). This means the FP4 quantization block size is always exactlyBLOCK_K, which only works correctly becauseBLOCK_K == fp4_block_sizefor both supported configs. If a future config changesBLOCK_Kindependently offp4_block_size, quantization granularity will silently break. Worth a brief comment or astatic_assertin the kernel:static_assert(BLOCK_K == 128 || BLOCK_K == 256, "BLOCK_K must match fp4_block_size");
578-591:verbose=Truewill spam build logs on every first invocation.For an experimental module this is fine during development, but consider gating it behind an environment variable or defaulting to
Falseso downstream users don't get unexpected compiler output.
641-646: Input validation uses bareassert, which is stripped underpython -O.The
assertstatements on lines 643 and 646 will be silently removed when Python runs with optimizations enabled. For a CUDA kernel wrapper, invalid shapes reaching the kernel could cause silent corruption or hard crashes. Consider using explicit checks:Proposed fix
- assert x.ndim == 5 and w.ndim == 5 + if x.ndim != 5 or w.ndim != 5: + raise ValueError(f"Expected 5D tensors, got x.ndim={x.ndim}, w.ndim={w.ndim}") n_batch, cin, d, h, w_in = x.shape cout, cin_w, kd, kh, kw = w.shape - assert cin_w == cin + if cin_w != cin: + raise ValueError(f"Input channels mismatch: x has {cin}, w has {cin_w}")
663-667: All inputs are cast to.float()(FP32) — potential unnecessary memory doubling.If inputs are already FP32, the
.float().contiguous()calls are cheap. But if inputs arrive as BF16 (common for the use-case described), this silently doubles memory. The docstring says "BF16 WMMA" but the kernel actually consumes FP32 global-memory inputs and converts to BF16 only in shared memory. This is worth a brief comment so users understand the kernel is not end-to-end BF16 in global memory.
| if (fp4_block_size == 128) { | ||
| // BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WARPS_M=2, WARPS_N=4 | ||
| // 8 warps = 256 threads -> faster cooperative loading | ||
| // WARP_M=32, WARP_N=16, WARP_TILES_M=2, WARP_TILES_N=1 -> 2 mma per warp per K-step | ||
| // Shared: 64*(128+8)*2 + 128*(64+8)*2 = 17,408 + 18,432 = 35,840 bytes (~35KB) | ||
| LAUNCH_WMMA_KERNEL(64, 64, 128, 2, 4) | ||
| } else { | ||
| // BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, WARPS_M=2, WARPS_N=4 | ||
| // 8 warps = 256 threads -> faster cooperative loading | ||
| // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) | ||
| LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4) | ||
| } |
There was a problem hiding this comment.
No validation of fp4_block_size — any unsupported value silently falls through to 256.
If a user accidentally passes fp4_block_size=64 or any value other than 128, the else branch selects the 256 config without warning. An explicit check would prevent silent misconfiguration.
Proposed fix (in the Python wrapper)
Add a check near the top of the Python function:
if fp4_block_size not in (128, 256):
raise ValueError(f"fp4_block_size must be 128 or 256, got {fp4_block_size}")🤖 Prompt for AI Agents
In `@experimental/conv/implicit_gemm_cuda.py` around lines 549 - 560, Add an
explicit validation for the fp4_block_size parameter so unsupported values don't
silently use the 256 branch: at the start of the Python wrapper function that
accepts fp4_block_size (the function that ultimately selects between the two
LAUNCH_WMMA_KERNEL branches), check that fp4_block_size is either 128 or 256 and
raise a ValueError with a clear message if not (e.g., "fp4_block_size must be
128 or 256, got {fp4_block_size}"). Ensure this validation is performed before
any kernel-launch logic or passing fp4_block_size into the CUDA launch path.
| has_bias = bias is not None | ||
| bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) # type: ignore[union-attr] | ||
|
|
||
| do_quant = quant_act and act_amax is not None | ||
| amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr] |
There was a problem hiding this comment.
quant_act=True with act_amax=None silently disables quantization.
Line 672 collapses quant_act and act_amax is not None into do_quant, so a caller who explicitly requests quantization but forgets the scale tensor will get unquantized output with no indication. This is a likely user error, especially since the README states act_amax is "required when quant_act=True."
Consider raising an error instead of silently degrading:
Proposed fix
- do_quant = quant_act and act_amax is not None
- amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr]
+ if quant_act and act_amax is None:
+ raise ValueError("act_amax is required when quant_act=True")
+ do_quant = quant_act
+ amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr]🤖 Prompt for AI Agents
In `@experimental/conv/implicit_gemm_cuda.py` around lines 669 - 673, The code
currently combines quant_act and act_amax into do_quant, silently disabling
quantization when quant_act is True but act_amax is None; change this by adding
an explicit guard: if quant_act is True and act_amax is None raise a ValueError
(e.g., "act_amax is required when quant_act=True") so callers are notified,
otherwise keep the existing behavior of creating amax_t when do_quant is True;
update the block around the symbols quant_act, act_amax, do_quant, and amax_t
accordingly.
| | `dilation` | Convolution dilation `(D, H, W)` | | ||
| | `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | | ||
| | `quant_act` | Enable FP4 fake quantization on activations | | ||
| | `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) | |
There was a problem hiding this comment.
Parameter name mismatch with actual API.
The table lists FP4_BLOCK_SIZE but the Python function signature uses fp4_block_size. This will confuse users trying to call the function with keyword arguments.
-| `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) |
+| `fp4_block_size` | FP4 quantization block size (`128` or `256`) |📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| | `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) | | |
| | `fp4_block_size` | FP4 quantization block size (`128` or `256`) | |
🤖 Prompt for AI Agents
In `@experimental/conv/README.md` at line 76, The README table uses the
constant-style name `FP4_BLOCK_SIZE` which doesn't match the Python function
parameter `fp4_block_size`; update the table entry to use `fp4_block_size` (or
explicitly list both forms if you want to document the env/constant separately)
so it matches the function signature and avoids confusion when calling the
function with keyword arguments; locate the table row that currently shows
`FP4_BLOCK_SIZE` and replace it with `fp4_block_size` (or add a parenthetical
note like `fp4_block_size (FP4_BLOCK_SIZE)` if documenting both).
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Optimized CUDA-based Conv3D Implicit GEMM with FP4 quantization using BF16 WMMA Tensor Cores. |
There was a problem hiding this comment.
Looks like it uses a lot of wmma specific functions. Then it will work on Hopper but not on Blackwell. Is that expected?
There was a problem hiding this comment.
I think it’s fine for fake quant. We can add a note in the README clarifying that it doesn’t work on Blackwell.
What does this PR do?
Type of change: new feature
Overview:
Experimental Conv3D implicit-GEMM CUDA kernel with optional NVFP4-style (E2M1 + FP8 E4M3 scale) fake quantization for activations.
It is intended for research/prototyping and quantization-accuracy experiments only, not production deployment.
The implementation runs as a JIT-compiled PyTorch extension, mirrors conv3d output shape, and provides a quantized and non-quantized path to compare numerical behavior.
There is currently no real quantized production kernel integration in the formal ModelOpt export/compress/runtime stack; this path is kept in experimental/ for fake-quant accuracy validation and benchmarking.
Usage
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation