diff --git a/models/demos/grayskull/README.md b/models/demos/grayskull/README.md deleted file mode 100644 index 24bcf57ebac..00000000000 --- a/models/demos/grayskull/README.md +++ /dev/null @@ -1 +0,0 @@ -This is place for Grayskull models. diff --git a/models/demos/grayskull/t5/README.md b/models/demos/grayskull/t5/README.md new file mode 100644 index 00000000000..073550b03a7 --- /dev/null +++ b/models/demos/grayskull/t5/README.md @@ -0,0 +1,46 @@ +# t5 Model for Conditional Generation + +## Introduction +Text-To-Text Transfer Transformer (T5) is an encoder-decoder model which reframes all NLP tasks into a unified text-to-text-format where the input and output are always text strings. +With this model, tasks are seamlessly executed by adding a specific prefix to the input. For instance, summarization tasks begin with 'summarize:'. This flexibility enables T5 to excel across a wide range of tasks. + +## Details + +The entry point to the T5 model is `t5_for_conditional_generation` in `ttnn_optimized_functional_t5.py`. The model picks up certain configs and weights from huggingface pretrained model. `t5-small` and `google/flan-t5-small` versions from huggingface are used as reference. + +In this demo, the model accepts input text, and it provides a summarized version of the input text. + +## Inputs +Inputs for the demo are provided from `input_data.json` by default. If you need to change the inputs or provide a different path, modify the input_path parameter in the command accordingly. We recommend against modifying the input_data.json file directly. + +## How to Run + +- Use the following command to run T5 for conditional generation demo using `t5-small`. +``` +pytest --disable-warnings --input-path="models/demos/grayskull/t5/demo/input_data.json" models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-t5-small] +``` + +- Alternatively, use the following command to run T5 for conditional generation demo using `google/flan-t5-small`. +``` +pytest --disable-warnings --input-path="models/demos/grayskull/t5/demo/input_data.json" models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-google/flan-t5-small] +``` + +- If you wish to run the demo with a different input file, replace with the path to your JSON file in the following command: +``` +pytest --disable-warnings --input-path= models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize[8-128-64-t5-small] +``` + +Second demo is designed to run with `openai/summarize_from_feedback` dataset. The dataset includes human feedback which is used as input text to the model and summary of the feedback is used to validate the model. + +- Use the following command to run the second demo of T5 using `t5-small` variant for summarize the input text demo. +``` +pytest --disable-warnings models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize_dataset[8-128-64-t5-small] +``` + +- Alternatively, use the following command to run the second demo of T5 using `google/flan-t5-small` variant for summarize the input text demo. +``` +pytest --disable-warnings models/demos/grayskull/t5/demo/demo.py::test_t5_demo_for_summarize_dataset[8-128-64-google/flan-t5-small] +``` + +## Results +The input is fed into the T5 model for conditional generation, and the output will be a summarized and simplified version of the input text. diff --git a/models/demos/grayskull/t5/demo/demo.py b/models/demos/grayskull/t5/demo/demo.py new file mode 100644 index 00000000000..a0d25a25c98 --- /dev/null +++ b/models/demos/grayskull/t5/demo/demo.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import json +import pytest +import torch +import evaluate +from loguru import logger +from datasets import load_dataset +from models.generation_utils import get_logits_processor +import ttnn + + +from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config +from models.demos.grayskull.t5.tt import ttnn_functional_t5 +from models.demos.grayskull.t5.tt import ttnn_optimized_functional_t5 +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + disable_compilation_reports, + disable_persistent_kernel_cache, + profiler, +) + + +def load_inputs(input_path, batch): + with open(input_path) as f: + input_data = json.load(f) + assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." + + input_text = [] + for i in range(batch): + input_text.append(input_data[i]["content"]) + + return input_text + + +def run_generate(input_ids, model, config, parameters, device, max_tokens, batch_size): + tt_model = ttnn_optimized_functional_t5 + + logits_processor = get_logits_processor(input_ids, config) + + decoder_input_ids = model.generation_config.pad_token_id * torch.ones(batch_size, input_ids.shape[-1]).to( + torch.long + ) + + input_ids = ttnn.from_torch(input_ids, device=device) + + profiler.start(f"inference_time") + for iteration in range(max_tokens): + decoder_input_ids = ttnn.from_torch(decoder_input_ids) + decoder_input_ids = ttnn.to_device(decoder_input_ids, device) + + tt_output, encoder_hidden_states = tt_model.t5_for_conditional_generation( + config, + input_ids, + decoder_input_ids, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + next_token_logits = ttnn.to_torch(tt_output) + + next_tokens_scores = logits_processor(input_ids, next_token_logits) + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + decoder_input_ids = ttnn.from_device(decoder_input_ids) + decoder_input_ids = ttnn.to_torch(decoder_input_ids) + decoder_input_ids[:, iteration + 1] = next_tokens[:, iteration] + + profiler.end(f"inference_time") + + return decoder_input_ids + + +def run_summarization_inference(input_path, device, batch_size, sequence_length, max_tokens, model_name): + config = T5Config.from_pretrained(model_name) + model = T5ForConditionalGeneration.from_pretrained(model_name).eval() + tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32) + + input_sentance = load_inputs(input_path, batch_size) + + profiler.start(f"preprocessing_input") + input_ids = tokenizer( + input_sentance, + padding="max_length", + max_length=sequence_length, + truncation=True, + return_tensors="pt", + ).input_ids + + profiler.end(f"preprocessing_input") + + tt_model_name = "ttnn_optimized_" + model_name + + decoded_tt_output = [] + + convert_to_ttnn = ttnn_optimized_functional_t5.convert_to_ttnn + + custom_preprocessor = ttnn_optimized_functional_t5.custom_preprocessor + + profiler.start(f"preprocessing_parameter") + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=custom_preprocessor, + device=device, + ) + profiler.end(f"preprocessing_parameter") + + tt_output = run_generate( + input_ids, + model, + config, + parameters, + device, + max_tokens, + batch_size, + ) + + profiler.start(f"post_processing_output_to_string") + for batch in range(batch_size): + output = tokenizer.decode(tt_output[batch], skip_special_tokens=True) + decoded_tt_output.append(output) + profiler.end(f"post_processing_output_to_string") + + for i in range(batch_size): + logger.info( + f"------------------------------------------------------------------------------------------------------------------------" + ) + logger.info(f"Input text {i} >> {input_sentance[i]}") + logger.info(f"Output text {i} >> {decoded_tt_output[i]}") + logger.info("") + + measurements = { + "preprocessing_parameter": profiler.get("preprocessing_parameter"), + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements + + +def run_summarization_dataset_inference(device, batch_size, sequence_length, max_tokens, model_name): + config = T5Config.from_pretrained(model_name) + model = T5ForConditionalGeneration.from_pretrained(model_name).eval() + tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=32) + + dataset = load_dataset("openai/summarize_from_feedback", "axis") + dataset = dataset.shuffle(seed=19) + bert_score = evaluate.load("bertscore") + + validation_split = dataset["validation"]["info"] + reference_split = dataset["validation"]["summary"] + + input_sentance = [] + references = [] + + for i in range(batch_size): + references.append(reference_split[i]["text"][1:]) + input_sentance.append(f"summarize: {validation_split[i]['post']}") + + profiler.start(f"preprocessing_input") + input_ids = tokenizer( + input_sentance, + padding="max_length", + max_length=sequence_length, + truncation=True, + return_tensors="pt", + ).input_ids + profiler.end(f"preprocessing_input") + + tt_model_name = "ttnn_optimized_" + model_name + + decoded_tt_output = [] + + convert_to_ttnn = ttnn_optimized_functional_t5.convert_to_ttnn + + custom_preprocessor = ttnn_optimized_functional_t5.custom_preprocessor + + profiler.start(f"preprocessing_parameter") + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=custom_preprocessor, + device=device, + ) + profiler.end(f"preprocessing_parameter") + + tt_output = run_generate( + input_ids, + model, + config, + parameters, + device, + max_tokens, + batch_size, + ) + + profiler.start(f"post_processing_output_to_string") + for batch in range(batch_size): + output = tokenizer.decode(tt_output[batch], skip_special_tokens=True) + decoded_tt_output.append(output) + profiler.end(f"post_processing_output_to_string") + + for i in range(batch_size): + logger.info( + f"------------------------------------------------------------------------------------------------------------------------" + ) + logger.info(f"Input text {i} >> {input_sentance[i]}") + logger.info(f"Output text {i} >> {decoded_tt_output[i]}") + logger.info("") + + results = bert_score.compute(predictions=decoded_tt_output, references=references, lang="en") + avg_f1 = sum(results["f1"]) / len(results["f1"]) + logger.info("") + logger.info(f"Average F1 score: {avg_f1}") + + measurements = { + "preprocessing_parameter": profiler.get("preprocessing_parameter"), + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements + + +@pytest.mark.parametrize( + ("batch_size", "sequence_length", "max_tokens", "model_name"), + ( + (8, 128, 64, "t5-small"), + (8, 128, 64, "google/flan-t5-small"), + ), +) +def test_t5_demo_for_summarize(input_path, device, batch_size, sequence_length, max_tokens, model_name): + disable_persistent_kernel_cache() + disable_compilation_reports() + + return run_summarization_inference(input_path, device, batch_size, sequence_length, max_tokens, model_name) + + +@pytest.mark.parametrize( + ("batch_size", "sequence_length", "max_tokens", "model_name"), + ( + (8, 128, 64, "t5-small"), + (8, 128, 64, "google/flan-t5-small"), + ), +) +def test_t5_demo_for_summarize_dataset(device, batch_size, sequence_length, max_tokens, model_name): + disable_persistent_kernel_cache() + disable_compilation_reports() + + return run_summarization_dataset_inference(device, batch_size, sequence_length, max_tokens, model_name) diff --git a/models/demos/grayskull/t5/demo/input_data.json b/models/demos/grayskull/t5/demo/input_data.json new file mode 100644 index 00000000000..44aa1dde956 --- /dev/null +++ b/models/demos/grayskull/t5/demo/input_data.json @@ -0,0 +1,32 @@ +[ + { + "content": "summarize: Artificial Intelligence (AI) has revolutionized various industries, including healthcare, finance, and transportation. AI-powered algorithms can analyze large datasets to identify patterns and make predictions with remarkable accuracy. In healthcare, AI helps diagnose diseases, personalize treatment plans, and improve patient outcomes. In finance, AI algorithms analyze market trends to make investment decisions and detect fraudulent activities. In transportation, self-driving cars equipped with AI technology promise to make roads safer and reduce accidents. As AI continues to advance, its potential to transform industries and improve lives is limitless." + }, + { + "content": "summarize: Climate change is one of the most pressing issues facing our planet today. The burning of fossil fuels, deforestation, and industrial activities have led to a rise in greenhouse gas emissions, resulting in global warming and extreme weather events. The impacts of climate change are already being felt, with rising sea levels, melting glaciers, and more frequent heatwaves. Urgent action is needed to mitigate the effects of climate change and transition to a sustainable future. This includes reducing carbon emissions, investing in renewable energy sources, and implementing policies to protect the environment. By taking decisive action now, we can work towards a healthier planet for future generations." + }, + { + "content": "summarize: Space exploration has captivated the imagination of humanity for centuries. From the first moon landing to the exploration of Mars, humans have always been drawn to the mysteries of the cosmos. Advances in technology have made space exploration more accessible, with missions to distant planets and asteroids becoming a reality. The search for extraterrestrial life and the possibility of colonizing other planets has fueled interest in space exploration. However, space exploration also presents challenges, including the high cost of missions and the risks to astronauts. Despite these challenges, the quest to explore the universe continues to inspire scientists, engineers, and explorers around the world." + }, + { + "content": "summarize: Mental health is a fundamental aspect of overall well-being, yet it remains a neglected and stigmatized issue in many societies. Mental illnesses, such as depression, anxiety, and schizophrenia, affect millions of people worldwide, impacting their quality of life, productivity, and relationships. Despite the prevalence of mental health disorders, access to mental health care and support services remains limited, particularly in low- and middle-income countries where resources are scarce. Addressing mental health requires a comprehensive approach that encompasses prevention, early intervention, treatment, and support. This includes promoting mental health literacy and awareness to reduce stigma and discrimination, integrating mental health services into primary care settings, and ensuring access to affordable and culturally appropriate treatments. Moreover, addressing social determinants of mental health, such as poverty, unemployment, and social isolation, is essential for promoting mental well-being and resilience. As the world grapples with the mental health impacts of the COVID-19 pandemic, there is a growing recognition of the need to prioritize mental health and invest in mental health infrastructure and services. From teletherapy and online support groups to community-based interventions and peer support networks, innovative approaches are emerging to address the diverse needs of individuals living with mental illness and promote mental health for all." + }, + { + "content": "summarize: Renewable energy sources such as solar, wind, and hydropower offer a sustainable alternative to fossil fuels and help reduce carbon emissions. Solar panels convert sunlight into electricity, while wind turbines harness the power of the wind to generate energy. Hydropower plants use flowing water to produce electricity, with minimal environmental impact. The transition to renewable energy is essential for mitigating climate change and reducing our dependence on finite resources. Investing in renewable energy infrastructure and technology is key to achieving a greener and more sustainable future for generations to come." + }, + { + "content": "summarize: Cultural diversity enriches societies by fostering creativity, innovation, and mutual understanding. Different cultures bring unique perspectives, traditions, and languages that contribute to the richness of human experience. Embracing cultural diversity promotes tolerance and respect for people of all backgrounds, helping to build more inclusive communities. However, cultural diversity also faces challenges, including discrimination, prejudice, and cultural assimilation. By celebrating diversity and promoting intercultural exchange, we can create a more harmonious and interconnected world where all cultures are valued and respected." + }, + { + "content": "summarize: The world's oceans are home to a vast array of marine life and play a crucial role in regulating the Earth's climate and providing essential resources. However, pollution, overfishing, and habitat destruction threaten the health of our oceans and marine ecosystems. Plastic waste, in particular, poses a significant threat to marine life, with millions of tons of plastic entering the ocean each year. Conservation efforts such as marine protected areas, sustainable fishing practices, and beach clean-up initiatives are essential for protecting marine biodiversity and preserving fragile ecosystems. By taking action to reduce pollution and protect marine habitats, we can ensure a healthy and thriving ocean for future generations." + }, + { + "content": "summarize: Sustainable agriculture is essential for ensuring food security, preserving natural resources, and mitigating the environmental impacts of agriculture. By adopting practices that prioritize environmental stewardship, social equity, and economic viability, sustainable agriculture aims to meet the needs of the present without compromising the ability of future generations to meet their own needs. This includes minimizing the use of synthetic pesticides and fertilizers, promoting crop diversity and rotation, conserving water and soil, and supporting small-scale farmers and rural communities. Sustainable agriculture also emphasizes the importance of agroecology, which seeks to understand the ecological processes that underpin agricultural systems and leverage them to enhance productivity and resilience. Agroecological principles such as biodiversity conservation, soil health, and natural pest control contribute to more resilient and sustainable farming practices that can withstand environmental challenges such as climate change and resource scarcity. In addition to environmental benefits, sustainable agriculture offers social and economic advantages by creating opportunities for rural livelihoods, promoting food sovereignty, and strengthening local food systems. By reconnecting consumers with the source of their food and supporting local farmers, sustainable agriculture fosters a sense of community and resilience in the face of global challenges." + }, + { + "content": "summarize: Cybersecurity is a growing concern in an increasingly interconnected world where cyber threats pose risks to individuals, businesses, and governments. Malicious actors exploit vulnerabilities in computer systems and networks to steal sensitive information, disrupt services, and commit financial fraud. Protecting against cyber threats requires robust cybersecurity measures, including firewalls, encryption, and multi-factor authentication. Cybersecurity awareness and education are also essential for helping individuals and organizations recognize and respond to cyber threats effectively. By prioritizing cybersecurity and implementing proactive measures, we can strengthen our defenses against cyber attacks and safeguard digital infrastructure." + }, + { + "content": "summarize: Globalization refers to the interconnectedness and interdependence of economies, cultures, and societies around the world. Advances in technology, transportation, and communication have facilitated the exchange of goods, ideas, and information on a global scale. While globalization has brought benefits such as increased trade, cultural exchange, and technological innovation, it has also raised concerns about inequality, environmental degradation, and cultural homogenization. Managing the challenges of globalization requires cooperation and coordination among nations to ensure that its benefits are shared equitably and its negative impacts are mitigated. By embracing the opportunities of globalization while addressing its challenges, we can build a more sustainable and inclusive global community." + } +] diff --git a/models/demos/grayskull/t5/reference/torch_functional_t5.py b/models/demos/grayskull/t5/reference/torch_functional_t5.py new file mode 100644 index 00000000000..cb3bcbfb4dc --- /dev/null +++ b/models/demos/grayskull/t5/reference/torch_functional_t5.py @@ -0,0 +1,426 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import math +import functools +from typing import Optional + +import torch +import transformers + +from models.experimental.functional_common.attention_mask_functions import ( + get_extended_attention_mask, + invert_attention_mask, +) + + +def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +def compute_bias(config, query_length, key_length, device=None, *, is_decoder, parameters): + """Compute binned relative position bias""" + if device is None: + device = parameters.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = _relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not is_decoder), + num_buckets=config.relative_attention_num_buckets, + max_distance=config.relative_attention_max_distance, + ) + values = torch.nn.functional.embedding( + relative_position_bucket, parameters.relative_attention_bias.weight + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + +def t5_layer_norm(config, hidden_states, *, weight): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + config.layer_norm_epsilon) + + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(weight.dtype) + + return weight * hidden_states + + +def gelu_new(input_tensor): + # TODO: compare against torch.nn.functional.gelu + return transformers.activations.NewGELUActivation()(input_tensor) + + +def get_activation_function(dense_act_fn): + if dense_act_fn == "relu": + return torch.nn.functional.relu + elif dense_act_fn == "gelu_new": + return gelu_new + else: + raise RuntimeError(f"Unsupported activation function: {dense_act_fn}") + + +def t5_dense_act_dense(config, hidden_states, parameters): + activation_function = get_activation_function(config.dense_act_fn) + + hidden_states = hidden_states @ parameters.wi.weight + hidden_states = activation_function(hidden_states) + hidden_states = hidden_states @ parameters.wo.weight + return hidden_states + + +def t5_dense_gated_act_dense(config, hidden_states, parameters): + activation_function = get_activation_function(config.dense_act_fn) + + hidden_gelu = hidden_states @ parameters.wi_0.weight + hidden_gelu = activation_function(hidden_gelu) + hidden_linear = hidden_states @ parameters.wi_1.weight + hidden_states = hidden_gelu * hidden_linear + + hidden_states = hidden_states @ parameters.wo.weight + return hidden_states + + +def t5_layer_ff(config, hidden_states, parameters): + forwarded_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + if config.is_gated_act: + forwarded_states = t5_dense_gated_act_dense(config, forwarded_states, parameters.DenseReluDense) + else: + forwarded_states = t5_dense_act_dense(config, forwarded_states, parameters.DenseReluDense) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +def t5_attention( + config, + hidden_states, + key_value_states=None, + mask=None, + layer_head_mask=None, + position_bias=None, + *, + is_decoder, + parameters, +): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length, _ = hidden_states.shape + + real_seq_length = seq_length + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states, head_size): + """projection""" + return states.view(batch_size, -1, config.num_heads, head_size).transpose(1, 2) + + def unshape(states, hidden_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, hidden_size) + + def project(hidden_states, weight): + hidden_size = weight.shape[-1] + head_size = hidden_size // config.num_heads + """projects hidden states correctly to key/query states""" + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(hidden_states @ weight, head_size) + return hidden_states + + # get query states + hidden_size = parameters.q.weight.shape[-1] + query_states = project(hidden_states, parameters.q.weight) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states if key_value_states is None else key_value_states, + parameters.k.weight, + ) + value_states = project( + hidden_states if key_value_states is None else key_value_states, + parameters.v.weight, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if "relative_attention_bias" in parameters: + position_bias = compute_bias( + config, real_seq_length, key_length, device=scores.device, is_decoder=is_decoder, parameters=parameters + ) + else: + position_bias = torch.zeros( + (1, config.num_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + + attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states), hidden_size) # (batch_size, seq_length, dim) + attn_output = attn_output @ parameters.o.weight + + return attn_output, position_bias + + +def t5_layer_self_attention( + config, + hidden_states, + attention_mask=None, + position_bias=None, + *, + is_decoder, + parameters, +): + normed_hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + attention_output, position_bias = t5_attention( + config, + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.SelfAttention, + ) + hidden_states = hidden_states + attention_output + return hidden_states, position_bias + + +def t5_layer_cross_attention( + config, hidden_states, key_value_states, attention_mask=None, position_bias=None, *, is_decoder, parameters +): + normed_hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + attention_output, position_bias = t5_attention( + config, + normed_hidden_states, + key_value_states, + mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.EncDecAttention, + ) + layer_output = hidden_states + attention_output + return layer_output, position_bias + + +def t5_block( + config, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + *, + is_decoder, + parameters, +): + hidden_states, position_bias = t5_layer_self_attention( + config, + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.layer[0], + ) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = encoder_hidden_states is not None + if do_cross_attention: + hidden_states, encoder_decoder_position_bias = t5_layer_cross_attention( + config, + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + is_decoder=is_decoder, + parameters=parameters.layer[1], + ) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Apply Feed Forward layer + hidden_states = t5_layer_ff(config, hidden_states, parameters.layer[-1]) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states, position_bias, encoder_decoder_position_bias + + +def t5_stack( + config, + input_ids, + shared_embedding_weight, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + *, + parameters, +): + input_shape = input_ids.size() + + hidden_states = torch.nn.functional.embedding(input_ids, shared_embedding_weight) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = seq_length + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=hidden_states.device) + + is_decoder = encoder_hidden_states is not None + extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape, is_decoder=is_decoder) + + if is_decoder: + if encoder_attention_mask is None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long + ) + + encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + position_bias = None + encoder_decoder_position_bias = None + + for block_parameters in parameters.block: + hidden_states, position_bias, encoder_decoder_position_bias = t5_block( + config, + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + is_decoder=is_decoder, + parameters=block_parameters, + ) + + hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.final_layer_norm.weight) + + return hidden_states + + +def t5_for_conditional_generation( + config, + input_ids: Optional[torch.LongTensor], + decoder_input_ids: Optional[torch.LongTensor], + parameters, + *, + encoder_last_hidden_state=None, +) -> torch.FloatTensor: + # Encode + if encoder_last_hidden_state is None: + encoder_last_hidden_state = t5_stack( + config, + input_ids=input_ids, + shared_embedding_weight=parameters.shared.weight, + parameters=parameters.encoder, + ) + + # Decode + sequence_output = t5_stack( + config, + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_last_hidden_state, + shared_embedding_weight=parameters.shared.weight, + parameters=parameters.decoder, + ) + + if config.tie_word_embeddings: + sequence_output *= config.d_model**-0.5 + lm_logits = sequence_output @ parameters.lm_head.weight + + return lm_logits, encoder_last_hidden_state diff --git a/models/demos/grayskull/t5/tt/ttnn_functional_t5.py b/models/demos/grayskull/t5/tt/ttnn_functional_t5.py new file mode 100644 index 00000000000..6435edd23b2 --- /dev/null +++ b/models/demos/grayskull/t5/tt/ttnn_functional_t5.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import functools +import math + +import torch + +import ttnn + +from models.experimental.functional_common.attention_mask_functions import ( + get_extended_attention_mask, + invert_attention_mask, +) + + +def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +def compute_bias(config, query_length, key_length, *, is_decoder, parameters): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = _relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not is_decoder), + num_buckets=config.relative_attention_num_buckets, + max_distance=config.relative_attention_max_distance, + ) + values = torch.nn.functional.embedding( + relative_position_bucket, parameters.relative_attention_bias.weight + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + +def t5_layer_norm(config, hidden_states, *, weight): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + squared_hidden_states = ttnn.pow(hidden_states, 2) + averaged_squared_hidden_states = ttnn.mean( + squared_hidden_states, + dim=-1, + keepdim=True, + ) + + variance = averaged_squared_hidden_states + config.layer_norm_epsilon + std = ttnn.rsqrt(variance) + + hidden_states = hidden_states * std + hidden_states = hidden_states * weight + + return hidden_states + + +def get_activation_function(dense_act_fn): + if dense_act_fn == "relu": + return ttnn.relu + elif dense_act_fn == "gelu_new": + return ttnn.gelu + else: + raise RuntimeError(f"Unsupported activation function: {dense_act_fn}") + + +def t5_dense_act_dense(config, hidden_states, parameters): + activation_function = get_activation_function(config.dense_act_fn) + + hidden_states = hidden_states @ parameters.wi.weight + hidden_states = activation_function(hidden_states) + hidden_states = hidden_states @ parameters.wo.weight + return hidden_states + + +def t5_dense_gated_act_dense(config, hidden_states, parameters): + activation_function = get_activation_function(config.dense_act_fn) + + hidden_gelu = hidden_states @ parameters.wi_0.weight + hidden_gelu = activation_function(hidden_gelu) + hidden_linear = hidden_states @ parameters.wi_1.weight + hidden_states = hidden_gelu * hidden_linear + + hidden_states = hidden_states @ parameters.wo.weight + return hidden_states + + +def t5_layer_ff(config, hidden_states, parameters): + forwarded_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + if config.is_gated_act: + forwarded_states = t5_dense_gated_act_dense(config, forwarded_states, parameters.DenseReluDense) + else: + forwarded_states = t5_dense_act_dense(config, forwarded_states, parameters.DenseReluDense) + hidden_states = ttnn.add(hidden_states, forwarded_states, memory_config=ttnn.L1_MEMORY_CONFIG) + return hidden_states + + +def t5_attention( + config, + hidden_states, + key_value_states=None, + mask=None, + layer_head_mask=None, + position_bias=None, + *, + is_decoder, + parameters, +): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length, _ = hidden_states.shape + + real_seq_length = seq_length + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states, head_size, is_key=False): + """projection""" + states = ttnn.to_layout(states, layout=ttnn.ROW_MAJOR_LAYOUT) + states = ttnn.reshape(states, (batch_size, seq_length, config.num_heads, head_size)) + if is_key: + states = ttnn.permute(states, (0, 2, 3, 1)) + else: + states = ttnn.permute(states, (0, 2, 1, 3)) + states = ttnn.to_layout(states, ttnn.TILE_LAYOUT) + return states + + def unshape(states, hidden_size): + """reshape""" + states = ttnn.permute(states, (0, 2, 1, 3)) + states = ttnn.to_layout(states, ttnn.ROW_MAJOR_LAYOUT) + states = ttnn.reshape(states, (batch_size, seq_length, hidden_size)) + states = ttnn.to_layout(states, ttnn.TILE_LAYOUT) + return states + + def project(hidden_states, weight, is_key=False): + hidden_size = weight.shape[-1] + head_size = hidden_size // config.num_heads + """projects hidden states correctly to key/query states""" + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(hidden_states @ weight, head_size, is_key=is_key) + return hidden_states + + # get query states + hidden_size = parameters.q.weight.shape[-1] + query_states = project(hidden_states, parameters.q.weight) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states if key_value_states is None else key_value_states, + parameters.k.weight, + is_key=True, + ) + value_states = project( + hidden_states if key_value_states is None else key_value_states, + parameters.v.weight, + ) + + # compute scores + scores = ttnn.matmul(query_states, key_states) + + if position_bias is None: + if "relative_attention_bias" in parameters: + position_bias = compute_bias( + config, real_seq_length, key_length, is_decoder=is_decoder, parameters=parameters + ) + else: + position_bias = torch.zeros((1, config.num_heads, real_seq_length, key_length), dtype=torch.float32) + + position_bias = ttnn.from_torch( + position_bias, dtype=ttnn.bfloat16, device=scores.device(), layout=ttnn.TILE_LAYOUT + ) + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + + attn_weights = ttnn.softmax(scores, dim=-1) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(ttnn.matmul(attn_weights, value_states), hidden_size) # (batch_size, seq_length, dim) + attn_output = attn_output @ parameters.o.weight + + return attn_output, position_bias + + +def t5_layer_self_attention( + config, + hidden_states, + attention_mask=None, + position_bias=None, + *, + is_decoder, + parameters, +): + normed_hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + attention_output, position_bias = t5_attention( + config, + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.SelfAttention, + ) + hidden_states = hidden_states + attention_output + return hidden_states, position_bias + + +def t5_layer_cross_attention( + config, hidden_states, key_value_states, attention_mask=None, position_bias=None, *, is_decoder, parameters +): + normed_hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + attention_output, position_bias = t5_attention( + config, + normed_hidden_states, + key_value_states, + mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.EncDecAttention, + ) + layer_output = hidden_states + attention_output + return layer_output, position_bias + + +def t5_block( + config, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + *, + is_decoder, + parameters, +): + hidden_states, position_bias = t5_layer_self_attention( + config, + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.layer[0], + ) + + do_cross_attention = encoder_hidden_states is not None + if do_cross_attention: + hidden_states, encoder_decoder_position_bias = t5_layer_cross_attention( + config, + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + is_decoder=is_decoder, + parameters=parameters.layer[1], + ) + + # Apply Feed Forward layer + hidden_states = t5_layer_ff(config, hidden_states, parameters.layer[-1]) + + return hidden_states, position_bias, encoder_decoder_position_bias + + +def t5_stack( + config, + input_ids, + shared_embedding_weight, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + *, + parameters, +): + input_shape = tuple(input_ids.shape) + + hidden_states = ttnn.embedding(input_ids, shared_embedding_weight, layout=ttnn.TILE_LAYOUT) + + is_decoder = encoder_hidden_states is not None + if attention_mask is None: + attention_mask = create_attention_mask(input_shape, config.num_heads, input_ids.device(), is_decoder=is_decoder) + if encoder_hidden_states is not None: + encoder_attention_mask = create_encoder_attention_mask(input_shape, config.num_heads, input_ids.device()) + else: + encoder_attention_mask = None + + position_bias = None + encoder_decoder_position_bias = None + + for block_parameters in parameters.block: + hidden_states, position_bias, encoder_decoder_position_bias = t5_block( + config, + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + is_decoder=is_decoder, + parameters=block_parameters, + ) + + hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.final_layer_norm.weight) + + return hidden_states + + +def t5_for_conditional_generation( + config, + input_ids: ttnn.Tensor, + decoder_input_ids: ttnn.Tensor, + parameters, + *, + encoder_last_hidden_state=None, +) -> ttnn.Tensor: + # Encode + if encoder_last_hidden_state is None: + encoder_last_hidden_state = t5_stack( + config, + input_ids=input_ids, + shared_embedding_weight=parameters.shared.weight, + parameters=parameters.encoder, + ) + + # Decode + sequence_output = t5_stack( + config, + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_last_hidden_state, + shared_embedding_weight=parameters.shared.weight, + parameters=parameters.decoder, + ) + + if config.tie_word_embeddings: + sequence_output *= config.d_model**-0.5 + + lm_logits = sequence_output @ parameters.lm_head.weight + + return lm_logits, encoder_last_hidden_state + + +@functools.lru_cache +def create_attention_mask(input_shape, num_heads, device, is_decoder): + batch_size, seq_length = input_shape + + attention_mask = torch.ones(batch_size, seq_length) + + extended_attention_mask = get_extended_attention_mask( + attention_mask, input_shape, is_decoder=is_decoder, dtype=torch.bfloat16 + ) + + extended_attention_mask = extended_attention_mask.expand((-1, num_heads, seq_length, -1)) + extended_attention_mask = ttnn.from_torch(extended_attention_mask) + extended_attention_mask = ttnn.to_layout(extended_attention_mask, ttnn.TILE_LAYOUT) + extended_attention_mask = ttnn.to_device(extended_attention_mask, device) + return extended_attention_mask + + +@functools.lru_cache +def create_encoder_attention_mask(input_shape, num_heads, device): + batch_size, seq_length = input_shape + + encoder_attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long) + + encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) + + encoder_extended_attention_mask = encoder_extended_attention_mask.expand((-1, num_heads, seq_length, -1)) + encoder_extended_attention_mask = ttnn.from_torch(encoder_extended_attention_mask) + encoder_extended_attention_mask = ttnn.to_layout(encoder_extended_attention_mask, ttnn.TILE_LAYOUT) + encoder_extended_attention_mask = ttnn.to_device(encoder_extended_attention_mask, device) + return encoder_extended_attention_mask + + +def convert_to_ttnn(model, name): + return "relative_attention_bias" not in name + + +def custom_preprocessor(model, name): + import transformers + from ttnn.model_preprocessing import preprocess_layernorm_parameter + + parameters = {} + if isinstance(model, transformers.models.t5.modeling_t5.T5LayerNorm): + parameters["weight"] = preprocess_layernorm_parameter(model.weight, dtype=ttnn.bfloat16) + + return parameters diff --git a/models/demos/grayskull/t5/tt/ttnn_optimized_functional_t5.py b/models/demos/grayskull/t5/tt/ttnn_optimized_functional_t5.py new file mode 100644 index 00000000000..45f2863f771 --- /dev/null +++ b/models/demos/grayskull/t5/tt/ttnn_optimized_functional_t5.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import functools +import math +from typing import Optional + +import torch + +import ttnn + +from models.experimental.functional_common.attention_mask_functions import ( + get_extended_attention_mask, + invert_attention_mask, +) + + +def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +def compute_bias(config, query_length, key_length, *, is_decoder, parameters): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = _relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not is_decoder), + num_buckets=config.relative_attention_num_buckets, + max_distance=config.relative_attention_max_distance, + ) + values = torch.nn.functional.embedding( + relative_position_bucket, parameters.relative_attention_bias.weight + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + +def t5_layer_norm(config, hidden_states, *, weight): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + # return ttnn.rms_norm(hidden_states, weight, epsilon=config.layer_norm_epsilon) + + squared_hidden_states = ttnn.pow(hidden_states, 2) + averaged_squared_hidden_states = ttnn.mean( + squared_hidden_states, + dim=-1, + keepdim=True, + ) + + variance = averaged_squared_hidden_states + config.layer_norm_epsilon + std = ttnn.rsqrt(variance) + + hidden_states = hidden_states * std + hidden_states = hidden_states * weight + + return hidden_states + + +def get_activation_function(dense_act_fn): + if dense_act_fn == "relu": + return ttnn.relu + elif dense_act_fn == "gelu_new": + return ttnn.gelu + else: + raise RuntimeError(f"Unsupported activation function: {dense_act_fn}") + + +def t5_dense_act_dense(config, hidden_states, parameters): + if config.dense_act_fn == "relu": + ff1_activation = "relu" + elif config.dense_act_fn == "gelu_new": + ff1_activation = "gelu" + else: + raise RuntimeError(f"Unsupported activation function: {config.dense_act_fn}") + + _, height, _ = hidden_states.shape + hidden_states = ttnn.linear( + hidden_states, + parameters.wi.weight, + dtype=ttnn.bfloat8_b, + activation=ff1_activation, + core_grid=ttnn.CoreGrid(y=height // 32, x=12), + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + hidden_states = ttnn.linear( + hidden_states, + parameters.wo.weight, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=9, x=12), + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + return hidden_states + + +def t5_dense_gated_act_dense(config, hidden_states, parameters): + activation_function = get_activation_function(config.dense_act_fn) + + hidden_gelu = hidden_states @ parameters.wi_0.weight + hidden_gelu = activation_function(hidden_gelu) + hidden_linear = hidden_states @ parameters.wi_1.weight + hidden_states = hidden_gelu * hidden_linear + + hidden_states = hidden_states @ parameters.wo.weight + return hidden_states + + +def t5_layer_ff(config, hidden_states, parameters): + forwarded_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + if config.is_gated_act: + forwarded_states = t5_dense_gated_act_dense(config, forwarded_states, parameters.DenseReluDense) + else: + forwarded_states = t5_dense_act_dense(config, forwarded_states, parameters.DenseReluDense) + hidden_states = ttnn.add(hidden_states, forwarded_states, memory_config=ttnn.L1_MEMORY_CONFIG) + return hidden_states + + +def t5_attention( + config, + hidden_states, + key_value_states=None, + mask=None, + layer_head_mask=None, + position_bias=None, + *, + is_decoder, + parameters, + num_cores_x=12, +): + batch_size, seq_length, _ = hidden_states.shape + + real_seq_length = seq_length + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + if key_value_states is None: + query_key_value_output = ttnn.linear( + hidden_states, + parameters.query_key_value.weight, + memory_config=ttnn.L1_MEMORY_CONFIG, + # dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value_output, + memory_config=ttnn.L1_MEMORY_CONFIG, + num_heads=config.num_heads, + ) + ttnn.deallocate(query_key_value_output) + + else: + query_proj = ttnn.linear( + hidden_states, + parameters.q.weight, + memory_config=ttnn.L1_MEMORY_CONFIG, + # dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + + key_value_proj = ttnn.linear( + key_value_states, + parameters.key_value.weight, + memory_config=ttnn.L1_MEMORY_CONFIG, + # dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + query, key, value = ttnn.transformer.split_query_key_value_and_split_heads( + query_proj, key_value_proj, num_heads=config.num_heads + ) + ttnn.deallocate(query_proj) + ttnn.deallocate(key_value_proj) + + scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + # core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + if position_bias is None: + if "relative_attention_bias" in parameters: + position_bias = compute_bias( + config, real_seq_length, key_length, is_decoder=is_decoder, parameters=parameters + ) + else: + position_bias = torch.zeros((1, config.num_heads, real_seq_length, key_length), dtype=torch.float32) + + position_bias = ttnn.from_torch( + position_bias, dtype=ttnn.bfloat16, device=scores.device(), layout=ttnn.TILE_LAYOUT + ) + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores = ttnn.add(scores, position_bias, memory_config=ttnn.L1_MEMORY_CONFIG) + + attn_weights = ttnn.softmax(scores, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + context_layer = ttnn.matmul( + attn_weights, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + # dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + ttnn.deallocate(attn_weights) + ttnn.deallocate(value) + + context_layer = ttnn.transformer.concatenate_heads( + context_layer, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + self_output = ttnn.linear( + context_layer, + parameters.o.weight, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + # core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + ttnn.deallocate(context_layer) + + return self_output, position_bias + + +def t5_layer_self_attention( + config, + hidden_states, + attention_mask=None, + position_bias=None, + *, + is_decoder, + parameters, +): + normed_hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + attention_output, position_bias = t5_attention( + config, + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.SelfAttention, + ) + hidden_states = ttnn.add(hidden_states, attention_output, memory_config=ttnn.L1_MEMORY_CONFIG) + return hidden_states, position_bias + + +def t5_layer_cross_attention( + config, hidden_states, key_value_states, attention_mask=None, position_bias=None, *, is_decoder, parameters +): + normed_hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.layer_norm.weight) + attention_output, position_bias = t5_attention( + config, + normed_hidden_states, + key_value_states, + mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.EncDecAttention, + ) + layer_output = ttnn.add(hidden_states, attention_output, memory_config=ttnn.L1_MEMORY_CONFIG) + return layer_output, position_bias + + +def t5_block( + config, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + *, + is_decoder, + parameters, +): + hidden_states, position_bias = t5_layer_self_attention( + config, + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + is_decoder=is_decoder, + parameters=parameters.layer[0], + ) + + do_cross_attention = encoder_hidden_states is not None + if do_cross_attention: + hidden_states, encoder_decoder_position_bias = t5_layer_cross_attention( + config, + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + is_decoder=is_decoder, + parameters=parameters.layer[1], + ) + + # Apply Feed Forward layer + hidden_states = t5_layer_ff(config, hidden_states, parameters.layer[-1]) + + return hidden_states, position_bias, encoder_decoder_position_bias + + +def t5_stack( + config, + input_ids, + shared_embedding_weight, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + *, + parameters, +): + input_shape = tuple(input_ids.shape) + + hidden_states = ttnn.embedding( + input_ids, shared_embedding_weight, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + is_decoder = encoder_hidden_states is not None + if attention_mask is None: + attention_mask = create_attention_mask(input_shape, config.num_heads, input_ids.device(), is_decoder=is_decoder) + if encoder_hidden_states is not None: + encoder_attention_mask = create_encoder_attention_mask(input_shape, config.num_heads, input_ids.device()) + else: + encoder_attention_mask = None + + position_bias = None + encoder_decoder_position_bias = None + + for block_parameters in parameters.block: + hidden_states, position_bias, encoder_decoder_position_bias = t5_block( + config, + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + is_decoder=is_decoder, + parameters=block_parameters, + ) + + hidden_states = t5_layer_norm(config, hidden_states, weight=parameters.final_layer_norm.weight) + + return hidden_states + + +def t5_for_conditional_generation( + config, + input_ids: Optional[torch.LongTensor], + decoder_input_ids: Optional[torch.LongTensor], + parameters, + *, + encoder_last_hidden_state=None, +) -> torch.FloatTensor: + # Encode + if encoder_last_hidden_state is None: + encoder_last_hidden_state = t5_stack( + config, + input_ids=input_ids, + shared_embedding_weight=parameters.shared.weight, + parameters=parameters.encoder, + ) + + # Decode + sequence_output = t5_stack( + config, + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_last_hidden_state, + shared_embedding_weight=parameters.shared.weight, + parameters=parameters.decoder, + ) + + if config.tie_word_embeddings: + sequence_output = ttnn.mul(sequence_output, config.d_model**-0.5, memory_config=ttnn.L1_MEMORY_CONFIG) + + lm_logits = ttnn.linear(sequence_output, parameters.lm_head.weight, memory_config=ttnn.L1_MEMORY_CONFIG) + + return lm_logits, encoder_last_hidden_state + + +@functools.lru_cache +def create_attention_mask(input_shape, num_heads, device, is_decoder): + batch_size, seq_length = input_shape + + attention_mask = torch.ones(batch_size, seq_length) + + extended_attention_mask = get_extended_attention_mask( + attention_mask, input_shape, is_decoder=is_decoder, dtype=torch.bfloat16 + ) + + extended_attention_mask = extended_attention_mask.expand((-1, num_heads, seq_length, -1)) + extended_attention_mask = ttnn.from_torch(extended_attention_mask) + extended_attention_mask = ttnn.to_layout(extended_attention_mask, ttnn.TILE_LAYOUT) + extended_attention_mask = ttnn.to_device(extended_attention_mask, device) + return extended_attention_mask + + +@functools.lru_cache +def create_encoder_attention_mask(input_shape, num_heads, device): + batch_size, seq_length = input_shape + + encoder_attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long) + + encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) + + encoder_extended_attention_mask = encoder_extended_attention_mask.expand((-1, num_heads, seq_length, -1)) + encoder_extended_attention_mask = ttnn.from_torch(encoder_extended_attention_mask) + encoder_extended_attention_mask = ttnn.to_layout(encoder_extended_attention_mask, ttnn.TILE_LAYOUT) + encoder_extended_attention_mask = ttnn.to_device(encoder_extended_attention_mask, device) + return encoder_extended_attention_mask + + +def convert_to_ttnn(model, name): + return "relative_attention_bias" not in name + + +def custom_preprocessor(model, name): + import transformers + from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_layernorm_parameter + + parameters = {} + if isinstance(model, transformers.models.t5.modeling_t5.T5LayerNorm): + parameters["weight"] = preprocess_layernorm_parameter(model.weight, dtype=ttnn.bfloat16) + + elif isinstance(model, transformers.models.t5.modeling_t5.T5Attention): + if "EncDecAttention" in name: + # Cross Attention + preprocessed_kv_weight = torch.cat([model.k.weight, model.v.weight], dim=0) + parameters = { + "q": {"weight": preprocess_linear_weight(model.q.weight, dtype=ttnn.bfloat16)}, + "key_value": {"weight": preprocess_linear_weight(preprocessed_kv_weight, dtype=ttnn.bfloat16)}, + } + else: + # Self Attention + preprocessed_qkv_weight = torch.cat([model.q.weight, model.k.weight, model.v.weight], dim=0) + parameters = { + "query_key_value": {"weight": preprocess_linear_weight(preprocessed_qkv_weight, dtype=ttnn.bfloat16)}, + "o": {"weight": preprocess_linear_weight(model.o.weight, dtype=ttnn.bfloat16)}, + } + if hasattr(model, "relative_attention_bias"): + parameters["relative_attention_bias"] = model.relative_attention_bias + if hasattr(model, "o"): + parameters["o"] = {"weight": preprocess_linear_weight(model.o.weight, dtype=ttnn.bfloat16)} + + return parameters