Skip to content

Commit

Permalink
Refactor data util tests as data handler tests.
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
  • Loading branch information
dushyantbehl committed Nov 22, 2024
1 parent 43626ed commit 3bd42b5
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 82 deletions.
48 changes: 22 additions & 26 deletions tests/utils/test_data_utils.py → tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,39 @@
# https://spdx.dev/learn/handling-license-info/

# Third Party
from transformers import AutoTokenizer
import datasets
import pytest

# First Party
from tests.testdata import TWITTER_COMPLAINTS_DATA_JSONL
from tests.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL

# Local
from tuning.utils import data_utils
from tuning.data.data_handlers import apply_custom_data_formatting_template


def test_apply_custom_formatting_template():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = data_utils.apply_custom_formatting_template(
json_dataset, template, formatted_dataset_field
)
# a new dataset_text_field is created in Dataset
assert formatted_dataset_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_adds_eos_token():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
formatted_dataset = json_dataset.map(
apply_custom_data_formatting_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaintEOS"
)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = data_utils.apply_custom_formatting_template(
json_dataset, template, formatted_dataset_field, "EOS"
+ " \n\n ### Response: no complaint"
+ tokenizer.eos_token
)

# a new dataset_text_field is created in Dataset
assert formatted_dataset_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response
Expand All @@ -71,7 +61,13 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
formatted_dataset_field = "formatted_data_field"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(KeyError):
data_utils.apply_custom_formatting_template(
json_dataset, template, formatted_dataset_field, "EOS"
json_dataset.map(
apply_custom_data_formatting_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
File renamed without changes.
16 changes: 14 additions & 2 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from transformers import AutoTokenizer

# Local
from tuning.data.data_preprocessing_utils import combine_sequence
from tuning.utils.data_utils import custom_data_formatter
from tuning.data.data_preprocessing_utils import combine_sequence, custom_data_formatter


def tokenize_and_apply_input_masking(
Expand Down Expand Up @@ -71,6 +70,19 @@ def apply_custom_data_formatting_template(
template: str,
**kwargs,
):
"""Function to format datasets with Alpaca style / other templates.
Expects to be run as a HF Map API function.
Args:
element: the HF Dataset element loaded from a JSON or DatasetDict object.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}
formatted_dataset_field: Dataset_text_field
eos_token: string EOS token to be appended while formatting data to a single sequence.
Defaults to empty
Returns:
Formatted HF Dataset
"""

template += tokenizer.eos_token

# TODO: Eventually move the code here.
Expand Down
22 changes: 22 additions & 0 deletions tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# Standard
from typing import Callable, Optional
import re

# Third Party
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
Expand Down Expand Up @@ -184,3 +185,24 @@ def get_data_collator(
raise ValueError(
"Could not pick a data collator. Please refer to supported data formats"
)


def custom_data_formatter(element, template, formatted_dataset_field):
def replace_text(match_obj):
captured_groups = match_obj.groups()
if len(captured_groups) != 1:
raise ValueError(
"Unexpectedly captured multiple groups in template formatting"
)

index_object = captured_groups[0]
if index_object not in element:
raise KeyError("Requested template string is not a valid key in dict")

return element[index_object]

return {
formatted_dataset_field: re.sub(
r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template
)
}
6 changes: 4 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def train(
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)

data_preprocessor_time = time.time()
data_preprocessing_time = time.time()
(
formatted_train_dataset,
formatted_validation_dataset,
Expand All @@ -298,7 +298,9 @@ def train(
max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args)
additional_metrics["data_preprocessor_time"] = time.time() - data_preprocessor_time
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)

if framework is not None and framework.requires_agumentation:
model, (peft_config,) = framework.augmentation(
Expand Down
52 changes: 0 additions & 52 deletions tuning/utils/data_utils.py

This file was deleted.

0 comments on commit 3bd42b5

Please sign in to comment.