diff --git a/README.md b/README.md index bb1a95876..f552f06fe 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,10 @@ pip install -e ".[aim]" ``` ## Data format -The data format expectation is a single column text. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. +We support two data formats: + +1. #### Pre-process the JSON/JSONL dataset + Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. ```python PROMPT_DICT = { @@ -56,6 +59,24 @@ The `response template` corresponding to the above dataset and the `Llama` token The same way can be applied to any dataset, with more info can be found [here](https://huggingface.co/docs/trl/main/en/sft_trainer#format-your-input-prompts). +Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer. + +2. #### Format JSON/JSONL on the fly + Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. + JSON fields can contain alpha-numeric characters, spaces and the following special symbols - "." , "_", "-". + +Example: Train.json +`[{ "input" : , + "output" : , + }, + ... +]` +data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` + +Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. + + +##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. ## Supported Models @@ -64,12 +85,16 @@ Current supported and tested models are `Llama2` (7 and 13B configurations have ## Training ### Single GPU + +1. Using pre-processed dataset for training. + ```bash # if you want to use one GPU on multi-gpu machine export CUDA_VISIBLE_DEVICES=0 # MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint # TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset + # contains data in single sequence {"output": "### Input: text \n\n### Response: text"} # OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved python tuning/sft_trainer.py \ @@ -94,6 +119,39 @@ python tuning/sft_trainer.py \ ``` +2. Using formatter with JSON/JSONL files + +```bash +# if you want to use one GPU on multi-gpu machine +export CUDA_VISIBLE_DEVICES=0 + +# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint +# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset + # contains data in form of [{"input": text , "output": text}] +# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved + +python tuning/sft_trainer.py \ +--model_name_or_path $MODEL_PATH \ +--training_data_path $TRAIN_DATA_PATH \ +--output_dir $OUTPUT_PATH \ +--num_train_epochs 5 \ +--per_device_train_batch_size 4 \ +--per_device_eval_batch_size 4 \ +--gradient_accumulation_steps 4 \ +--eval_strategy "no" \ +--save_strategy "epoch" \ +--learning_rate 1e-5 \ +--weight_decay 0. \ +--warmup_ratio 0.03 \ +--lr_scheduler_type "cosine" \ +--logging_steps 1 \ +--include_tokens_per_second \ +--packing False \ +--response_template "\n## Label:" \ +--data_formatter_template: "### Input: {{input}} \n\n##Label: {{output}}" + +``` + ### Multiple GPUs with FSDP The recommendation is to use [huggingface accelerate](https://huggingface.co/docs/accelerate/en/index) to launch multi-gpu jobs, in particular when using FSDP: diff --git a/tests/data/__init__.py b/tests/data/__init__.py index 6df7802cd..b81ccaff2 100644 --- a/tests/data/__init__.py +++ b/tests/data/__init__.py @@ -20,5 +20,6 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json") +TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json") EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") diff --git a/tests/data/twitter_complaints_json.json b/tests/data/twitter_complaints_json.json new file mode 100644 index 000000000..fba22a9fd --- /dev/null +++ b/tests/data/twitter_complaints_json.json @@ -0,0 +1,12 @@ +[ + {"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"}, + {"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"}, + {"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"}, + {"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"}, + {"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"}, + {"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"}, + {"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"}, + {"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"}, + {"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"}, + {"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"} +] \ No newline at end of file diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index bbe91f890..42368879f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -29,7 +29,12 @@ # First Party from scripts.run_inference import TunedCausalLM -from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA +from tests.data import ( + EMPTY_DATA, + MALFORMATTED_DATA, + TWITTER_COMPLAINTS_DATA, + TWITTER_COMPLAINTS_JSON_FORMAT, +) # Local from tuning import sft_trainer @@ -114,6 +119,70 @@ def test_run_causallm_pt_and_inference(): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +def test_run_causallm_pt_and_inference_with_formatting_data(): + """Check if we can bootstrap and peft tune causallm models + This test needs the trainer to format data to a single sequence internally. + """ + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + data_formatting_args.dataset_text_field = None + data_formatting_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args, PEFT_PT_ARGS) + + # validate peft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + +def test_run_causallm_pt_and_inference_JSON_file_formatter(): + """Check if we can bootstrap and peft tune causallm models with JSON train file format""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.dataset_text_field = None + data_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) + + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + + # validate peft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + def test_run_causallm_pt_init_text(): """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" with tempfile.TemporaryDirectory() as tempdir: @@ -174,6 +243,23 @@ def test_run_causallm_pt_with_validation(): _validate_training(tempdir, check_eval=True) +def test_run_causallm_pt_with_validation_data_formatting(): + """Check if we can bootstrap and peft tune causallm models with validation dataset""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.eval_strategy = "epoch" + data_args = copy.deepcopy(DATA_ARGS) + data_args.validation_data_path = TWITTER_COMPLAINTS_DATA + data_args.dataset_text_field = None + data_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) + + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + _validate_training(tempdir, check_eval=True) + + ############################# Lora Tests ############################# target_modules_val_map = [ @@ -335,6 +421,30 @@ def test_invalid_dataset_text_field(): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) +### Tests that giving dataset_text_field as well as formatter template gives error +def test_invalid_dataset_text_field_and_formatter_template(): + """Only one of dataset_text_field or formatter can be supplied""" + data_args = copy.deepcopy(DATA_ARGS) + data_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) + + with pytest.raises(ValueError): + sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) + + +### Tests passing formatter with invalid keys gives error +def test_invalid_formatter_template(): + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = None + data_args.data_formatter_template = ( + "### Text: {{not found}} \n\n### Label: {{text_label}}" + ) + + with pytest.raises(KeyError): + sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) + + ### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) def test_malformatted_data(): """Ensure that malformatted data explodes due to failure to generate the dataset.""" @@ -382,13 +492,15 @@ def test_run_causallm_lora_with_invalid_modules(): ### Direct validation tests based on whether or not packing is enabled -def test_no_packing_needs_dataset_text_field(): +def test_no_packing_needs_dataset_text_field_or_data_formatter_template(): """Ensure we need to set the dataset text field if packing is False""" with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir data_args = copy.deepcopy(DATA_ARGS) + # One of dataset_text_field or data_formatter_template should be set data_args.dataset_text_field = None + data_args.data_formatter_template = None with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py new file mode 100644 index 000000000..471f28590 --- /dev/null +++ b/tests/utils/test_data_utils.py @@ -0,0 +1,66 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +import datasets +import pytest + +# First Party +from tests.data import TWITTER_COMPLAINTS_DATA + +# Local +from tuning.utils import data_utils + + +def test_apply_custom_formatting_template(): + json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" + ) + formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( + json_dataset, template + ) + # a new dataset_text_field is created in Dataset + assert dataset_text_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][dataset_text_field] == expected_response + + +def test_apply_custom_formatting_template_adds_eos_token(): + json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaintEOS" + ) + formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( + json_dataset, template, "EOS" + ) + # a new dataset_text_field is created in Dataset + assert dataset_text_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][dataset_text_field] == expected_response + + +def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): + """Tests that the formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + with pytest.raises(KeyError): + data_utils.apply_custom_formatting_template(json_dataset, template, "EOS") diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 247652b7c..bccf5d15b 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -43,17 +43,34 @@ class ModelArguments: @dataclass class DataArguments: training_data_path: str = field( - default=None, metadata={"help": "Path to the training data in JSONL format."} + default=None, + metadata={"help": "Path to the training data in JSON/JSONL format."}, ) response_template: str = field( default=None, metadata={"help": "Response template, separator to train on completions only"}, ) dataset_text_field: str = field( - default=None, metadata={"help": "Training dataset text field"} + default=None, + metadata={ + "help": "Training dataset text field containing single sequence. \ + Either the dataset_text_field \ + or data_formatter_template need to be supplied." + }, ) validation_data_path: str = field( - default=None, metadata={"help": "Path to the validation data in JSONL format."} + default=None, + metadata={"help": "Path to the validation data in JSON/JSONL format."}, + ) + data_formatter_template: str = field( + default=None, + metadata={ + "help": "formatter template to format a single sequence \ + from each instance in JSONL files. \ + Keys of JSON can be referred to as {{key}} in template. \ + Either the dataset_text_field \ + or data_formatter_template needs to be supplied." + }, ) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b307505c0..fdf7efc8d 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -46,6 +46,7 @@ from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype +from tuning.utils.data_utils import apply_custom_formatting_template def train( @@ -218,8 +219,6 @@ def train( # TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization # We should do this validation up front, then do the encoding, then handle the collator raise ValueError("Response template is None, needs to be set for training") - if data_args.dataset_text_field is None: - raise ValueError("Dataset_text_field is None, needs to be set for training") data_collator = DataCollatorForCompletionOnlyLM( response_template_ids, tokenizer=tokenizer, @@ -227,6 +226,19 @@ def train( ) packing = False + # Currently we support formatted datasets with single sequence instances. + if not (data_args.dataset_text_field or data_args.data_formatter_template): + raise ValueError( + "dataset_text_field and data_formatter_template are None. \ + One of them needs to be set for training" + ) + # Only one of dataset_text_field or data_formatter_template should be set. + if data_args.dataset_text_field and data_args.data_formatter_template: + raise ValueError( + "dataset_text_field and data_formatter_template are both set,\ + but are mutually exclusive options" + ) + # load the data by parsing JSON data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: @@ -238,12 +250,34 @@ def train( } json_dataset = datasets.load_dataset("json", data_files=data_files) - formatted_train_dataset = json_dataset["train"].map(format_dataset) + if data_args.data_formatter_template: + ( + formatted_train_dataset, + data_args.dataset_text_field, + ) = apply_custom_formatting_template( + json_dataset["train"], + data_args.data_formatter_template, + tokenizer.eos_token, + ) + else: + formatted_train_dataset = json_dataset["train"].map(format_dataset) logger.info("Training dataset length is %s", len(formatted_train_dataset)) formatted_validation_dataset = None if data_args.validation_data_path: - formatted_validation_dataset = json_dataset["validation"].map(format_dataset) + if data_args.data_formatter_template: + ( + formatted_validation_dataset, + data_args.dataset_text_field, + ) = apply_custom_formatting_template( + json_dataset["validation"], + data_args.data_formatter_template, + tokenizer.eos_token, + ) + else: + formatted_validation_dataset = json_dataset["validation"].map( + format_dataset + ) logger.info( "Validation dataset length is %s", len(formatted_validation_dataset) ) diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py new file mode 100644 index 000000000..3e67cc56f --- /dev/null +++ b/tuning/utils/data_utils.py @@ -0,0 +1,40 @@ +# Standard +import re + + +def apply_custom_formatting_template(dataset, template, eos_token=""): + """Function to format datasets with Alpaca style / other templates. + Args: + dataset: the HF Dataset element loaded from a JSON or DatasetDict object. + template: Template to format data with. Features of Dataset + should be referred to by {{key}} + eos_token: string EOS token to be appended while formatting data to a single sequence. + Defaults to empty + Returns: + Formatted HF Dataset, dataset_field name that contains formatted data. + """ + + formatted_dataset_field = "formatted_data_field" + template += eos_token + + def formatter(element): + def replace_text(match_obj): + captured_groups = match_obj.groups() + if len(captured_groups) != 1: + raise ValueError( + "Unexpectedly captured multiple groups in template formatting" + ) + + index_object = captured_groups[0] + if index_object not in element: + raise KeyError("Requested template string is not a valid key in dict") + + return element[index_object] + + return { + formatted_dataset_field: re.sub( + r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template + ) + } + + return dataset.map(formatter), formatted_dataset_field