[minor] add a general FP8ScaleSweepCalibrator and its registry#1171
[minor] add a general FP8ScaleSweepCalibrator and its registry#1171
Conversation
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughThe changes introduce a new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
🧹 Nitpick comments (3)
modelopt/torch/quantization/model_calib.py (2)
63-80: Consider adding overwrite warning and type validation for consistency.Two observations:
- Unlike
register_quant_backend(which warns when overwriting an existing backend), this silently overwrites. For consistency and to help users debug accidental overwrites, consider adding a warning.- The type hint is
typerather thantype[FP8ScaleSweepCalibrator]. Adding a runtimeissubclasscheck would catch incorrect registrations early.♻️ Suggested improvement
+import warnings +from .calib import FP8ScaleSweepCalibrator + def register_fp8_sweep_calibrator(backend: str, calibrator_cls: type) -> None: """Register a :class:`FP8ScaleSweepCalibrator` subclass for a quantization backend. ... """ + if not issubclass(calibrator_cls, FP8ScaleSweepCalibrator): + raise TypeError( + f"calibrator_cls must be a subclass of FP8ScaleSweepCalibrator, got {calibrator_cls}" + ) + if backend in _FP8_SWEEP_CALIBRATOR_REGISTRY: + warnings.warn(f"Overwriting existing FP8 sweep calibrator for backend: {backend}") _FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_clsThis aligns with the existing
register_quant_backendpattern shown intensor_quantizer.py:82-99.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 63 - 80, The register_fp8_sweep_calibrator function currently overwrites entries silently and accepts a generic type; change it to mirror register_quant_backend by (1) validating that calibrator_cls is a subclass of FP8ScaleSweepCalibrator via issubclass(calibrator_cls, FP8ScaleSweepCalibrator) and raising a TypeError if not, and (2) emitting a warning (e.g., using warnings.warn) when _FP8_SWEEP_CALIBRATOR_REGISTRY already contains the backend key before overwriting; update the function signature/type hint to accept type[FP8ScaleSweepCalibrator] to reflect the enforced type.
651-682: Code is correct; consider extracting shared calibrator selection logic.This block mirrors the calibrator selection in
mse_calibrate()(lines 351-378). Both paths:
- Check
is_nvfp4_staticvs registry-backed backends- Instantiate the appropriate calibrator class
A small helper (e.g.,
_create_fp8_sweep_calibrator(...)) could reduce duplication, but this is optional given the clarity of the current implementation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 651 - 682, The calibrator selection/instantiation logic is duplicated between this block and mse_calibrate(); extract a helper (e.g., _create_fp8_sweep_calibrator) that accepts (weight_quantizer, initial_amax, quant_func, error_func, is_nvfp4_static) and returns the configured calibrator instance; inside the helper, if is_nvfp4_static return NVFP4MSECalibrator configured with amax=initial_amax, axis=weight_quantizer._calibrator._axis if present, global_amax=weight_quantizer.global_amax and quant_func/error_func; otherwise lookup calibrator_cls = _FP8_SWEEP_CALIBRATOR_REGISTRY[weight_quantizer.backend] and instantiate with amax, axis, quant_func and error_func; replace the duplicated instantiation code in both this block and mse_calibrate() to call the new helper and assign to weight_quantizer._calibrator.modelopt/torch/quantization/calib/mse.py (1)
174-185: Add a named constant for the FP8 E4M3 maximum value and document inherited unused parameters.The magic number
448.0represents the maximum representable value in FP8 E4M3 format and should be defined as a module-level constant for clarity and maintainability, following the pattern used elsewhere in the codebase (e.g.,mxfp8_tensor.py).Additionally,
FP8ScaleSweepCalibratorinherits__init__fromMseCalibrator, which acceptsstep_size,start_multiplier, andstop_multiplierparameters. Since_generate_candidatesis overridden in this class, these parameters are silently ignored. Consider either overriding__init__to explicitly exclude these parameters or adding a note to the class docstring that these inherited parameters are not used.♻️ Suggested improvement
+_FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3 format + class FP8ScaleSweepCalibrator(MseCalibrator): """MSE calibrator that sweeps 126 valid FP8 E4M3 candidates of ``initial_amax``. Candidate amax values are ``initial_amax * candidate`` + + Note: + The ``step_size``, ``start_multiplier``, and ``stop_multiplier`` parameters + inherited from :class:`MseCalibrator` are not used by this calibrator. """ def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" + """Generate 126 valid FP8 E4M3 scale candidates (all nonzero finite values / max).""" uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) - return fp8_values[valid_mask] / 448.0 + return fp8_values[valid_mask] / _FP8_E4M3_MAX🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/calib/mse.py` around lines 174 - 185, Replace the magic literal 448.0 with a module-level constant named FP8_E4M3_MAX (or similar) and use that constant in FP8ScaleSweepCalibrator._generate_candidates so the maximum representable FP8 E4M3 value is documented and reusable (matching the pattern in mxfp8_tensor.py); also update the FP8ScaleSweepCalibrator class docstring to note that the inherited MseCalibrator __init__ parameters (step_size, start_multiplier, stop_multiplier) are intentionally unused by this subclass, or alternatively override __init__ in FP8ScaleSweepCalibrator to explicitly accept and ignore those params so their omission is explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 174-185: Replace the magic literal 448.0 with a module-level
constant named FP8_E4M3_MAX (or similar) and use that constant in
FP8ScaleSweepCalibrator._generate_candidates so the maximum representable FP8
E4M3 value is documented and reusable (matching the pattern in mxfp8_tensor.py);
also update the FP8ScaleSweepCalibrator class docstring to note that the
inherited MseCalibrator __init__ parameters (step_size, start_multiplier,
stop_multiplier) are intentionally unused by this subclass, or alternatively
override __init__ in FP8ScaleSweepCalibrator to explicitly accept and ignore
those params so their omission is explicit.
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 63-80: The register_fp8_sweep_calibrator function currently
overwrites entries silently and accepts a generic type; change it to mirror
register_quant_backend by (1) validating that calibrator_cls is a subclass of
FP8ScaleSweepCalibrator via issubclass(calibrator_cls, FP8ScaleSweepCalibrator)
and raising a TypeError if not, and (2) emitting a warning (e.g., using
warnings.warn) when _FP8_SWEEP_CALIBRATOR_REGISTRY already contains the backend
key before overwriting; update the function signature/type hint to accept
type[FP8ScaleSweepCalibrator] to reflect the enforced type.
- Around line 651-682: The calibrator selection/instantiation logic is
duplicated between this block and mse_calibrate(); extract a helper (e.g.,
_create_fp8_sweep_calibrator) that accepts (weight_quantizer, initial_amax,
quant_func, error_func, is_nvfp4_static) and returns the configured calibrator
instance; inside the helper, if is_nvfp4_static return NVFP4MSECalibrator
configured with amax=initial_amax, axis=weight_quantizer._calibrator._axis if
present, global_amax=weight_quantizer.global_amax and quant_func/error_func;
otherwise lookup calibrator_cls =
_FP8_SWEEP_CALIBRATOR_REGISTRY[weight_quantizer.backend] and instantiate with
amax, axis, quant_func and error_func; replace the duplicated instantiation code
in both this block and mse_calibrate() to call the new helper and assign to
weight_quantizer._calibrator.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6a31f4b3-e4ad-4edc-b8f6-48c7553b2e41
📒 Files selected for processing (2)
modelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1171 +/- ##
===========================================
- Coverage 70.20% 54.51% -15.69%
===========================================
Files 230 348 +118
Lines 26098 39789 +13691
===========================================
+ Hits 18322 21692 +3370
- Misses 7776 18097 +10321
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Refactor