From 1c96a34e7d7c846b4b1ed038f810b2e32fdfb7c7 Mon Sep 17 00:00:00 2001 From: badaoui Date: Tue, 3 Mar 2026 12:04:52 +0000 Subject: [PATCH 1/3] fix when A.numel() not divisibel by blocksize --- bitsandbytes/backends/triton/ops.py | 8 ++++++-- tests/test_ops.py | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 3a16961fa..4b1444b35 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -76,8 +76,12 @@ def quantize_4bit( n = A.numel() - # TODO: Support when weight matrix is not divisible by blocksize - # torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") + # Pad to next multiple of blocksize so the kernel always processes full blocks + remainder = n % blocksize + if remainder != 0: + padding = blocksize - remainder + A = torch.nn.functional.pad(A.view(-1), (0, padding), value=0.0) + n = A.numel() blocks = -(n // -(blocksize * 2)) diff --git a/tests/test_ops.py b/tests/test_ops.py index c5a439eaf..005084c52 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -172,6 +172,32 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype)) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256]) + def test_quantize_4bit_not_divisible_by_blocksize(self, device, dtype, quant_type, blocksize): + """Test quantize/dequantize roundtrip when n_elements is not divisible by blocksize.""" + # Shape chosen so numel is NOT divisible by blocksize + shape = (7, blocksize - 1) + A = torch.randn(shape, dtype=dtype, device=device) + storage_dtype = torch.uint8 + + # Should not raise + packed, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype) + + assert packed.device == A.device + assert absmax.device == A.device + + # Dequantize back and verify shape is preserved + out = torch.ops.bitsandbytes.dequantize_4bit(packed, absmax, blocksize, quant_type, shape, dtype) + + assert out.shape == shape + assert out.dtype == dtype + + # Verify output is finite (no NaN/Inf) + assert torch.isfinite(out).all(), "Dequantized output contains NaN or Inf" + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) From 01d060dc456a9549b8946dc51a696fa75112e60f Mon Sep 17 00:00:00 2001 From: badaoui Date: Tue, 3 Mar 2026 12:35:34 +0000 Subject: [PATCH 2/3] fix --- bitsandbytes/backends/default/ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 707aeb3c3..eb0c4e0d3 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -248,6 +248,9 @@ def _( # Quantize with the lookup table code = CODE[quant_type].to(scaled.device).to(scaled.dtype) + # Pad to even length so packing pairs all elements + if scaled.numel() % 2 != 0: + scaled = torch.nn.functional.pad(scaled, (0, 1), value=0.0) quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8) # Pack two quantized values per byte From 98f3a3420a94edc64f24ef04ebbd021be79f0348 Mon Sep 17 00:00:00 2001 From: badaoui Date: Tue, 3 Mar 2026 13:13:25 +0000 Subject: [PATCH 3/3] another one --- bitsandbytes/backends/default/ops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index eb0c4e0d3..78ba818ca 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -277,17 +277,20 @@ def _dequantize_4bit_impl( A = A.reshape(-1) # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() out_dq[1::2] = A & 0xF out_dq[::2] = A >> 4 # code is fp32, cast to dtype to avoid the mismatch issue code = CODE[quant_type].to(dtype).to(A.device) out_dq = code[out_dq] + # Use the actual output size, not the unpacked size (which may include padding) + n = 1 + for s in shape: + n *= s + # Trim any extra elements from padding during quantization + out_dq = out_dq[:n] + # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 rem = n % blocksize