Skip to content

Commit f24e582

Browse files
authored
support load lora from state dict (#206)
1 parent e3cf908 commit f24e582

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,39 +74,40 @@ def update_component(
7474
component.load_state_dict(state_dict, assign=True)
7575
component.to(device=device, dtype=dtype, non_blocking=True)
7676

77-
def load_loras(
77+
def _load_lora_state_dicts(
7878
self,
79-
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
79+
lora_state_dict_list: List[Tuple[Dict[str, torch.Tensor], Union[float, LoraConfig], str]],
8080
fused: bool = True,
8181
save_original_weight: bool = False,
8282
lora_converter: Optional[LoRAStateDictConverter] = None,
8383
):
8484
if not lora_converter:
8585
lora_converter = self.lora_converter
8686

87-
for lora_path, lora_item in lora_list:
87+
for state_dict, lora_item, lora_name in lora_state_dict_list:
8888
if isinstance(lora_item, float):
8989
lora_scale = lora_item
9090
scheduler_config = None
91-
if isinstance(lora_item, LoraConfig):
91+
elif isinstance(lora_item, LoraConfig):
9292
lora_scale = lora_item.scale
9393
scheduler_config = lora_item.scheduler_config
94+
else:
95+
raise ValueError(f"lora_item must be float or LoraConfig, got {type(lora_item)}")
9496

95-
logger.info(f"loading lora from {lora_path} with LoraConfig (scale={lora_scale})")
96-
state_dict = load_file(lora_path, device=self.device)
97+
logger.info(f"loading lora from state_dict '{lora_name}' with scale={lora_scale}")
9798

9899
if scheduler_config is not None:
99100
self.apply_scheduler_config(scheduler_config)
100101
logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
101102

102103
lora_state_dict = lora_converter.convert(state_dict)
103-
for model_name, state_dict in lora_state_dict.items():
104+
for model_name, model_state_dict in lora_state_dict.items():
104105
model = getattr(self, model_name)
105106
lora_args = []
106-
for key, param in state_dict.items():
107+
for key, param in model_state_dict.items():
107108
lora_args.append(
108109
{
109-
"name": lora_path,
110+
"name": lora_name,
110111
"key": key,
111112
"scale": lora_scale,
112113
"rank": param["rank"],
@@ -120,6 +121,26 @@ def load_loras(
120121
)
121122
model.load_loras(lora_args, fused=fused)
122123

124+
def load_loras(
125+
self,
126+
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
127+
fused: bool = True,
128+
save_original_weight: bool = False,
129+
lora_converter: Optional[LoRAStateDictConverter] = None,
130+
):
131+
lora_state_dict_list = []
132+
for lora_path, lora_item in lora_list:
133+
logger.info(f"loading lora from {lora_path}")
134+
state_dict = load_file(lora_path, device=self.device)
135+
lora_state_dict_list.append((state_dict, lora_item, lora_path))
136+
137+
self._load_lora_state_dicts(
138+
lora_state_dict_list=lora_state_dict_list,
139+
fused=fused,
140+
save_original_weight=save_original_weight,
141+
lora_converter=lora_converter,
142+
)
143+
123144
def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
124145
self.load_loras([(path, scale)], fused, save_original_weight)
125146

0 commit comments

Comments
 (0)