LoRA 原理和 PyTorch 代码实现
2024年11月9日大约 4 分钟
背景
无论是火热的大模型(LLM)还是文生图模型(Stable Diffusion)微调的时候,都需要大量的GPU显存,个人的显卡上很难实现,因此各种参数高效(Parameter-Efficient)的方法层出不穷,最受大家欢迎的就是 LoRA《LoRA: Low-Rank Adaptation of Large Language Models》。
LoRA 有很多的优点,节约显存,训练快,效果损失较小(相对于全参数微调),推理的时候不增加耗时,可以做一个插入式组件使用。缺点当然也有,那就是还是会有一些效果的损失(笑)。
减少显存占用的主要原因是训练参数变小了(比如只对 qkv 层做 LoRA)
核心原理
核心原理非常的简单,任意一个矩阵
最终在训练计算的时候是
其中
- 为什么说只优化 AB 两个矩阵就可以了呢?这里面的假设是什么?
不是满秩的,里面有大量参数是冗余的,那么其实可以用更接近满秩的矩阵 AB 代替。
矩阵都可以表示为若干个线性无关向量,最大的线性无关向量个数就是秩
PyTorch 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LinearLoRALayer(nn.Module):
def __init__(self,
in_features,
out_features,
merge=False,
rank=8,
lora_alpha=16,
dropout=0.1,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.merge = merge
self.rank = rank
# linear weight 的 Shape 是 (out_features, in_features), 正确的做法是 xW^T
self.linear = nn.Linear(in_features, out_features)
# 这里非常的重要,这里是实现的小细节
if rank > 0:
# 这里是为了标记 lora_a 和 lora_b 是可训练的参数
self.lora_a = nn.Parameter(
torch.zeros(out_features, rank)
)
# lora_a 需要初始化为 高斯分布
nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
self.lora_b = nn.Parameter(
torch.zeros(rank, in_features)
)
self.scale = lora_alpha / rank
# linear 需要设置为不可以训练
self.linear.weight.requires_grad = False
self.dropout = nn.Dropout(
dropout
) if dropout > 0 else nn.Identity()
# 如果采用 merge 进行推理,
# 那么会把 lora_a 和 lora_b 两个小矩阵的参数直接放到 linear.weight 中
if merge:
self.merge_weight()
def forward(self, X):
# X shape is (batch, seq_len, in_feature)
# lora_a 是 out_features * rank
if self.rank > 0 and not self.merge:
output = self.linear(X) + self.scale * ( X @ (self.lora_a @ self.lora_b).T )
elif self.rank > 0 and self.merge:
output = self.linear(X)
else:
output = self.linear(X)
return self.dropout(output)
def merge_weight(self, ):
if self.merge and self.rank > 0:
self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)
def unmerge_weight(self, ):
if self.rank > 0:
self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)
# 写一段测试代码
# Test the LoRALinear layer
batch_size = 32
seq_len = 128
in_features = 768
out_features = 512
rank = 8
lora_alpha = 16
dropout = 0.1
# Create a test input
x = torch.randn(batch_size, seq_len, in_features)
# Test regular mode (no merge)
lora_layer = LinearLoRALayer(
in_features=in_features,
out_features=out_features,
rank=rank,
lora_alpha=lora_alpha,
dropout=dropout,
merge=False
)
# Forward pass
output = lora_layer(x)
print(f"Output shape (no merge): {output.shape}") # Should be [batch_size, seq_len, out_features]
# Test merged mode
lora_layer_merged = LinearLoRALayer(
in_features=in_features,
out_features=out_features,
rank=rank,
lora_alpha=lora_alpha,
dropout=dropout,
merge=True
)
# Forward pass with merged weights
output_merged = lora_layer_merged(x)
print(f"Output shape (merged): {output_merged.shape}") # Should be [batch_size, seq_len, out_features]
# Test weight merging/unmerging
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)
print("Max difference after merge/unmerge cycle:",
torch.max(torch.abs(output - output_after_unmerge)).item())
- Q: 大模型的 LoRA 实现真的这么简单吗?
- A: 原理是这么简单,但是实际实现过程中因为层很多,会有一些配置,比如 QKV 层做 LoRA 还是 FFN 层做 LoRA,这些都会增加代码的复杂性,但是核心原理就是上面的代码。
References
感兴趣可以阅读我的其他文章:
这里和PCA,SVD 有一些差别。前者是为了据降维/压缩,后者仅仅是为了学习低秩的矩阵(参数可以更新改变) ↩︎