Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4-bit quantization and QLoRA #487

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- 🤗 **Try the pretrained model out [here](https://huggingface.co/spaces/tloen/alpaca-lora), courtesy of a GPU grant from Huggingface!**
- Users have created a Discord server for discussion and support [here](https://discord.gg/prbq284xX5)
- 4/14: Chansung Park's GPT4-Alpaca adapters: https://github.com/tloen/alpaca-lora/issues/340
- 5/30: 4-bit quantization and QLoRA: https://github.com/tloen/alpaca-lora/issues/486

This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
Expand Down Expand Up @@ -32,10 +33,11 @@ This file contains a straightforward application of PEFT to the LLaMA model,
as well as some code related to prompt construction and tokenization.
PRs adapting this code to support larger models are always welcome.

Example usage:
Example usage with 4-bit quantization:

```bash
python finetune.py \
--load_in_4bit \
--base_model 'decapoda-research/llama-7b-hf' \
--data_path 'yahma/alpaca-cleaned' \
--output_dir './lora-alpaca'
Expand Down Expand Up @@ -66,7 +68,7 @@ python finetune.py \

This file reads the foundation model from the Hugging Face model hub and the LoRA weights from `tloen/alpaca-lora-7b`, and runs a Gradio interface for inference on a specified input. Users should treat this as example code for the use of the model, and modify it as needed.

Example usage:
Example usage 8-bit:

```bash
python generate.py \
Expand All @@ -75,6 +77,14 @@ python generate.py \
--lora_weights 'tloen/alpaca-lora-7b'
```

Example usage 4-bit:
```bash
python generate.py \
--load_4bit \
--base_model 'decapoda-research/llama-7b-hf' \
--lora_weights 'tloen/alpaca-lora-7b'
```

### Official weights

The most recent "official" Alpaca-LoRA adapter available at [`tloen/alpaca-lora-7b`](https://huggingface.co/tloen/alpaca-lora-7b) was trained on March 26 with the following command:
Expand Down
26 changes: 17 additions & 9 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig

from utils.prompter import Prompter

Expand Down Expand Up @@ -56,6 +56,7 @@ def train(
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
load_in_4bit: bool = False, #using 4bit quantization
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
Expand All @@ -82,6 +83,7 @@ def train(
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
f"load_in_4bit: {load_in_4bit}\n"
)
assert (
base_model
Expand All @@ -108,14 +110,20 @@ def train(
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model


load_in_8bit = True if not load_in_4bit else False
bnb_config = BitsAndBytesConfig(
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
base_model, quantization_config=bnb_config, torch_dtype=torch.float16, device_map=device_map
)


tokenizer = LlamaTokenizer.from_pretrained(base_model)

tokenizer.pad_token_id = (
Expand Down Expand Up @@ -171,7 +179,7 @@ def generate_and_tokenize_prompt(data_point):
] # could be sped up, probably
return tokenized_full_prompt

model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
r=lora_r,
Expand Down Expand Up @@ -241,7 +249,7 @@ def generate_and_tokenize_prompt(data_point):
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
optim="paged_adamw_8bit", #adamw_bnb_8bit
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=200 if val_set_size > 0 else None,
Expand Down
18 changes: 12 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig

from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
Expand All @@ -25,6 +25,7 @@

def main(
load_8bit: bool = False,
load_4bit: bool = False,
base_model: str = "",
lora_weights: str = "tloen/alpaca-lora-7b",
prompt_template: str = "", # The prompt template to use, will default to alpaca.
Expand All @@ -39,12 +40,17 @@ def main(
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_8bit = False if load_4bit else load_8bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=load_4bit,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = LlamaForCausalLM.from_pretrained(
base_model, quantization_config=bnb_config, torch_dtype=torch.float16, device_map="auto"
)
model = PeftModel.from_pretrained(
model,
lora_weights,
Expand Down Expand Up @@ -77,7 +83,7 @@ def main(
model.config.bos_token_id = 1
model.config.eos_token_id = 2

if not load_8bit:
if not (load_8bit or load_4bit):
model.half() # seems to fix bugs for some users.

model.eval()
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
accelerate
git+https://github.com/huggingface/accelerate.git
appdirs
loralib
bitsandbytes
bitsandbytes>=0.39.0
black
black[jupyter]
datasets
fire
git+https://github.com/huggingface/peft.git
transformers>=4.28.0
git+https://github.com/huggingface/transformers.git
sentencepiece
gradio
scipy