@@ -74,39 +74,40 @@ def update_component(
7474 component .load_state_dict (state_dict , assign = True )
7575 component .to (device = device , dtype = dtype , non_blocking = True )
7676
77- def load_loras (
77+ def _load_lora_state_dicts (
7878 self ,
79- lora_list : List [Tuple [str , Union [float , LoraConfig ]]],
79+ lora_state_dict_list : List [Tuple [Dict [ str , torch . Tensor ], Union [float , LoraConfig ], str ]],
8080 fused : bool = True ,
8181 save_original_weight : bool = False ,
8282 lora_converter : Optional [LoRAStateDictConverter ] = None ,
8383 ):
8484 if not lora_converter :
8585 lora_converter = self .lora_converter
8686
87- for lora_path , lora_item in lora_list :
87+ for state_dict , lora_item , lora_name in lora_state_dict_list :
8888 if isinstance (lora_item , float ):
8989 lora_scale = lora_item
9090 scheduler_config = None
91- if isinstance (lora_item , LoraConfig ):
91+ elif isinstance (lora_item , LoraConfig ):
9292 lora_scale = lora_item .scale
9393 scheduler_config = lora_item .scheduler_config
94+ else :
95+ raise ValueError (f"lora_item must be float or LoraConfig, got { type (lora_item )} " )
9496
95- logger .info (f"loading lora from { lora_path } with LoraConfig (scale={ lora_scale } )" )
96- state_dict = load_file (lora_path , device = self .device )
97+ logger .info (f"loading lora from state_dict '{ lora_name } ' with scale={ lora_scale } " )
9798
9899 if scheduler_config is not None :
99100 self .apply_scheduler_config (scheduler_config )
100101 logger .info (f"Applied scheduler args from LoraConfig: { scheduler_config } " )
101102
102103 lora_state_dict = lora_converter .convert (state_dict )
103- for model_name , state_dict in lora_state_dict .items ():
104+ for model_name , model_state_dict in lora_state_dict .items ():
104105 model = getattr (self , model_name )
105106 lora_args = []
106- for key , param in state_dict .items ():
107+ for key , param in model_state_dict .items ():
107108 lora_args .append (
108109 {
109- "name" : lora_path ,
110+ "name" : lora_name ,
110111 "key" : key ,
111112 "scale" : lora_scale ,
112113 "rank" : param ["rank" ],
@@ -120,6 +121,26 @@ def load_loras(
120121 )
121122 model .load_loras (lora_args , fused = fused )
122123
124+ def load_loras (
125+ self ,
126+ lora_list : List [Tuple [str , Union [float , LoraConfig ]]],
127+ fused : bool = True ,
128+ save_original_weight : bool = False ,
129+ lora_converter : Optional [LoRAStateDictConverter ] = None ,
130+ ):
131+ lora_state_dict_list = []
132+ for lora_path , lora_item in lora_list :
133+ logger .info (f"loading lora from { lora_path } " )
134+ state_dict = load_file (lora_path , device = self .device )
135+ lora_state_dict_list .append ((state_dict , lora_item , lora_path ))
136+
137+ self ._load_lora_state_dicts (
138+ lora_state_dict_list = lora_state_dict_list ,
139+ fused = fused ,
140+ save_original_weight = save_original_weight ,
141+ lora_converter = lora_converter ,
142+ )
143+
123144 def load_lora (self , path : str , scale : float , fused : bool = True , save_original_weight : bool = False ):
124145 self .load_loras ([(path , scale )], fused , save_original_weight )
125146
0 commit comments