Skip to content

Commit

Permalink
Rename loader to processor
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
  • Loading branch information
dushyantbehl committed Nov 28, 2024
1 parent 13aa6b6 commit 70252af
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 40 deletions.
14 changes: 7 additions & 7 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

# Local
from tuning.config import configs
from tuning.data.data_config import DataLoaderConfig, DataSetConfig
from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig
from tuning.data.data_preprocessing_utils import (
combine_sequence,
get_data_collator,
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
def test_load_dataset_with_datafile(datafile, column_names):
"""Ensure that both dataset is loaded with datafile."""
processor = get_datapreprocessor(
dataloaderconfig=DataLoaderConfig(), tokenizer=None
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=None, splitName="train", datafile=datafile
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigna
"""Ensure that both dataset is loaded with datafile."""
datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile])
processor = get_datapreprocessor(
dataloaderconfig=DataLoaderConfig(), tokenizer=None
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
Expand All @@ -185,7 +185,7 @@ def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname):
"""Ensure that both datasetconfig and datafile cannot be passed."""
datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile])
processor = get_datapreprocessor(
dataloaderconfig=DataLoaderConfig(), tokenizer=None
processor_config=DataPreProcessorConfig(), tokenizer=None
)
with pytest.raises(ValueError):
processor.load_dataset(
Expand All @@ -196,7 +196,7 @@ def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname):
def test_load_dataset_without_dataconfig_and_datafile():
"""Ensure that both datasetconfig and datafile cannot be None."""
processor = get_datapreprocessor(
dataloaderconfig=DataLoaderConfig(), tokenizer=None
processor_config=DataPreProcessorConfig(), tokenizer=None
)
with pytest.raises(ValueError):
processor.load_dataset(datasetconfig=None, splitName="train", datafile=None)
Expand Down Expand Up @@ -623,10 +623,10 @@ def test_process_dataargs_pretokenized(data_args):
)
def test_process_dataset_configs(datafile, column_names, datasetconfigname):
"""Test process_dataset_configs for expected output."""
dataloaderconfig = DataLoaderConfig()
dataprocessor_config = DataPreProcessorConfig()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
processor = HFBasedDataPreProcessor(
dataloaderconfig=dataloaderconfig,
processor_config=dataprocessor_config,
tokenizer=tokenizer,
)
datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])]
Expand Down
2 changes: 1 addition & 1 deletion tests/predefined_data_configs/apply_custom_template.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dataloader:
dataprocessor:
type: default
datasets:
- name: apply_custom_data_template
Expand Down
2 changes: 1 addition & 1 deletion tests/predefined_data_configs/pretokenized_json_data.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dataloader:
dataprocessor:
type: default
datasets:
- name: pretokenized_dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
dataloader:
dataprocessor:
type: default
datasets:
- name: text_dataset_input_output_masking
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_instruction_masking
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
Expand Down
20 changes: 10 additions & 10 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class DataSetConfig:


@dataclass
class DataLoaderConfig:
class DataPreProcessorConfig:
type: Optional[str] = "default"


@dataclass
class DataConfig:
dataloader: DataLoaderConfig
dataprocessor: DataPreProcessorConfig
datasets: List[DataSetConfig]


Expand Down Expand Up @@ -102,15 +102,15 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
return c


def _validate_dataloader_config(dataloader_config) -> DataLoaderConfig:
kwargs = dataloader_config
c = DataLoaderConfig()
assert isinstance(kwargs, dict), "dataloader in data_config needs to be a dict"
def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConfig:
kwargs = dataprocessor_config
c = DataPreProcessorConfig()
assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict"
return c


def validate_data_config(dataconfig: DataConfig):
_validate_dataloader_config(dataconfig.dataloader)
_validate_dataprocessor_config(dataconfig.dataprocessor)
for d in dataconfig.datasets:
_validate_dataset_config(d)

Expand All @@ -127,8 +127,8 @@ def load_and_validate_data_config(data_config_file: str) -> DataConfig:
datasets = []
for d in raw_data["datasets"]:
datasets.append(_validate_dataset_config(d))
if "dataloader" in raw_data:
dataloader = _validate_dataloader_config(raw_data["dataloader"])
if "dataprocessor" in raw_data:
dataprocessor = _validate_dataprocessor_config(raw_data["dataprocessor"])

