-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0df67e8
commit 7ae3122
Showing
7 changed files
with
1,718 additions
and
1 deletion.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# t5 Model for Conditional Generation | ||
|
||
## Introduction | ||
Text-To-Text Transfer Transformer (T5) is an encoder-decoder model which reframes all NLP tasks into a unified text-to-text-format where the input and output are always text strings. | ||
With this model, tasks are seamlessly executed by adding a specific prefix to the input. For instance, summarization tasks begin with 'summarize:'. This flexibility enables T5 to excel across a wide range of tasks. | ||
|
||
## Details | ||
|
||
The entry point to the T5 model is `t5_for_conditional_generation` in `ttnn_optimized_functional_t5.py`. The model picks up certain configs and weights from huggingface pretrained model. `t5-small` and `google/flan-t5-small` versions from huggingface are used as reference. | ||
|
||
In this demo, the model accepts input text, and it provides a summarized version of the input text. | ||
|
||
## Inputs | ||
Inputs for the demo are provided from `input_data.json` by default. If you need to change the inputs or provide a different path, modify the input_path parameter in the command accordingly. We recommend against modifying the input_data.json file directly. | ||
|
||
## How to Run | ||
|
||
- Use the following command to run T5 for conditional generation demo using `t5-small`. | ||
``` | ||
pytest --disable-warnings --input-path="models/demos/grayskull/t5/demo/input_data.json" models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-t5-small] | ||
``` | ||
|
||
- Alternatively, use the following command to run T5 for conditional generation demo using `google/flan-t5-small`. | ||
``` | ||
pytest --disable-warnings --input-path="models/demos/grayskull/t5/demo/input_data.json" models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-google/flan-t5-small] | ||
``` | ||
|
||
- If you wish to run the demo with a different input file, replace <address_to_your_json_file.json> with the path to your JSON file in the following command: | ||
``` | ||
pytest --disable-warnings --input-path=<address_to_your_json_file.json> models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-t5-small] | ||
``` | ||
|
||
Second demo is designed to run with `openai/summarize_from_feedback` dataset. The dataset includes human feedback which is used as input text to the model and summary of the feedback is used to validate the model. | ||
|
||
- Use the following command to run the second demo of T5 using `t5-small` variant for summarize the input text demo. | ||
``` | ||
pytest --disable-warnings models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize_dataset[8-128-64-t5-small] | ||
``` | ||
|
||
- Alternatively, use the following command to run the second demo of T5 using `google/flan-t5-small` variant for summarize the input text demo. | ||
``` | ||
pytest --disable-warnings models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize_dataset[8-128-64-google/flan-t5-small] | ||
``` | ||
|
||
## Results | ||
The input is fed into the T5 model for conditional generation, and the output will be a summarized and simplified version of the input text. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import pytest | ||
import torch | ||
import evaluate | ||
from loguru import logger | ||
from datasets import load_dataset | ||
from models.generation_utils import get_logits_processor | ||
import ttnn | ||
|
||
|
||
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config | ||
from models.demos.grayskull.t5.tt import ttnn_functional_t5 | ||
from models.demos.grayskull.t5.tt import ttnn_optimized_functional_t5 | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
|
||
from models.utility_functions import ( | ||
disable_compilation_reports, | ||
disable_persistent_kernel_cache, | ||
profiler, | ||
) | ||
|
||
|
||
def load_inputs(input_path, batch): | ||
with open(input_path) as f: | ||
input_data = json.load(f) | ||
assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." | ||
|
||
input_text = [] | ||
for i in range(batch): | ||
input_text.append(input_data[i]["content"]) | ||
|
||
return input_text | ||
|
||
|
||
def run_generate(input_ids, model, config, parameters, device, max_tokens, batch_size): | ||
tt_model = ttnn_optimized_functional_t5 | ||
|
||
logits_processor = get_logits_processor(input_ids, config) | ||
|
||
decoder_input_ids = model.generation_config.pad_token_id * torch.ones(batch_size, input_ids.shape[-1]).to( | ||
torch.long | ||
) | ||
|
||
input_ids = ttnn.from_torch(input_ids, device=device) | ||
|
||
profiler.start(f"inference_time") | ||
for iteration in range(max_tokens): | ||
decoder_input_ids = ttnn.from_torch(decoder_input_ids) | ||
decoder_input_ids = ttnn.to_device(decoder_input_ids, device) | ||
|
||
tt_output, encoder_hidden_states = tt_model.t5_for_conditional_generation( | ||
config, | ||
input_ids, | ||
decoder_input_ids, | ||
parameters=parameters, | ||
) | ||
tt_output = ttnn.from_device(tt_output) | ||
next_token_logits = ttnn.to_torch(tt_output) | ||
|
||
next_tokens_scores = logits_processor(input_ids, next_token_logits) | ||
next_tokens = torch.argmax(next_tokens_scores, dim=-1) | ||
|
||
decoder_input_ids = ttnn.from_device(decoder_input_ids) | ||
decoder_input_ids = ttnn.to_torch(decoder_input_ids) | ||
decoder_input_ids[:, iteration + 1] = next_tokens[:, iteration] | ||
|
||
profiler.end(f"inference_time") | ||
|
||
return decoder_input_ids | ||
|
||
|
||
def run_summarization_inference(input_path, device, batch_size, sequence_length, max_tokens, model_name): | ||
config = T5Config.from_pretrained(model_name) | ||
model = T5ForConditionalGeneration.from_pretrained(model_name).eval() | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32) | ||
|
||
input_sentance = load_inputs(input_path, batch_size) | ||
|
||
profiler.start(f"preprocessing_input") | ||
input_ids = tokenizer( | ||
input_sentance, | ||
padding="max_length", | ||
max_length=sequence_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
).input_ids | ||
|
||
profiler.end(f"preprocessing_input") | ||
|
||
tt_model_name = "ttnn_optimized_" + model_name | ||
|
||
decoded_tt_output = [] | ||
|
||
convert_to_ttnn = ttnn_optimized_functional_t5.convert_to_ttnn | ||
|
||
custom_preprocessor = ttnn_optimized_functional_t5.custom_preprocessor | ||
|
||
profiler.start(f"preprocessing_parameter") | ||
parameters = preprocess_model_parameters( | ||
model_name=tt_model_name, | ||
initialize_model=lambda: model, | ||
convert_to_ttnn=convert_to_ttnn, | ||
custom_preprocessor=custom_preprocessor, | ||
device=device, | ||
) | ||
profiler.end(f"preprocessing_parameter") | ||
|
||
tt_output = run_generate( | ||
input_ids, | ||
model, | ||
config, | ||
parameters, | ||
device, | ||
max_tokens, | ||
batch_size, | ||
) | ||
|
||
profiler.start(f"post_processing_output_to_string") | ||
for batch in range(batch_size): | ||
output = tokenizer.decode(tt_output[batch], skip_special_tokens=True) | ||
decoded_tt_output.append(output) | ||
profiler.end(f"post_processing_output_to_string") | ||
|
||
for i in range(batch_size): | ||
logger.info( | ||
f"------------------------------------------------------------------------------------------------------------------------" | ||
) | ||
logger.info(f"Input text {i} >> {input_sentance[i]}") | ||
logger.info(f"Output text {i} >> {decoded_tt_output[i]}") | ||
logger.info("") | ||
|
||
measurements = { | ||
"preprocessing_parameter": profiler.get("preprocessing_parameter"), | ||
"preprocessing_input": profiler.get("preprocessing_input"), | ||
"inference_time": profiler.get("inference_time"), | ||
"post_processing": profiler.get("post_processing_output_to_string"), | ||
} | ||
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") | ||
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") | ||
logger.info(f"inference_time: {measurements['inference_time']} s") | ||
logger.info(f"post_processing : {measurements['post_processing']} s") | ||
|
||
return measurements | ||
|
||
|
||
def run_summarization_dataset_inference(device, batch_size, sequence_length, max_tokens, model_name): | ||
config = T5Config.from_pretrained(model_name) | ||
model = T5ForConditionalGeneration.from_pretrained(model_name).eval() | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32) | ||
|
||
dataset = load_dataset("openai/summarize_from_feedback", "axis") | ||
dataset = dataset.shuffle(seed=19) | ||
bert_score = evaluate.load("bertscore") | ||
|
||
validation_split = dataset["validation"]["info"] | ||
reference_split = dataset["validation"]["summary"] | ||
|
||
input_sentance = [] | ||
references = [] | ||
|
||
for i in range(batch_size): | ||
references.append(reference_split[i]["text"][1:]) | ||
input_sentance.append(f"summarize: {validation_split[i]['post']}") | ||
|
||
profiler.start(f"preprocessing_input") | ||
input_ids = tokenizer( | ||
input_sentance, | ||
padding="max_length", | ||
max_length=sequence_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
).input_ids | ||
profiler.end(f"preprocessing_input") | ||
|
||
tt_model_name = "ttnn_optimized_" + model_name | ||
|
||
decoded_tt_output = [] | ||
|
||
convert_to_ttnn = ttnn_optimized_functional_t5.convert_to_ttnn | ||
|
||
custom_preprocessor = ttnn_optimized_functional_t5.custom_preprocessor | ||
|
||
profiler.start(f"preprocessing_parameter") | ||
parameters = preprocess_model_parameters( | ||
model_name=tt_model_name, | ||
initialize_model=lambda: model, | ||
convert_to_ttnn=convert_to_ttnn, | ||
custom_preprocessor=custom_preprocessor, | ||
device=device, | ||
) | ||
profiler.end(f"preprocessing_parameter") | ||
|
||
tt_output = run_generate( | ||
input_ids, | ||
model, | ||
config, | ||
parameters, | ||
device, | ||
max_tokens, | ||
batch_size, | ||
) | ||
|
||
profiler.start(f"post_processing_output_to_string") | ||
for batch in range(batch_size): | ||
output = tokenizer.decode(tt_output[batch], skip_special_tokens=True) | ||
decoded_tt_output.append(output) | ||
profiler.end(f"post_processing_output_to_string") | ||
|
||
for i in range(batch_size): | ||
logger.info( | ||
f"------------------------------------------------------------------------------------------------------------------------" | ||
) | ||
logger.info(f"Input text {i} >> {input_sentance[i]}") | ||
logger.info(f"Output text {i} >> {decoded_tt_output[i]}") | ||
logger.info("") | ||
|
||
results = bert_score.compute(predictions=decoded_tt_output, references=references, lang="en") | ||
avg_f1 = sum(results["f1"]) / len(results["f1"]) | ||
logger.info("") | ||
logger.info(f"Average F1 score: {avg_f1}") | ||
|
||
measurements = { | ||
"preprocessing_parameter": profiler.get("preprocessing_parameter"), | ||
"preprocessing_input": profiler.get("preprocessing_input"), | ||
"inference_time": profiler.get("inference_time"), | ||
"post_processing": profiler.get("post_processing_output_to_string"), | ||
} | ||
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") | ||
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") | ||
logger.info(f"inference_time: {measurements['inference_time']} s") | ||
logger.info(f"post_processing : {measurements['post_processing']} s") | ||
|
||
return measurements | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("batch_size", "sequence_length", "max_tokens", "model_name"), | ||
( | ||
(8, 128, 64, "t5-small"), | ||
(8, 128, 64, "google/flan-t5-small"), | ||
), | ||
) | ||
def test_t5_demo_for_summarize(input_path, device, batch_size, sequence_length, max_tokens, model_name): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_summarization_inference(input_path, device, batch_size, sequence_length, max_tokens, model_name) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("batch_size", "sequence_length", "max_tokens", "model_name"), | ||
( | ||
(8, 128, 64, "t5-small"), | ||
(8, 128, 64, "google/flan-t5-small"), | ||
), | ||
) | ||
def test_t5_demo_for_summarize_dataset(device, batch_size, sequence_length, max_tokens, model_name): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_summarization_dataset_inference(device, batch_size, sequence_length, max_tokens, model_name) |
Oops, something went wrong.