Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def _(

# Quantize with the lookup table
code = CODE[quant_type].to(scaled.device).to(scaled.dtype)
# Pad to even length so packing pairs all elements
if scaled.numel() % 2 != 0:
scaled = torch.nn.functional.pad(scaled, (0, 1), value=0.0)
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8)

# Pack two quantized values per byte
Expand All @@ -274,17 +277,20 @@ def _dequantize_4bit_impl(
A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype).to(A.device)
out_dq = code[out_dq]

# Use the actual output size, not the unpacked size (which may include padding)
n = 1
for s in shape:
n *= s
# Trim any extra elements from padding during quantization
out_dq = out_dq[:n]

# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
Expand Down
8 changes: 6 additions & 2 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ def quantize_4bit(

n = A.numel()

# TODO: Support when weight matrix is not divisible by blocksize
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
# Pad to next multiple of blocksize so the kernel always processes full blocks
remainder = n % blocksize
if remainder != 0:
padding = blocksize - remainder
A = torch.nn.functional.pad(A.view(-1), (0, padding), value=0.0)
n = A.numel()

blocks = -(n // -(blocksize * 2))

Expand Down
26 changes: 26 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,32 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize

opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256])
def test_quantize_4bit_not_divisible_by_blocksize(self, device, dtype, quant_type, blocksize):
"""Test quantize/dequantize roundtrip when n_elements is not divisible by blocksize."""
# Shape chosen so numel is NOT divisible by blocksize
shape = (7, blocksize - 1)
A = torch.randn(shape, dtype=dtype, device=device)
storage_dtype = torch.uint8

# Should not raise
packed, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)

assert packed.device == A.device
assert absmax.device == A.device

# Dequantize back and verify shape is preserved
out = torch.ops.bitsandbytes.dequantize_4bit(packed, absmax, blocksize, quant_type, shape, dtype)

assert out.shape == shape
assert out.dtype == dtype

# Verify output is finite (no NaN/Inf)
assert torch.isfinite(out).all(), "Dequantized output contains NaN or Inf"

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
Expand Down