data_config = DataConfig(dataloader=dataloader, datasets=datasets)
data_config = DataConfig(dataprocessor=dataprocessor, datasets=datasets)
return data_config
2 changes: 1 addition & 1 deletion tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def apply_custom_data_formatting_template(


AVAILABLE_DATA_HANDLERS = {
"tokenize_and_apply_instruction_masking": tokenize_and_apply_input_masking,
"tokenize_and_apply_input_masking": tokenize_and_apply_input_masking,
"apply_dataset_formatting": apply_dataset_formatting,
"apply_custom_data_formatting_template": apply_custom_data_formatting_template,
}
20 changes: 10 additions & 10 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch

# Local
from tuning.data.data_config import DataConfig, DataLoaderConfig, DataSetConfig
from tuning.data.data_config import DataConfig, DataPreProcessorConfig, DataSetConfig
from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS
from tuning.utils.utils import get_extension, get_loader_for_filepath

Expand All @@ -35,12 +35,12 @@ class DataPreProcessor(ABC):

tokenizer = None
data_config: DataConfig = None
dataloaderconfig: DataLoaderConfig = None
processor_config: DataPreProcessorConfig = None
registered_handlers: Dict[str, callable] = None

def __init__(self, dataloaderconfig: DataLoaderConfig, tokenizer: AutoTokenizer):
def __init__(self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer):
self.tokenizer = tokenizer
self.dataloaderconfig = dataloaderconfig
self.processor_config = processor_config

# Initialize other objects
self.registered_handlers = {}
Expand All @@ -67,10 +67,10 @@ def process_dataset_configs(
class HFBasedDataPreProcessor(DataPreProcessor):
def __init__(
self,
dataloaderconfig: DataLoaderConfig,
processor_config: DataPreProcessorConfig,
tokenizer: AutoTokenizer,
):
super().__init__(dataloaderconfig=dataloaderconfig, tokenizer=tokenizer)
super().__init__(processor_config=processor_config, tokenizer=tokenizer)

def load_dataset(
self,
Expand Down Expand Up @@ -224,12 +224,12 @@ def autoregister_available_handlers(processor: DataPreProcessor):


def get_datapreprocessor(
dataloaderconfig: DataLoaderConfig, tokenizer: AutoTokenizer
processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
) -> DataPreProcessor:
loader = dataloaderconfig.type
if loader == "default":
processor = processor_config.type
if processor == "default":
processor = HFBasedDataPreProcessor(
dataloaderconfig=dataloaderconfig,
processor_config=processor_config,
tokenizer=tokenizer,
)
else:
Expand Down
16 changes: 8 additions & 8 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tuning.config.configs import DataArguments, TrainingArguments
from tuning.data.data_config import (
DataHandlerConfig,
DataLoaderConfig,
DataPreProcessorConfig,
DataSetConfig,
load_and_validate_data_config,
)
Expand All @@ -46,9 +46,9 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]):
if not data:
return False
if isinstance(data, str):
# Create a data processor with default loader config
# Create a data processor with default processor config
processor = get_datapreprocessor(
dataloaderconfig=DataLoaderConfig(), tokenizer=None
processor_config=DataPreProcessorConfig(), tokenizer=None
)
data = processor.load_dataset(None, splitName="train[:1]", datafile=data)

Expand All @@ -62,7 +62,7 @@ def _process_dataconfig_file(
):
data_config = load_and_validate_data_config(data_args.data_config_path)
processor = get_datapreprocessor(
dataloaderconfig=data_config.dataloader, tokenizer=tokenizer
processor_config=data_config.dataprocessor, tokenizer=tokenizer
)
train_dataset = processor.process_dataset_configs(data_config.datasets)

Expand Down Expand Up @@ -124,10 +124,10 @@ def process_dataargs(
data_args, tokenizer, train_args.packing, max_seq_length
)

# Create a data processor with default loader config
default_loader_config = DataLoaderConfig()
# Create a data processor with default processor config
default_processor_config = DataPreProcessorConfig()
data_processor = get_datapreprocessor(
dataloaderconfig=default_loader_config, tokenizer=tokenizer
processor_config=default_processor_config, tokenizer=tokenizer
)

# TODO: This check loads first slice of the dataset to view its columns
Expand Down Expand Up @@ -205,7 +205,7 @@ def process_dataargs(
}

handler = DataHandlerConfig(
"tokenize_and_apply_instruction_masking", arguments=kwargs
"tokenize_and_apply_input_masking", arguments=kwargs
)
handlers = [handler]

Expand Down

0 comments on commit 70252af

Please sign in to comment.