Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Sep 26, 2024
2 parents 80aed7a + 8676d01 commit 7f6d6be
Show file tree
Hide file tree
Showing 19 changed files with 97,611 additions and 30 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,35 @@ Example 3:

</details>

#### Post-processing needed for inference on VLLM

In order to run inference of LoRA adapters on vLLM, any new token embeddings added while tuning needs to be moved out of 'adapters.safetensors' to a new file 'new_embeddings.safetensors'. The 'adapters.safetensors' should only have LoRA weights and should not have modified embedding vectors. This is a requirement to support vLLM's paradigm that one base model can serve multiple adapters. New token embedding vectors are appended to the embedding matrix read from the base model by vLLM.

To do this postprocessing, the tuning script sft_trainer.py will generate a file 'added_tokens_info.json' with model artifacts. After tuning, you can run script 'post_process_adapters_vLLM.py' :

```bash
# model_path: Path to saved model artifacts which has file 'added_tokens_info.json'
# output_model_path: Optional. If you want to store modified \
# artifacts in a different directory rather than modify in-place.
python scripts/post_process_adapters_vLLM.py \
--model_path "/testing/tuning/output/post-process-LoRA-saved" \
--output_model_path "/testing/tuning/output/post-process-LoRA-modified"
```

<details>
<summary> Alternatively, if using SDK :</summary>

```bash
# function in tuning/utils/merge_model_utils.py
post_process_vLLM_adapters_new_tokens(
path_to_checkpoint="/testing/tuning/output/post-process-LoRA-saved",
modified_checkpoint_path=None,
num_added_tokens=1,
)
# where num_added_tokens is returned by sft_trainer.train()
```
</details>

_________________________


Expand Down
6 changes: 4 additions & 2 deletions build/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ For example, the below config is used for running with two GPUs and FSDP for fin
"per_device_train_batch_size": 4,
"learning_rate": 1e-5,
"response_template": "\n### Label:",
"dataset_text_field": "output"
"dataset_text_field": "output",
"lora_post_process_for_vllm": true
}
```

Users should always set `num_processes` to be explicit about the number of processes to run tuning on. When `num_processes` is greater than 1, the [FSDP config](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/fixtures/accelerate_fsdp_defaults.yaml) is used by default. Thus in the above example, you don't need to pass in the FSDP flags since they match the ones used in the default FSDP config. You can also set your own default values by specifying your own config file using key `config_file`. Any of these values in configs can be overwritten by passing in flags via `accelerate_launch_args` in the JSON config.
`num_processes` defaults to the amount of GPUs allocated for tuning, unless the user sets `SET_NUM_PROCESSES_TO_NUM_GPUS` to `False`. When `num_processes` is greater than 1, the [FSDP config](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/fixtures/accelerate_fsdp_defaults.yaml) is used by default. Thus in the above example, you don't need to pass in the FSDP flags since they match the ones used in the default FSDP config. You can also set your own default values by specifying your own config file using key `config_file`. Any of these values in configs can be overwritten by passing in flags via `accelerate_launch_args` in the JSON config.

Note that `num_processes` which is the total number of processes to be launched in parallel, should match the number of GPUs to run on. The number of GPUs used can also be set by setting environment variable `CUDA_VISIBLE_DEVICES`. If ``num_processes=1`, the script will assume single-GPU.

If tuning for inference on vLLM, set `lora_post_process_for_vllm` to `true`. Post process LoRA adapters to allow inferencing on vLLM. vLLM needs new token embedding weights added during tuning to be moved to a new file new_embeddings.safetensors.

## Building the Image

Expand Down
53 changes: 53 additions & 0 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import subprocess
import sys
import traceback
import json
from pathlib import Path

