feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108
feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108chen2021673 wants to merge 5 commits intomasterfrom
Conversation
- Add LoRA module infrastructure with configurable rank, alpha, dropout - Implement LoRALinear wrapper for seamless integration with Linear layers - Support tensor parallelism via LoRAParallelLinear - Add LoRAModel utility for managing multiple LoRA layers - Integrate LoRA configuration and utilities - Add GPT2 example demonstrating LoRA fine-tuning - Include comprehensive usage documentation and test suite Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Refactor LoRA config construction with proper target module parsing - Add GetLoRAModel for in-place LoRA layer injection - Fix DDP reducer to correctly handle LoRA parameters
- Fix RowParallel/ColumnParallel LoRA input handling to match base module behavior - Add shape-based defensive checks for TP/SP consistency - Move TP/SP communication helper function declarations to utils.h - Move getter implementations from header to .cc file - Add unit test for SaveLoRAWeights/LoadLoRAWeights functionality Co-Authored-By: Claude Opus 4.6 <[email protected]>
|
|
||
| std::vector<std::shared_ptr<Tensor>> LoRAModel::TrainableParameters() const { return GetLoRAParameters(base_model_); } | ||
|
|
||
| std::vector<std::shared_ptr<Tensor>> LoRAModel::Parameters() const { return base_model_->Parameters(); } |
There was a problem hiding this comment.
这块原则上 base_model 的 Parameters() 的行为已经被改写过了(调用栈:FreezeBaseModel->GetLoRAParameters->Module::LoRAParameters()),最好加个 NOTE 说明下行为跟 naive 版本的 Module::Parameters() 不太一样
| // 3. Add LoRA contribution to base output | ||
| // Both should now have the same sequence dimension | ||
| auto output = base_output->Add(scaled_lora); | ||
|
|
There was a problem hiding this comment.
对比 ColumnParallelLinear::Forward,似乎少了一步 allgather 操作:GatherFromTPRegionFunc()。
目前仅有的几个测例的正确性上应该不影响,因为后续默认都会跟着 RowParallelLinear,这种情况下 gather_output=false。
| // Freeze base model parameters | ||
| FreezeBaseModel(base_model_); | ||
|
|
||
| LOG(INFO) << "LoRAModel created with rank=" << config_.rank << ", alpha=" << config_.alpha; |
There was a problem hiding this comment.
base_model_ 是不是得注册到 modules_ 里面(DDP 里就用了 modules_[kModuleName] = std::move(module);),不然可能会导致 To(dtype) 或者 To(device) 操作出问题,现在似乎是必须 base_model 先完成 To(dtype) 或者 To(device) 了再用 LoRAModel 来包,否则先包 LoraModel 再 To() 的话会导致 frozen 的 params 不受影响
| // LoRA A: [rank, in_features] - replicated across TP ranks (implemented as Linear) | ||
| // LoRA B: [out_features_per_partition, rank] - sharded like base weight (implemented as ColumnParallelLinear with | ||
| // gather_output) | ||
| class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLinear> { |
There was a problem hiding this comment.
感觉是不是可以继承自原 ColumnParallelLinear,篇幅上可以省一些基类的成员定义和 getter
| // Weight shape: [out_features, in_features_per_partition] | ||
| // LoRA A: [rank, in_features_per_partition] - sharded like base weight (implemented as RowParallelLinear with | ||
| // input_is_parallel) LoRA B: [out_features, rank] - replicated (implemented as Linear) | ||
| class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear> { |
| continue; | ||
| } | ||
|
|
||
| if (type == Linear::kType) { |
There was a problem hiding this comment.
这个文件里从这里开始有比较多这种三个 if 判断,但实际上就是一个 class name 的差异的代码,感觉可以采取一些更优雅的写法
Summary
Added LoRA (Low-Rank Adaptation) support for parameter-efficient fine-tuning. This feature significantly reduces the number of trainable parameters through low-rank decomposition, enabling efficient fine-tuning of large models.
Changes
New Features
LoRA Infrastructure (
infini_train/include/nn/lora/):lora_config.h/cc- LoRA configuration (rank, alpha, dropout)lora_linear.h/cc- LoRA linear layer wrapperlora_model.h/cc- Multi-LoRA layer managementlora_parallel_linear.h/cc- Tensor parallelism supportlora_utils.h/cc- Utility functionsTests:
test/lora/test_lora.cc- Unit testsDocumentation:
docs/lora_usage.md- Usage documentationExamples:
example/gpt2/main.cc- Added LoRA training exampleBuild:
CMakeLists.txt- Added test_lora build targetTest Result
精度:



性能:
llama3 运行结果对比: