Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CorDA as an optional initialization method of LoRA #2231

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

iboing
Copy link

@iboing iboing commented Nov 21, 2024

Existing PEFT methods are mostly agnoistic of the context of a task of concern, e.g., a downstraem task to learn or some pre-trained world knowledge to maintain. In our paper (https://openreview.net/pdf?id=Gi00NVru6n) accepted by NeurIPS 2024, we propose a PEFT method named Corda, context-oriented decomposition adaptation, which builds task-aware LoRA adapters from weight decomposition oriented by the context of the task concerned for task-aware PEFT.

Concretely, CorDA randomly collect a few (usually 256) data samples from a target task, e.g. questions from a QA dataset or instructions to write a code or solve a math problem, and feed these samples into a pre-trained LLM. We can obtain the covariance matrix of the input activation of each linear layer, i.e., $C=XX^T\in\mathcal{R}^{d_{in}\times d_{in}}$, where $X$ is the input of this linear layer. We then perform singular value decomposition (SVD) for the weight $W\in \mathcal{R}^{d_{out}\times d_{in}}$ multiplied by the covariance matrix, i.e., $\verb|SVD|(WC) = U\Sigma V^T$, where $U$ and $V$ are singular vectors and $\Sigma$ is the diagonal matrix with the singular values arranged in descending order. In this way, the context expressed by these representative covariance matrices is able to orientate the decomposition, such that the principal components are most associated with the task of concern. To ensure the same inference result with the pre-trained model at the start of adaptation, we multiply the inverse of these covariance matrices with the decomposed components, $\hat{W}=U\Sigma V^T C^{-1}$, where $\hat{W}$ is the weight after decomposition and reconstruction.

Thanks to the task-awareness, CorDA enables two optional modes, knowledge-preserving adaptation mode (KPM) and instruction-previewed adaptation mode (IPM). In KPM, we use questions from question-answering dataset whose knowledge needs to be preserved, such as TriviaQA and NQopen, to obtain the covariance matrices. After our context-oriented decomposition, we use the components with the smallest $r$ singular values, $U_{[:,-r:]}$, $\Sigma_{[-r:]}$, and $(V^TC^{-1})_{\left[-r:,:\right]}$ to initialize the learnable LoRA adapters $A=\sqrt{\Sigma}_{[-r:]}(V^TC^{-1})_{[-r:,:]}$ and $B=U_{[:,-r:]}\sqrt{\Sigma}_{[-r:]}$. The other components that compact the question-answering ability are frozen during adaptation. KPM enables to learn new tasks effectively while keeping the world knowledge associated with $C$ as sound as possible. Alternatively, when one only aims to achieve performance as high as possible on the finetuning task without concern for world knowledge maintenance, our IPM will be favored. In this mode, CorDA uses the instruction and response from the fine-tuning task (e.g., Math or Code) to produce the covariance matrices. The principal components with large singular values capturing the characteristics of the finetuning task in advance can better accommodate the new ability. So we initialize adapters as $A= \sqrt{\Sigma}_{[:r]} (V^T C^{-1})_{[:r,:]}$ and $B =U_{[:,:r]} \sqrt{\Sigma}_{[:r]}$, and freeze the remaining components. The implementations of KPM and IPM are compared as follows:

Mode Collect covariance from LoRA $A$ LoRA $B$
KPM questions from a knowledge benchmark to maintain $A=\sqrt{\Sigma}_{[-r:]}(V^T C^{-1})_{[-r:,:]}$ $B=U_{[:,-r:]}\sqrt{\Sigma}_{[-r:]}$
IPM instructions and responses from a downstream task to learn $A= \sqrt{\Sigma}_{[:r]} (V^T C^{-1})_{[:r,:]}$ $B =U_{[:,:r]} \sqrt{\Sigma}_{[:r]}$

fig1

Similar to PiSSA, our method is also an initialization method for LoRA adapter with the same LoRA structure, and our IPM also uses the principal components to initialize $A$ and $B$. But PiSSA adopts the naive SVD and does not consider task awareness. Our method builds task-aware adapters, enabling two optional modes to satisfy customized needs:

  • KPM not only achieves better fine-tuning performance than LoRA but also effectively preserves pre-trained knowledge.
  • IPM can further accelerate convergence and enhance the fine-tuning performance on downstream tasks, surpassing PiSSA in our experiments.

Some results are shown below.

1. The superiority of our decomposition method.

fig2

Plain SVD is the normal SVD on pre-trained weights, as adopted by PiSSA. This experiment discards the smallest $r$ components after different decomposition methods and tests the perplexity on Wiki and PTB. It indicates that our decomposition better compacts the associated task abilities into the principal components, and thus can better maintain the corresponding knowledge by freezing them in KPM, or better learn new abilities by adapting them in IPM.

2. Adaptation performance with knowledge-preserving mode (sample from NQopen, fine-tune on Math)

Method Model NQ open GSM8k Math Avg.
Pre-trained Llama-2-7b 14.99 - - -
LoRA Llama-2-7b 1.27 42.68 5.88 16.61
CorDA (KPM) Llama-2-7b 8.20 46.32 7.00 20.51
Pre-trained Llama-2-13b 23.63 - - -
LoRA Llama-2-13b 16.26 57.24 8.92 27.47
CorDA (KPM) Llama-2-13b 19.86 59.29 9.62 29.59
Pre-trained Llama-3-8b 13.41 - - -
LoRA Llama-3-8b 8.75 72.33 24.04 35.04
CorDA (KPM) Llama-3-8b 9.61 74.68 25.34 36.54
Pre-trained Gemma-2-9b 12.85 - - -
LoRA Gemma-2-9b 9.28 83.47 42.30 45.02
CorDA (KPM) Gemma-2-9b 10.17 84.08 42.64 45.63

3. Adaptation performance with instruction-previewed mode (sample from Math, fine-tune on Math)

Method Model GSM8k Math
LoRA Llama-2-7b 42.68 5.88
PiSSA Llama-2-7b 51.63 7.32
CorDA (IPM) Llama-2-7b 53.45 8.64
LoRA Llama-2-13b 57.24 8.92
PiSSA Llama-2-13b 60.88 11.08
CorDA (IPM) Llama-2-13b 62.47 11.54
LoRA Gemma-2-9b 83.47 42.30
PiSSA Gemma-2-9b 84.23 43.52
CorDA (IPM) Gemma-2-9b 84.45 43.88

fig3

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR to add CorDA to PEFT.

I don't have time to do a proper review this week, so I just did a very quick check and added some comments. I'll do a proper review next week.

# if (S!=S).any():
if torch.isnan(S).any() or torch.isinf(S).any():
# print("nan in S")
raise Exception("nan or inf in S")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise a ValueError here and below.

scale_u = torch.linalg.norm(U) / math.sqrt(r)
scale_v = torch.linalg.norm(V) / math.sqrt(r)
logging.info(f"scale_u: {scale_u:.2f}, scale_v: {scale_v:.2f}, svd_error: {svd_error:.2f}")
assert U.size(0) == out_dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use asserts but instead raise proper errors with nice error messages.

@@ -0,0 +1,314 @@
import logging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the license notice.

@5eqn
Copy link

5eqn commented Nov 22, 2024

@BenjaminBossan Thanks for your review! As @iboing 's collaborator, I've revised the code according to your comments.

@iboing
Copy link
Author

iboing commented Nov 23, 2024

@BenjaminBossan Thank you for your review!

@BenjaminBossan
Copy link
Member

Thanks a lot for the updates. The week has been quite busy so far, but I'm hopeful I can give this a proper review tomorrow. Just from reading the description of the method and skimming the code, I have some questions/comments:

  1. The description of the method mentions that it depends on the input data. I could not find in the code where this happens, for instance preprocess_corda does not take any data as input and in the corda_finetuning.py script I also don't see how CorDA incorporates the dataset. Could you please point me to where this happens?
  2. You're using logs throughout the code. In general, we don't do this in PEFT. Could you please remove those calls.
  3. Purely from the description (but I haven't studied the paper yet), this method reminded me of EVA, which recently added to PEFT ([FEAT] New LoRA Initialization Method: Explained Variance Adaptation #2142, https://arxiv.org/abs/2410.07170). Do you know if the two methods are similar? AFAICT, the respective papers don't mention another.

@5eqn
Copy link

5eqn commented Nov 27, 2024

@BenjaminBossan Thanks for your comments! Let me clarify:

I also don't see how CorDA incorporates the dataset. Could you please point me to where this happens?

In corda_finetuning/preprocess.py
calib_loader is the dataset iterator, and we're passing config.corda_config.run_model = lambda: run_model(model, calib_loader), this way CorDA gets a closure that captures a reference to the dataset.

This does not happen in corda_finetuning/corda_finetuning.py because it's the file to do the actual finetuning, but as we've preprocessed CorDA initial weights and residual model, we can train it like standard LoRA models, thus there's no need for dataset.

Edit: we have mentioned this in developer_guides/lora.md and corda_finetuning/README.md, we have comments in lora/config.py as well.

You're using logs throughout the code. In general, we don't do this in PEFT. Could you please remove those calls.

Sure, I've updated the PR so that logs are removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants