Skip to content

Commit

Permalink
#7414: add T5 GS demo
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-drazic committed Apr 12, 2024
1 parent 0df67e8 commit 7ae3122
Show file tree
Hide file tree
Showing 7 changed files with 1,718 additions and 1 deletion.
1 change: 0 additions & 1 deletion models/demos/grayskull/README.md

This file was deleted.

46 changes: 46 additions & 0 deletions models/demos/grayskull/t5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# t5 Model for Conditional Generation

## Introduction
Text-To-Text Transfer Transformer (T5) is an encoder-decoder model which reframes all NLP tasks into a unified text-to-text-format where the input and output are always text strings.
With this model, tasks are seamlessly executed by adding a specific prefix to the input. For instance, summarization tasks begin with 'summarize:'. This flexibility enables T5 to excel across a wide range of tasks.

## Details

The entry point to the T5 model is `t5_for_conditional_generation` in `ttnn_optimized_functional_t5.py`. The model picks up certain configs and weights from huggingface pretrained model. `t5-small` and `google/flan-t5-small` versions from huggingface are used as reference.

In this demo, the model accepts input text, and it provides a summarized version of the input text.

## Inputs
Inputs for the demo are provided from `input_data.json` by default. If you need to change the inputs or provide a different path, modify the input_path parameter in the command accordingly. We recommend against modifying the input_data.json file directly.

## How to Run

- Use the following command to run T5 for conditional generation demo using `t5-small`.
```
pytest --disable-warnings --input-path="models/demos/grayskull/t5/demo/input_data.json" models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-t5-small]
```

- Alternatively, use the following command to run T5 for conditional generation demo using `google/flan-t5-small`.
```
pytest --disable-warnings --input-path="models/demos/grayskull/t5/demo/input_data.json" models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-google/flan-t5-small]
```

- If you wish to run the demo with a different input file, replace <address_to_your_json_file.json> with the path to your JSON file in the following command:
```
pytest --disable-warnings --input-path=<address_to_your_json_file.json> models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-t5-small]
```

Second demo is designed to run with `openai/summarize_from_feedback` dataset. The dataset includes human feedback which is used as input text to the model and summary of the feedback is used to validate the model.

- Use the following command to run the second demo of T5 using `t5-small` variant for summarize the input text demo.
```
pytest --disable-warnings models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize_dataset[8-128-64-t5-small]
```

- Alternatively, use the following command to run the second demo of T5 using `google/flan-t5-small` variant for summarize the input text demo.
```
pytest --disable-warnings models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize_dataset[8-128-64-google/flan-t5-small]
```

## Results
The input is fed into the T5 model for conditional generation, and the output will be a summarized and simplified version of the input text.
265 changes: 265 additions & 0 deletions models/demos/grayskull/t5/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import json
import pytest
import torch
import evaluate
from loguru import logger
from datasets import load_dataset
from models.generation_utils import get_logits_processor
import ttnn


from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config
from models.demos.grayskull.t5.tt import ttnn_functional_t5
from models.demos.grayskull.t5.tt import ttnn_optimized_functional_t5
from ttnn.model_preprocessing import preprocess_model_parameters

from models.utility_functions import (
disable_compilation_reports,
disable_persistent_kernel_cache,
profiler,
)


def load_inputs(input_path, batch):
with open(input_path) as f:
input_data = json.load(f)
assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries."

input_text = []
for i in range(batch):
input_text.append(input_data[i]["content"])

return input_text


def run_generate(input_ids, model, config, parameters, device, max_tokens, batch_size):
tt_model = ttnn_optimized_functional_t5

logits_processor = get_logits_processor(input_ids, config)

decoder_input_ids = model.generation_config.pad_token_id * torch.ones(batch_size, input_ids.shape[-1]).to(
torch.long
)

input_ids = ttnn.from_torch(input_ids, device=device)

profiler.start(f"inference_time")
for iteration in range(max_tokens):
decoder_input_ids = ttnn.from_torch(decoder_input_ids)
decoder_input_ids = ttnn.to_device(decoder_input_ids, device)

tt_output, encoder_hidden_states = tt_model.t5_for_conditional_generation(
config,
input_ids,
decoder_input_ids,
parameters=parameters,
)
tt_output = ttnn.from_device(tt_output)
next_token_logits = ttnn.to_torch(tt_output)

next_tokens_scores = logits_processor(input_ids, next_token_logits)
next_tokens = torch.argmax(next_tokens_scores, dim=-1)

decoder_input_ids = ttnn.from_device(decoder_input_ids)
decoder_input_ids = ttnn.to_torch(decoder_input_ids)
decoder_input_ids[:, iteration + 1] = next_tokens[:, iteration]

profiler.end(f"inference_time")

return decoder_input_ids


def run_summarization_inference(input_path, device, batch_size, sequence_length, max_tokens, model_name):
config = T5Config.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32)

input_sentance = load_inputs(input_path, batch_size)

profiler.start(f"preprocessing_input")
input_ids = tokenizer(
input_sentance,
padding="max_length",
max_length=sequence_length,
truncation=True,
return_tensors="pt",
).input_ids

profiler.end(f"preprocessing_input")

tt_model_name = "ttnn_optimized_" + model_name

decoded_tt_output = []

