Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
scb_name = "SCB"

# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, scb_name)
param_from_weight = getattr(self.weight, scb_name, None)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name)

Expand Down Expand Up @@ -1095,15 +1095,16 @@ def _load_from_state_dict(
for key in unexpected_copy:
input_name = key[len(prefix) :]
if input_name == "SCB":
if self.weight.SCB is None:
weight_scb = getattr(self.weight, "SCB", None)
if weight_scb is None:
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()",
)

input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
weight_scb.copy_(input_param)

if self.state.SCB is not None:
self.state.SCB = self.weight.SCB
Expand Down Expand Up @@ -1133,7 +1134,7 @@ def to(self, *args, **kwargs):

def forward(self, x: torch.Tensor):
self.state.is_training = self.training
if self.weight.CB is not None:
if getattr(self.weight, "CB", None) is not None:
self.init_8bit_state()

# weights are cast automatically as Int8Params, but the bias has to be cast manually
Expand All @@ -1142,7 +1143,7 @@ def forward(self, x: torch.Tensor):

out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

if not self.state.has_fp16_weights and self.state.CB is not None:
if not self.state.has_fp16_weights and self.state.CB is not None and hasattr(self.weight, "CB"):
self.weight.data = self.state.CB

return out
Expand Down