forked from foundation-model-stack/fms-hf-tuning
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main'
- Loading branch information
Showing
19 changed files
with
97,611 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" Script to post-process tuned LoRA adapters for inference on vLLM. | ||
vLLM requires that any token embeddings added while tuning be moved to a new file \ | ||
called new_embeddings.safetensors. \ | ||
See the description in utility function \ | ||
/tuning/utils/merge_model_utils/post_process_vLLM_adapters_new_tokens for more details. | ||
This script takes a path to tuned model artifacts containing adapters \ | ||
(or checkpoints with adapters) and the file 'added_tokens_info.json' produced while tuning. \ | ||
It will perform the post-processing as needed for inferencing on vLLM. | ||
""" | ||
# Standard | ||
import argparse | ||
import json | ||
import logging | ||
import os | ||
import sys | ||
|
||
# Local | ||
from tuning.utils.merge_model_utils import ( | ||
copy_files_to_directory, | ||
post_process_vLLM_adapters_new_tokens, | ||
) | ||
|
||
|
||
### Main & arg parsing | ||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Post processes LoRA adapters due to addition of new tokens, as needed by vLLM" | ||
) | ||
parser.add_argument( | ||
"--model_path", | ||
help="Path to tuned model containing either one or multiple checkpoints. \ | ||
Path should have file added_tokens_info.json produced by tuning. \ | ||
Hint: This will be either output_dir or save_model_dir arguments while tuning. \ | ||
If multiple checkpoints are present, each checkpoint folder name \ | ||
should begin with 'checkpoint-'", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--output_model_path", | ||
help="Output directory where post-processed artifacts will be stored. \ | ||
If not provided, artifacts will be modified in place", | ||
default=None, | ||
) | ||
args = parser.parse_args() | ||
|
||
if args.output_model_path is None: | ||
output_model_path = args.model_path | ||
else: | ||
output_model_path = args.output_model_path | ||
if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")): | ||
with open( | ||
os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8" | ||
) as json_data: | ||
added_tokens_info = json.load(json_data) | ||
num_added_tokens = added_tokens_info["num_new_tokens"] | ||
else: | ||
raise ValueError( | ||
"file added_tokens_info.json not in model_path. \ | ||
Cannot post-processes" | ||
) | ||
if num_added_tokens == 0: | ||
logging.info("No new tokens added, hence post-processing not needed") | ||
sys.exit(0) | ||
|
||
found_adapters = 0 | ||
if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")): | ||
found_adapters = 1 | ||
post_process_vLLM_adapters_new_tokens( | ||
args.model_path, output_model_path, num_added_tokens | ||
) | ||
# if multiple checkpoints in directory, process each checkpoint | ||
found_checkpoints = 0 | ||
for _, dirs, _ in os.walk(args.model_path, topdown=False): | ||
for name in dirs: | ||
if "checkpoint-" in name.lower(): | ||
post_process_vLLM_adapters_new_tokens( | ||
os.path.join(args.model_path, name), | ||
os.path.join(output_model_path, name), | ||
num_added_tokens, | ||
) | ||
found_checkpoints = 1 | ||
if found_checkpoints and output_model_path != args.model_path: | ||
copy_files_to_directory( | ||
args.model_path, | ||
output_model_path, | ||
exclude_files=["adapter_model.safetensors"], | ||
) | ||
if not found_adapters and not found_checkpoints: | ||
logging.warning("No adapters were found to process in model path provided") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
29 changes: 29 additions & 0 deletions
29
tests/artifacts/tuned_llama_with_added_tokens/adapter_config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"alpha_pattern": {}, | ||
"auto_mapping": null, | ||
"base_model_name_or_path": "Maykeye/TinyLLama-v0", | ||
"bias": "none", | ||
"fan_in_fan_out": false, | ||
"inference_mode": true, | ||
"init_lora_weights": true, | ||
"layer_replication": null, | ||
"layers_pattern": null, | ||
"layers_to_transform": null, | ||
"loftq_config": {}, | ||
"lora_alpha": 32, | ||
"lora_dropout": 0.05, | ||
"megatron_config": null, | ||
"megatron_core": "megatron.core", | ||
"modules_to_save": null, | ||
"peft_type": "LORA", | ||
"r": 8, | ||
"rank_pattern": {}, | ||
"revision": null, | ||
"target_modules": [ | ||
"v_proj", | ||
"q_proj" | ||
], | ||
"task_type": "CAUSAL_LM", | ||
"use_dora": false, | ||
"use_rslora": false | ||
} |
Binary file added
BIN
+5.7 KB
tests/artifacts/tuned_llama_with_added_tokens/adapter_model.safetensors
Binary file not shown.
3 changes: 3 additions & 0 deletions
3
tests/artifacts/tuned_llama_with_added_tokens/added_tokens.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"<pad>": 32000 | ||
} |
30 changes: 30 additions & 0 deletions
30
tests/artifacts/tuned_llama_with_added_tokens/special_tokens_map.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
{ | ||
"bos_token": { | ||
"content": "<s>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
}, | ||
"eos_token": { | ||
"content": "</s>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
}, | ||
"pad_token": { | ||
"content": "<pad>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
}, | ||
"unk_token": { | ||
"content": "<unk>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
} | ||
} |
Oops, something went wrong.