Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import utils as quant_utils
from .calibrator import _Calibrator

__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]
__all__ = ["FP8ScaleSweepCalibrator", "MseCalibrator", "NVFP4MSECalibrator"]


class MseCalibrator(_Calibrator):
Expand Down Expand Up @@ -171,30 +171,49 @@ def compute_amax(self, verbose: bool = False):
return self._amax


class NVFP4MSECalibrator(MseCalibrator):
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""
class FP8ScaleSweepCalibrator(MseCalibrator):
"""MSE calibrator that sweeps 126 valid FP8 E4M3 candidates of ``initial_amax``.

Candidate amax values are ``initial_amax * candidate``
"""

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0


class NVFP4MSECalibrator(FP8ScaleSweepCalibrator):
"""FP8 scale sweep calibrator for NVFP4 per-block static quantization.

Extends :class:`FP8ScaleSweepCalibrator` with a ``global_amax`` that drives the
candidate amax computation: each candidate scales ``global_amax`` uniformly across
all blocks.
"""

def __init__(
self,
amax: torch.Tensor, # per_block_amax shape [num_blocks]
global_amax: torch.Tensor, # scalar
amax: torch.Tensor,
global_amax: torch.Tensor,
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
):
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
"""Initialize NVFP4 calibrator.

Args:
amax: Per-block amax tensor (shape ``[num_blocks]``).
global_amax: Scalar global amax used to scale all FP8 candidates.
axis: Quantization axis. None means per-tensor quantization.
quant_func: Function that quantizes input tensor given an amax value.
error_func: Function to compute error between x and xq.
"""
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
self._global_amax = global_amax

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
if candidates.ndim != 0: # Called during final compute amax
if candidates.ndim != 0: # Called during final compute_amax
candidates = candidates.view_as(self._initial_amax)
return torch.ones_like(self._initial_amax) * self._global_amax * candidates

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0
82 changes: 66 additions & 16 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,31 @@
"awq",
"local_hessian_calibrate",
"max_calibrate",
"register_fp8_sweep_calibrator",
"sequential_calibrate",
"smoothquant",
"svdquant",
]

# Registry for backends that want a custom FP8-sweep calibrator for mse_calibrate().
# Keys are backend name strings; values are FP8ScaleSweepCalibrator (sub)classes.
_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, type] = {}


def register_fp8_sweep_calibrator(backend: str, calibrator_cls: type) -> None:
"""Register a :class:`FP8ScaleSweepCalibrator` subclass for a quantization backend.

When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight quantizer
whose ``backend`` attribute matches a registered key will use the corresponding
calibrator class instead of the default :class:`MseCalibrator`.

Args:
backend: Backend name string (must match ``TensorQuantizer.backend``).
calibrator_cls: A :class:`FP8ScaleSweepCalibrator` subclass whose constructor
accepts ``(amax, axis, quant_func)``.
"""
_FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_cls


def weight_only_quantize(model: nn.Module):
"""Just quantize the weights of the model."""
Expand Down Expand Up @@ -328,21 +348,33 @@ def mse_calibrate(
and module._block_sizes.get("scale_bits") == (4, 3)
)

is_fp8_static_per_block_scales = (
is_nvfp4_static
or getattr(module, "backend", None) in _FP8_SWEEP_CALIBRATOR_REGISTRY
)

if is_nvfp4_static:
# Compute and set global_amax
global_amax = reduce_amax(initial_amax, axis=None)

# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
quant_func=partial(_mse_quant_func, quantizer=module),
)
if fp8_scale_sweep and is_fp8_static_per_block_scales:
if is_nvfp4_static:
module._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
quant_func=partial(_mse_quant_func, quantizer=module),
)
else:
calibrator_cls = _FP8_SWEEP_CALIBRATOR_REGISTRY[module.backend]
module._calibrator = calibrator_cls(
amax=initial_amax,
axis=module._calibrator._axis,
quant_func=partial(_mse_quant_func, quantizer=module),
)
continue

# Create MSE calibrator with quant_func
Expand Down Expand Up @@ -616,20 +648,38 @@ def quant_func(x, amax, quantizer=weight_quantizer):
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
)

is_fp8_static_per_block_scales = (
is_nvfp4_static
or getattr(weight_quantizer, "backend", None) in _FP8_SWEEP_CALIBRATOR_REGISTRY
)

if is_nvfp4_static:
global_amax = reduce_amax(initial_amax, axis=None)
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)

error_func = helper.get_error_func()

if fp8_scale_sweep and is_nvfp4_static:
weight_quantizer._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None,
global_amax=weight_quantizer.global_amax,
quant_func=quant_func,
error_func=error_func,
)
if fp8_scale_sweep and is_fp8_static_per_block_scales:
if is_nvfp4_static:
weight_quantizer._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=weight_quantizer._calibrator._axis
if weight_quantizer._calibrator
else None,
global_amax=weight_quantizer.global_amax,
quant_func=quant_func,
error_func=error_func,
)
else:
calibrator_cls = _FP8_SWEEP_CALIBRATOR_REGISTRY[weight_quantizer.backend]
weight_quantizer._calibrator = calibrator_cls(
amax=initial_amax,
axis=weight_quantizer._calibrator._axis
if weight_quantizer._calibrator
else None,
quant_func=quant_func,
error_func=error_func,
)
else:
weight_quantizer._calibrator = MseCalibrator(
amax=initial_amax,
Expand Down
Loading