diff --git a/tests/helpers.py b/tests/helpers.py index a88ae3ef8..59695826f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -37,11 +37,9 @@ def causal_lm_train_kwargs(train_kwargs): lora_config, prompt_tuning_config, ) = parser.parse_dict(train_kwargs, allow_extra_keys=True) - return ( - model_args, - data_args, - training_args, - lora_config - if train_kwargs.get("peft_method") == "lora" - else prompt_tuning_config, - ) + tuning_config = None + if train_kwargs.get("peft_method") == "lora": + tuning_config = lora_config + elif train_kwargs.get("peft_method") == "pt": + tuning_config = prompt_tuning_config + return (model_args, data_args, training_args, tuning_config) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c23c7e2c5..a55f7d25b 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -34,6 +34,7 @@ # Local from tuning import sft_trainer +from tuning.config import peft_config MODEL_NAME = "Maykeye/TinyLLama-v0" BASE_PEFT_KWARGS = { @@ -68,9 +69,9 @@ BASE_LORA_KWARGS["peft_method"] = "lora" BASE_FT_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS) -BASE_FT_KWARGS["peft_method"] = "" -BASE_FT_KWARGS["prompt_tuning_init"] = "" -BASE_FT_KWARGS["prompt_tuning_init_text"] = "" +BASE_FT_KWARGS["peft_method"] = None +del BASE_FT_KWARGS["prompt_tuning_init"] +del BASE_FT_KWARGS["prompt_tuning_init_text"] def test_helper_causal_lm_train_kwargs(): @@ -96,6 +97,15 @@ def test_helper_causal_lm_train_kwargs(): assert tune_config.tokenizer_name_or_path == MODEL_NAME assert tune_config.num_virtual_tokens == 8 + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + BASE_FT_KWARGS + ) + assert tune_config is None + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + BASE_LORA_KWARGS + ) + assert isinstance(tune_config, peft_config.LoraConfig) + def test_run_train_requires_output_dir(): """Check fails when output dir not provided.""" @@ -277,13 +287,14 @@ def test_run_causallm_ft_and_inference(): model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( BASE_FT_KWARGS ) + # Just assuring no tuning config is passed for PT or LoRA + assert tune_config is None + sft_trainer.train(model_args, data_args, training_args, tune_config) # validate ft tuning configs _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) - adapter_config = _get_adapter_config(checkpoint_path) - _validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_FT_KWARGS) # Load the model loaded_model = TunedCausalLM.load(checkpoint_path) @@ -296,6 +307,7 @@ def test_run_causallm_ft_and_inference(): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +############################# Helper functions ############################# def _validate_training(tempdir, check_eval=False): assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) train_logs_file_path = "{}/training_logs.jsonl".format(tempdir) @@ -333,6 +345,7 @@ def _validate_adapter_config(adapter_config, peft_type, base_kwargs): ) +############################# Other Tests ############################# ### Tests for a variety of edge cases and potentially problematic cases; # some of these test directly test validation within external dependencies # and validate errors that we expect to get from them which might be unintuitive.