Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Jooho committed May 21, 2024
2 parents bb5fd1d + 7a56e91 commit 8150851
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
14 changes: 6 additions & 8 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ def causal_lm_train_kwargs(train_kwargs):
lora_config,
prompt_tuning_config,
) = parser.parse_dict(train_kwargs, allow_extra_keys=True)
return (
model_args,
data_args,
training_args,
lora_config
if train_kwargs.get("peft_method") == "lora"
else prompt_tuning_config,
)
tuning_config = None
if train_kwargs.get("peft_method") == "lora":
tuning_config = lora_config
elif train_kwargs.get("peft_method") == "pt":
tuning_config = prompt_tuning_config
return (model_args, data_args, training_args, tuning_config)
23 changes: 18 additions & 5 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# Local
from tuning import sft_trainer
from tuning.config import peft_config

MODEL_NAME = "Maykeye/TinyLLama-v0"
BASE_PEFT_KWARGS = {
Expand Down Expand Up @@ -68,9 +69,9 @@
BASE_LORA_KWARGS["peft_method"] = "lora"

BASE_FT_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS)
BASE_FT_KWARGS["peft_method"] = ""
BASE_FT_KWARGS["prompt_tuning_init"] = ""
BASE_FT_KWARGS["prompt_tuning_init_text"] = ""
BASE_FT_KWARGS["peft_method"] = None
del BASE_FT_KWARGS["prompt_tuning_init"]
del BASE_FT_KWARGS["prompt_tuning_init_text"]


def test_helper_causal_lm_train_kwargs():
Expand All @@ -96,6 +97,15 @@ def test_helper_causal_lm_train_kwargs():
assert tune_config.tokenizer_name_or_path == MODEL_NAME
assert tune_config.num_virtual_tokens == 8

model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
BASE_FT_KWARGS
)
assert tune_config is None
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
BASE_LORA_KWARGS
)
assert isinstance(tune_config, peft_config.LoraConfig)


def test_run_train_requires_output_dir():
"""Check fails when output dir not provided."""
Expand Down Expand Up @@ -277,13 +287,14 @@ def test_run_causallm_ft_and_inference():
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
BASE_FT_KWARGS
)
# Just assuring no tuning config is passed for PT or LoRA
assert tune_config is None

sft_trainer.train(model_args, data_args, training_args, tune_config)

# validate ft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_FT_KWARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
Expand All @@ -296,6 +307,7 @@ def test_run_causallm_ft_and_inference():
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


############################# Helper functions #############################
def _validate_training(tempdir, check_eval=False):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/training_logs.jsonl".format(tempdir)
Expand Down Expand Up @@ -333,6 +345,7 @@ def _validate_adapter_config(adapter_config, peft_type, base_kwargs):
)


############################# Other Tests #############################
### Tests for a variety of edge cases and potentially problematic cases;
# some of these test directly test validation within external dependencies
# and validate errors that we expect to get from them which might be unintuitive.
Expand Down

0 comments on commit 8150851

Please sign in to comment.