From 9ec24863123a6309539d380ba896f12f1eba99d0 Mon Sep 17 00:00:00 2001 From: Jeff Yang_Cin Date: Tue, 30 May 2023 15:13:07 +0900 Subject: [PATCH 1/2] Add 4-bit quantization and QLoRA --- README.md | 14 ++++++++++++-- finetune.py | 26 +++++++++++++++++--------- generate.py | 18 ++++++++++++------ requirements.txt | 6 +++--- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 744060dc..c137e7d9 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ - 🤗 **Try the pretrained model out [here](https://huggingface.co/spaces/tloen/alpaca-lora), courtesy of a GPU grant from Huggingface!** - Users have created a Discord server for discussion and support [here](https://discord.gg/prbq284xX5) - 4/14: Chansung Park's GPT4-Alpaca adapters: https://github.com/tloen/alpaca-lora/issues/340 +- 5/30: 4-bit quantization and QLoRA: https://github.com/tloen/alpaca-lora/issues/486 This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf). We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research), @@ -32,10 +33,11 @@ This file contains a straightforward application of PEFT to the LLaMA model, as well as some code related to prompt construction and tokenization. PRs adapting this code to support larger models are always welcome. -Example usage: +Example usage with 4-bit quantization: ```bash python finetune.py \ + --load_in_4bit \ --base_model 'decapoda-research/llama-7b-hf' \ --data_path 'yahma/alpaca-cleaned' \ --output_dir './lora-alpaca' @@ -66,7 +68,7 @@ python finetune.py \ This file reads the foundation model from the Hugging Face model hub and the LoRA weights from `tloen/alpaca-lora-7b`, and runs a Gradio interface for inference on a specified input. Users should treat this as example code for the use of the model, and modify it as needed. -Example usage: +Example usage 8-bit: ```bash python generate.py \ @@ -75,6 +77,14 @@ python generate.py \ --lora_weights 'tloen/alpaca-lora-7b' ``` +Example usage 4-bit: +```bash +python generate.py \ + --load_4bit \ + --base_model 'decapoda-research/llama-7b-hf' \ + --lora_weights 'tloen/alpaca-lora-7b' +``` + ### Official weights The most recent "official" Alpaca-LoRA adapter available at [`tloen/alpaca-lora-7b`](https://huggingface.co/tloen/alpaca-lora-7b) was trained on March 26 with the following command: diff --git a/finetune.py b/finetune.py index 0e74641d..e83e9a15 100644 --- a/finetune.py +++ b/finetune.py @@ -17,10 +17,10 @@ LoraConfig, get_peft_model, get_peft_model_state_dict, - prepare_model_for_int8_training, + prepare_model_for_kbit_training, set_peft_model_state_dict, ) -from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig from utils.prompter import Prompter @@ -56,6 +56,7 @@ def train( wandb_log_model: str = "", # options: false | true resume_from_checkpoint: str = None, # either training checkpoint or final adapter prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. + load_in_4bit: bool = False, #using 4bit quantization ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( @@ -82,6 +83,7 @@ def train( f"wandb_log_model: {wandb_log_model}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" f"prompt template: {prompt_template_name}\n" + f"load_in_4bit: {load_in_4bit}\n" ) assert ( base_model @@ -108,14 +110,20 @@ def train( os.environ["WANDB_WATCH"] = wandb_watch if len(wandb_log_model) > 0: os.environ["WANDB_LOG_MODEL"] = wandb_log_model - + + load_in_8bit = True if not load_in_4bit else False + bnb_config = BitsAndBytesConfig( + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=True, - torch_dtype=torch.float16, - device_map=device_map, + base_model, quantization_config=bnb_config, torch_dtype=torch.float16, device_map=device_map ) + tokenizer = LlamaTokenizer.from_pretrained(base_model) tokenizer.pad_token_id = ( @@ -171,7 +179,7 @@ def generate_and_tokenize_prompt(data_point): ] # could be sped up, probably return tokenized_full_prompt - model = prepare_model_for_int8_training(model) + model = prepare_model_for_kbit_training(model) config = LoraConfig( r=lora_r, @@ -241,7 +249,7 @@ def generate_and_tokenize_prompt(data_point): learning_rate=learning_rate, fp16=True, logging_steps=10, - optim="adamw_torch", + optim="paged_adamw_8bit", #adamw_bnb_8bit evaluation_strategy="steps" if val_set_size > 0 else "no", save_strategy="steps", eval_steps=200 if val_set_size > 0 else None, diff --git a/generate.py b/generate.py index 4e1a9d7f..63640d27 100644 --- a/generate.py +++ b/generate.py @@ -6,7 +6,7 @@ import torch import transformers from peft import PeftModel -from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig from utils.callbacks import Iteratorize, Stream from utils.prompter import Prompter @@ -25,6 +25,7 @@ def main( load_8bit: bool = False, + load_4bit: bool = False, base_model: str = "", lora_weights: str = "tloen/alpaca-lora-7b", prompt_template: str = "", # The prompt template to use, will default to alpaca. @@ -39,12 +40,17 @@ def main( prompter = Prompter(prompt_template) tokenizer = LlamaTokenizer.from_pretrained(base_model) if device == "cuda": - model = LlamaForCausalLM.from_pretrained( - base_model, + load_8bit = False if load_4bit else load_8bit + bnb_config = BitsAndBytesConfig( + load_in_4bit=load_4bit, load_in_8bit=load_8bit, - torch_dtype=torch.float16, - device_map="auto", + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, ) + model = LlamaForCausalLM.from_pretrained( + base_model, quantization_config=bnb_config, torch_dtype=torch.float16, device_map="auto" + ) model = PeftModel.from_pretrained( model, lora_weights, @@ -77,7 +83,7 @@ def main( model.config.bos_token_id = 1 model.config.eos_token_id = 2 - if not load_8bit: + if not (load_8bit or load_4bit): model.half() # seems to fix bugs for some users. model.eval() diff --git a/requirements.txt b/requirements.txt index 35fd00fe..f8dcef09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -accelerate +git+https://github.com/huggingface/accelerate.git appdirs loralib -bitsandbytes +bitsandbytes>=0.39.0 black black[jupyter] datasets fire git+https://github.com/huggingface/peft.git -transformers>=4.28.0 +git+https://github.com/huggingface/transformers.git sentencepiece gradio From 6aacc532ccadffdc27b64a2901469f7e6f6e834f Mon Sep 17 00:00:00 2001 From: Jeff Date: Wed, 7 Jun 2023 12:18:59 +0000 Subject: [PATCH 2/2] add missing package --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index f8dcef09..f20df074 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ git+https://github.com/huggingface/peft.git git+https://github.com/huggingface/transformers.git sentencepiece gradio +scipy \ No newline at end of file