# Third Party
Expand All @@ -32,6 +33,9 @@
from build.utils import (
process_accelerate_launch_args,
)
from tuning.utils.merge_model_utils import (
post_process_vLLM_adapters_new_tokens,
)
from tuning.utils.config_utils import get_json_config
from tuning.utils.error_logging import (
write_termination_log,
Expand Down Expand Up @@ -115,6 +119,55 @@ def main():
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

peft_method = job_config.get("peft_method")

if job_config.get("lora_post_process_for_vllm") and peft_method == "lora":
save_model_dir = job_config.get("save_model_dir")
if save_model_dir:
if os.path.exists(os.path.join(save_model_dir, "added_tokens_info.json")):
with open(
os.path.join(save_model_dir, "added_tokens_info.json"),
encoding="utf-8",
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
else:
logging.warning(
"Failed to post-process: file added_tokens_info.json not in path %s",
save_model_dir,
)

if os.path.exists(
os.path.join(save_model_dir, "adapter_model.safetensors")
):
post_process_vLLM_adapters_new_tokens(
save_model_dir, save_model_dir, num_added_tokens
)

if (
os.path.exists(os.path.join(output_dir, "added_tokens_info.json"))
and job_config.get("save_strategy") != "no"
):
with open(
os.path.join(output_dir, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
# if multiple checkpoints in directory, process each checkpoint
for _, dirs, _ in os.walk(output_dir, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(output_dir, name),
os.path.join(output_dir, name),
num_added_tokens,
)
else:
logging.warning(
"Failed to post-process: file added_tokens_info.json not in path %s",
save_model_dir,
)

# The .complete file will signal to users that we are finished copying
# files over
if os.path.exists(output_dir):
Expand Down
6 changes: 5 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
import shutil


def copy_checkpoint(source, destination):
def copy_checkpoint(source, destination, exclude_files: list[str] = None):
if not os.path.exists(destination):
os.makedirs(destination)
shutil.copystat(source, destination)
# Have a list of directory objects, now iterate over them.
if exclude_files is None:
exclude_files = []
for item in os.listdir(source):
if item in exclude_files:
continue
source_file = os.path.join(source, item)
destination_file = os.path.join(destination, item)
if os.path.isdir(source_file):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers=[
dependencies = [
"numpy>=1.26.4,<2.0",
"accelerate>=0.20.3,<0.34",
"transformers>4.41,<5.0",
"transformers>4.41,<4.45",
"torch>=2.2.0,<3.0",
"sentencepiece>=0.1.99,<0.3",
"tokenizers>=0.13.3,<1.0",
Expand Down
94 changes: 94 additions & 0 deletions scripts/post_process_adapters_vLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
""" Script to post-process tuned LoRA adapters for inference on vLLM.
vLLM requires that any token embeddings added while tuning be moved to a new file \
called new_embeddings.safetensors. \
See the description in utility function \
/tuning/utils/merge_model_utils/post_process_vLLM_adapters_new_tokens for more details.
This script takes a path to tuned model artifacts containing adapters \
(or checkpoints with adapters) and the file 'added_tokens_info.json' produced while tuning. \
It will perform the post-processing as needed for inferencing on vLLM.
"""
# Standard
import argparse
import json
import logging
import os
import sys

# Local
from tuning.utils.merge_model_utils import (
copy_files_to_directory,
post_process_vLLM_adapters_new_tokens,
)


### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Post processes LoRA adapters due to addition of new tokens, as needed by vLLM"
)
parser.add_argument(
"--model_path",
help="Path to tuned model containing either one or multiple checkpoints. \
Path should have file added_tokens_info.json produced by tuning. \
Hint: This will be either output_dir or save_model_dir arguments while tuning. \
If multiple checkpoints are present, each checkpoint folder name \
should begin with 'checkpoint-'",
required=True,
)
parser.add_argument(
"--output_model_path",
help="Output directory where post-processed artifacts will be stored. \
If not provided, artifacts will be modified in place",
default=None,
)
args = parser.parse_args()

if args.output_model_path is None:
output_model_path = args.model_path
else:
output_model_path = args.output_model_path
if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")):
with open(
os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
else:
raise ValueError(
"file added_tokens_info.json not in model_path. \
Cannot post-processes"
)
if num_added_tokens == 0:
logging.info("No new tokens added, hence post-processing not needed")
sys.exit(0)

found_adapters = 0
if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")):
found_adapters = 1
post_process_vLLM_adapters_new_tokens(
args.model_path, output_model_path, num_added_tokens
)
# if multiple checkpoints in directory, process each checkpoint
found_checkpoints = 0
for _, dirs, _ in os.walk(args.model_path, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(args.model_path, name),
os.path.join(output_model_path, name),
num_added_tokens,
)
found_checkpoints = 1
if found_checkpoints and output_model_path != args.model_path:
copy_files_to_directory(
args.model_path,
output_model_path,
exclude_files=["adapter_model.safetensors"],
)
if not found_adapters and not found_checkpoints:
logging.warning("No adapters were found to process in model path provided")


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions tests/artifacts/tuned_llama_with_added_tokens/adapter_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"alpha_pattern": {},
"auto_mapping": null,
"base_model_name_or_path": "Maykeye/TinyLLama-v0",
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layer_replication": null,
"layers_pattern": null,
"layers_to_transform": null,
"loftq_config": {},
"lora_alpha": 32,
"lora_dropout": 0.05,
"megatron_config": null,
"megatron_core": "megatron.core",
"modules_to_save": null,
"peft_type": "LORA",
"r": 8,
"rank_pattern": {},
"revision": null,
"target_modules": [
"v_proj",
"q_proj"
],
"task_type": "CAUSAL_LM",
"use_dora": false,
"use_rslora": false
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"<pad>": 32000
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Loading

0 comments on commit 7f6d6be

Please sign in to comment.