convert_to_ttnn = ttnn_optimized_functional_t5.convert_to_ttnn

custom_preprocessor = ttnn_optimized_functional_t5.custom_preprocessor

profiler.start(f"preprocessing_parameter")
parameters = preprocess_model_parameters(
model_name=tt_model_name,
initialize_model=lambda: model,
convert_to_ttnn=convert_to_ttnn,
custom_preprocessor=custom_preprocessor,
device=device,
)
profiler.end(f"preprocessing_parameter")

tt_output = run_generate(
input_ids,
model,
config,
parameters,
device,
max_tokens,
batch_size,
)

profiler.start(f"post_processing_output_to_string")
for batch in range(batch_size):
output = tokenizer.decode(tt_output[batch], skip_special_tokens=True)
decoded_tt_output.append(output)
profiler.end(f"post_processing_output_to_string")

for i in range(batch_size):
logger.info(
f"------------------------------------------------------------------------------------------------------------------------"
)
logger.info(f"Input text {i} >> {input_sentance[i]}")
logger.info(f"Output text {i} >> {decoded_tt_output[i]}")
logger.info("")

measurements = {
"preprocessing_parameter": profiler.get("preprocessing_parameter"),
"preprocessing_input": profiler.get("preprocessing_input"),
"inference_time": profiler.get("inference_time"),
"post_processing": profiler.get("post_processing_output_to_string"),
}
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s")
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s")
logger.info(f"inference_time: {measurements['inference_time']} s")
logger.info(f"post_processing : {measurements['post_processing']} s")

return measurements


def run_summarization_dataset_inference(device, batch_size, sequence_length, max_tokens, model_name):
config = T5Config.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32)

dataset = load_dataset("openai/summarize_from_feedback", "axis")
dataset = dataset.shuffle(seed=19)
bert_score = evaluate.load("bertscore")

validation_split = dataset["validation"]["info"]
reference_split = dataset["validation"]["summary"]

input_sentance = []
references = []

for i in range(batch_size):
references.append(reference_split[i]["text"][1:])
input_sentance.append(f"summarize: {validation_split[i]['post']}")

profiler.start(f"preprocessing_input")
input_ids = tokenizer(
input_sentance,
padding="max_length",
max_length=sequence_length,
truncation=True,
return_tensors="pt",
).input_ids
profiler.end(f"preprocessing_input")

tt_model_name = "ttnn_optimized_" + model_name

decoded_tt_output = []

convert_to_ttnn = ttnn_optimized_functional_t5.convert_to_ttnn

custom_preprocessor = ttnn_optimized_functional_t5.custom_preprocessor

profiler.start(f"preprocessing_parameter")
parameters = preprocess_model_parameters(
model_name=tt_model_name,
initialize_model=lambda: model,
convert_to_ttnn=convert_to_ttnn,
custom_preprocessor=custom_preprocessor,
device=device,
)
profiler.end(f"preprocessing_parameter")

tt_output = run_generate(
input_ids,
model,
config,
parameters,
device,
max_tokens,
batch_size,
)

profiler.start(f"post_processing_output_to_string")
for batch in range(batch_size):
output = tokenizer.decode(tt_output[batch], skip_special_tokens=True)
decoded_tt_output.append(output)
profiler.end(f"post_processing_output_to_string")

for i in range(batch_size):
logger.info(
f"------------------------------------------------------------------------------------------------------------------------"
)
logger.info(f"Input text {i} >> {input_sentance[i]}")
logger.info(f"Output text {i} >> {decoded_tt_output[i]}")
logger.info("")

results = bert_score.compute(predictions=decoded_tt_output, references=references, lang="en")
avg_f1 = sum(results["f1"]) / len(results["f1"])
logger.info("")
logger.info(f"Average F1 score: {avg_f1}")

measurements = {
"preprocessing_parameter": profiler.get("preprocessing_parameter"),
"preprocessing_input": profiler.get("preprocessing_input"),
"inference_time": profiler.get("inference_time"),
"post_processing": profiler.get("post_processing_output_to_string"),
}
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s")
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s")
logger.info(f"inference_time: {measurements['inference_time']} s")
logger.info(f"post_processing : {measurements['post_processing']} s")

return measurements


@pytest.mark.parametrize(
("batch_size", "sequence_length", "max_tokens", "model_name"),
(
(8, 128, 64, "t5-small"),
(8, 128, 64, "google/flan-t5-small"),
),
)
def test_t5_demo_for_summarize(input_path, device, batch_size, sequence_length, max_tokens, model_name):
disable_persistent_kernel_cache()
disable_compilation_reports()

return run_summarization_inference(input_path, device, batch_size, sequence_length, max_tokens, model_name)


@pytest.mark.parametrize(
("batch_size", "sequence_length", "max_tokens", "model_name"),
(
(8, 128, 64, "t5-small"),
(8, 128, 64, "google/flan-t5-small"),
),
)
def test_t5_demo_for_summarize_dataset(device, batch_size, sequence_length, max_tokens, model_name):
disable_persistent_kernel_cache()
disable_compilation_reports()

return run_summarization_dataset_inference(device, batch_size, sequence_length, max_tokens, model_name)
Loading

0 comments on commit 7ae3122

Please sign in to comment.