Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed May 30, 2024
2 parents 4fa8b7a + 3d0c4f3 commit e94920a
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 10 deletions.
60 changes: 59 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ pip install -e ".[aim]"
```

## Data format
The data format expectation is a single column text. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code.
We support two data formats:

1. #### Pre-process the JSON/JSONL dataset
Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code.

```python
PROMPT_DICT = {
Expand Down Expand Up @@ -56,6 +59,24 @@ The `response template` corresponding to the above dataset and the `Llama` token

The same way can be applied to any dataset, with more info can be found [here](https://huggingface.co/docs/trl/main/en/sft_trainer#format-your-input-prompts).

Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer.

2. #### Format JSON/JSONL on the fly
Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template.
JSON fields can contain alpha-numeric characters, spaces and the following special symbols - "." , "_", "-".

Example: Train.json
`[{ "input" : <text>,
"output" : <text>,
},
...
]`
data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}`

Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`.


##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer.

## Supported Models

Expand All @@ -64,12 +85,16 @@ Current supported and tested models are `Llama2` (7 and 13B configurations have
## Training

### Single GPU

1. Using pre-processed dataset for training.

```bash
# if you want to use one GPU on multi-gpu machine
export CUDA_VISIBLE_DEVICES=0

# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint
# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset
# contains data in single sequence {"output": "### Input: text \n\n### Response: text"}
# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved

python tuning/sft_trainer.py \
Expand All @@ -94,6 +119,39 @@ python tuning/sft_trainer.py \

```

2. Using formatter with JSON/JSONL files

```bash
# if you want to use one GPU on multi-gpu machine
export CUDA_VISIBLE_DEVICES=0

# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint
# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset
# contains data in form of [{"input": text , "output": text}]
# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved

python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--training_data_path $TRAIN_DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--eval_strategy "no" \
--save_strategy "epoch" \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--include_tokens_per_second \
--packing False \
--response_template "\n## Label:" \
--data_formatter_template: "### Input: {{input}} \n\n##Label: {{output}}"

```

### Multiple GPUs with FSDP

The recommendation is to use [huggingface accelerate](https://huggingface.co/docs/accelerate/en/index) to launch multi-gpu jobs, in particular when using FSDP:
Expand Down
1 change: 1 addition & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json")
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
12 changes: 12 additions & 0 deletions tests/data/twitter_complaints_json.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"},
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"},
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"},
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"},
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"},
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"},
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"},
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"},
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"},
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"}
]
116 changes: 114 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

# First Party
from scripts.run_inference import TunedCausalLM
from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA
from tests.data import (
EMPTY_DATA,
MALFORMATTED_DATA,
TWITTER_COMPLAINTS_DATA,
TWITTER_COMPLAINTS_JSON_FORMAT,
)

# Local
from tuning import sft_trainer
Expand Down Expand Up @@ -114,6 +119,70 @@ def test_run_causallm_pt_and_inference():
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def test_run_causallm_pt_and_inference_with_formatting_data():
"""Check if we can bootstrap and peft tune causallm models
This test needs the trainer to format data to a single sequence internally.
"""
with tempfile.TemporaryDirectory() as tempdir:
data_formatting_args = copy.deepcopy(DATA_ARGS)
data_formatting_args.dataset_text_field = None
data_formatting_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args, PEFT_PT_ARGS)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def test_run_causallm_pt_and_inference_JSON_file_formatter():
"""Check if we can bootstrap and peft tune causallm models with JSON train file format"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.dataset_text_field = None
data_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def test_run_causallm_pt_init_text():
"""Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down Expand Up @@ -174,6 +243,23 @@ def test_run_causallm_pt_with_validation():
_validate_training(tempdir, check_eval=True)


def test_run_causallm_pt_with_validation_data_formatting():
"""Check if we can bootstrap and peft tune causallm models with validation dataset"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.eval_strategy = "epoch"
data_args = copy.deepcopy(DATA_ARGS)
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
data_args.dataset_text_field = None
data_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
_validate_training(tempdir, check_eval=True)


############################# Lora Tests #############################

target_modules_val_map = [
Expand Down Expand Up @@ -335,6 +421,30 @@ def test_invalid_dataset_text_field():
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


### Tests that giving dataset_text_field as well as formatter template gives error
def test_invalid_dataset_text_field_and_formatter_template():
"""Only one of dataset_text_field or formatter can be supplied"""
data_args = copy.deepcopy(DATA_ARGS)
data_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


### Tests passing formatter with invalid keys gives error
def test_invalid_formatter_template():
data_args = copy.deepcopy(DATA_ARGS)
data_args.dataset_text_field = None
data_args.data_formatter_template = (
"### Text: {{not found}} \n\n### Label: {{text_label}}"
)

with pytest.raises(KeyError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing)
def test_malformatted_data():
"""Ensure that malformatted data explodes due to failure to generate the dataset."""
Expand Down Expand Up @@ -382,13 +492,15 @@ def test_run_causallm_lora_with_invalid_modules():


### Direct validation tests based on whether or not packing is enabled
def test_no_packing_needs_dataset_text_field():
def test_no_packing_needs_dataset_text_field_or_data_formatter_template():
"""Ensure we need to set the dataset text field if packing is False"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
# One of dataset_text_field or data_formatter_template should be set
data_args.dataset_text_field = None
data_args.data_formatter_template = None

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
Expand Down
66 changes: 66 additions & 0 deletions tests/utils/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
import datasets
import pytest

# First Party
from tests.data import TWITTER_COMPLAINTS_DATA

# Local
from tuning.utils import data_utils


def test_apply_custom_formatting_template():
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
)
formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template(
json_dataset, template
)
# a new dataset_text_field is created in Dataset
assert dataset_text_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][dataset_text_field] == expected_response


def test_apply_custom_formatting_template_adds_eos_token():
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaintEOS"
)
formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template(
json_dataset, template, "EOS"
)
# a new dataset_text_field is created in Dataset
assert dataset_text_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][dataset_text_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
with pytest.raises(KeyError):
data_utils.apply_custom_formatting_template(json_dataset, template, "EOS")
23 changes: 20 additions & 3 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,34 @@ class ModelArguments:
@dataclass
class DataArguments:
training_data_path: str = field(
default=None, metadata={"help": "Path to the training data in JSONL format."}
default=None,
metadata={"help": "Path to the training data in JSON/JSONL format."},
)
response_template: str = field(
default=None,
metadata={"help": "Response template, separator to train on completions only"},
)
dataset_text_field: str = field(
default=None, metadata={"help": "Training dataset text field"}
default=None,
metadata={
"help": "Training dataset text field containing single sequence. \
Either the dataset_text_field \
or data_formatter_template need to be supplied."
},
)
validation_data_path: str = field(
default=None, metadata={"help": "Path to the validation data in JSONL format."}
default=None,
metadata={"help": "Path to the validation data in JSON/JSONL format."},
)
data_formatter_template: str = field(
default=None,
metadata={
"help": "formatter template to format a single sequence \
from each instance in JSONL files. \
Keys of JSON can be referred to as {{key}} in template. \
Either the dataset_text_field \
or data_formatter_template needs to be supplied."
},
)


Expand Down
Loading

0 comments on commit e94920a

Please sign in to comment.