From 8548a6df86e0a3ea00ece11a47c3aac2971a512e Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 24 Apr 2024 10:39:19 -0600 Subject: [PATCH] Add unit tests for various edge cases (#97) * Add unit tests for various edge cases Signed-off-by: Alex-Brooks * Fix bf16 check in skipped test Signed-off-by: Alex-Brooks * Remove redundant test Signed-off-by: Alex-Brooks * Fix linting Signed-off-by: Alex-Brooks --------- Signed-off-by: Alex-Brooks --- tests/data/__init__.py | 2 + tests/data/empty_data.json | 0 tests/data/malformatted_data.json | 1 + tests/test_sft_trainer.py | 197 ++++++++++++++++++++++++++++-- tuning/sft_trainer.py | 46 +++---- 5 files changed, 212 insertions(+), 34 deletions(-) create mode 100644 tests/data/empty_data.json create mode 100644 tests/data/malformatted_data.json diff --git a/tests/data/__init__.py b/tests/data/__init__.py index e7462d27b..6df7802cd 100644 --- a/tests/data/__init__.py +++ b/tests/data/__init__.py @@ -20,3 +20,5 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json") +EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") +MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") diff --git a/tests/data/empty_data.json b/tests/data/empty_data.json new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/malformatted_data.json b/tests/data/malformatted_data.json new file mode 100644 index 000000000..437763095 --- /dev/null +++ b/tests/data/malformatted_data.json @@ -0,0 +1 @@ +This data is bad! We can't use it to tune. diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 5f1c65bce..c23c7e2c5 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -22,11 +22,14 @@ import tempfile # Third Party +from datasets.exceptions import DatasetGenerationError import pytest +import torch +import transformers # First Party from scripts.run_inference import TunedCausalLM -from tests.data import TWITTER_COMPLAINTS_DATA +from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA from tests.helpers import causal_lm_train_kwargs # Local @@ -122,9 +125,10 @@ def test_run_train_fails_training_data_path_not_exist(): def test_run_causallm_pt_and_inference(): """Check if we can bootstrap and peft tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: - BASE_PEFT_KWARGS["output_dir"] = tempdir + TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}} + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - BASE_PEFT_KWARGS + TRAIN_KWARGS ) sft_trainer.train(model_args, data_args, training_args, tune_config) @@ -148,11 +152,12 @@ def test_run_causallm_pt_and_inference(): def test_run_causallm_pt_init_text(): """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" with tempfile.TemporaryDirectory() as tempdir: - pt_init_text = copy.deepcopy(BASE_PEFT_KWARGS) - pt_init_text["output_dir"] = tempdir - pt_init_text["prompt_tuning_init"] = "TEXT" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"output_dir": tempdir, "prompt_tuning_init": "TEXT"}, + } model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - pt_init_text + TRAIN_KWARGS ) sft_trainer.train(model_args, data_args, training_args, tune_config) @@ -160,7 +165,7 @@ def test_run_causallm_pt_init_text(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - _validate_adapter_config(adapter_config, "PROMPT_TUNING", pt_init_text) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", TRAIN_KWARGS) invalid_params_map = [ @@ -326,3 +331,179 @@ def _validate_adapter_config(adapter_config, peft_type, base_kwargs): if peft_type == "PROMPT_TUNING" else True ) + + +### Tests for a variety of edge cases and potentially problematic cases; +# some of these test directly test validation within external dependencies +# and validate errors that we expect to get from them which might be unintuitive. +# In such cases, it would probably be best for us to handle these things directly +# for better error messages, etc. + +### Tests related to tokenizer configuration +def test_tokenizer_has_no_eos_token(): + """Ensure that if the model has no EOS token, it sets the default before formatting.""" + # This is a bit roundabout, but patch the tokenizer and export it and the model to a tempdir + # that we can then reload out of for the train call, and clean up afterwards. + tokenizer = transformers.AutoTokenizer.from_pretrained( + BASE_PEFT_KWARGS["model_name_or_path"] + ) + model = transformers.AutoModelForCausalLM.from_pretrained( + BASE_PEFT_KWARGS["model_name_or_path"] + ) + tokenizer.eos_token = None + with tempfile.TemporaryDirectory() as tempdir: + tokenizer.save_pretrained(tempdir) + model.save_pretrained(tempdir) + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"model_name_or_path": tempdir, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # If we handled this badly, we would probably get something like a + # TypeError: can only concatenate str (not "NoneType") to str error + # when we go to apply the data formatter. + sft_trainer.train(model_args, data_args, training_args, tune_config) + _validate_training(tempdir) + + +### Tests for Bad dataset specification, i.e., data is valid, but the field we point it at isn't +def test_invalid_dataset_text_field(): + """Ensure that if we specify a dataset_text_field that doesn't exist, we get a KeyError.""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"dataset_text_field": "not found", "output_dir": "foo/bar/baz"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(KeyError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) +def test_malformatted_data(): + """Ensure that malformatted data explodes due to failure to generate the dataset.""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"training_data_path": MALFORMATTED_DATA, "output_dir": "foo/bar/baz"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(DatasetGenerationError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +def test_empty_data(): + """Ensure that malformatted data explodes due to failure to generate the dataset.""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"training_data_path": EMPTY_DATA, "output_dir": "foo/bar/baz"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(DatasetGenerationError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +def test_data_path_is_a_directory(): + """Ensure that we get FileNotFoundError if we point the data path at a dir, not a file.""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"training_data_path": tempdir, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # Confusingly, if we pass a directory for our data path, it will throw a + # FileNotFoundError saying "unable to find ''", since it can't + # find a matchable file in the path. + with pytest.raises(FileNotFoundError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Tests for bad tuning module configurations +def test_run_causallm_lora_with_invalid_modules(): + """Check that we throw a value error if the target modules for lora don't exist.""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"peft_method": "lora", "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # Defaults are q_proj / v_proj; this will fail lora as the torch module doesn't have them + tune_config.target_modules = ["foo", "bar"] + # Peft should throw a value error about modules not matching the base module + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Direct validation tests based on whether or not packing is enabled +def test_no_packing_needs_dataset_text_field(): + """Ensure we need to set the dataset text field if packing is False""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"dataset_text_field": None, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +# TODO: Fix this case +@pytest.mark.skip(reason="currently crashes before validation is done") +def test_no_packing_needs_reponse_template(): + """Ensure we need to set the response template if packing is False""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"response_template": None, "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + + +### Tests for model dtype edge cases +@pytest.mark.skipif( + not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()), + reason="Only runs if bf16 is unsupported", +) +def test_bf16_still_tunes_if_unsupported(): + """Ensure that even if bf16 is not supported, tuning still works without problems.""" + assert not torch.cuda.is_bf16_supported() + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"torch_dtype": "bfloat16", "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + sft_trainer.train(model_args, data_args, training_args, tune_config) + _validate_training(tempdir) + + +def test_bad_torch_dtype(): + """Ensure that specifying an invalid torch dtype yields a ValueError.""" + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"torch_dtype": "not a type", "output_dir": tempdir}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 5583a2dfa..1f8effa28 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -17,7 +17,6 @@ from typing import Optional, Union import json import os -import sys # Third Party from peft.utils.other import fsdp_auto_wrap_policy @@ -202,6 +201,26 @@ def train( model=model, ) + # Configure the collator and validate args related to packing prior to formatting the dataset + if train_args.packing: + logger.info("Packing is set to True") + data_collator = None + packing = True + else: + logger.info("Packing is set to False") + if data_args.response_template is None: + # TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization + # We should do this validation up front, then do the encoding, then handle the collator + raise ValueError("Response template is None, needs to be set for training") + if data_args.dataset_text_field is None: + raise ValueError("Dataset_text_field is None, needs to be set for training") + data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) + packing = False + # load the data by parsing JSON data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: @@ -235,31 +254,6 @@ def train( ) callbacks.append(tc_callback) - if train_args.packing: - logger.info("Packing is set to True") - data_collator = None - packing = True - else: - logger.info("Packing is set to False") - if data_args.response_template is None: - logger.error( - "Error, response template is None, needs to be set for training" - ) - sys.exit(-1) - - if data_args.dataset_text_field is None: - logger.error( - "Error, dataset_text_field is None, needs to be set for training" - ) - sys.exit(-1) - - data_collator = DataCollatorForCompletionOnlyLM( - response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) - packing = False - trainer = SFTTrainer( model=model, tokenizer=tokenizer,