From 997704bbaad9d1891a9489e0d0d2f2edb0b7ae72 Mon Sep 17 00:00:00 2001 From: Salar Date: Wed, 10 Apr 2024 16:00:37 +0000 Subject: [PATCH] #7273: Replace greedy decoding in Falcon7b demo with top-k-top-p sampling Signed-off-by: Salar --- models/demos/falcon7b/demo/demo.py | 56 ++++++++++++++++++------------ 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/models/demos/falcon7b/demo/demo.py b/models/demos/falcon7b/demo/demo.py index 9ca439e8439..4204e13d320 100644 --- a/models/demos/falcon7b/demo/demo.py +++ b/models/demos/falcon7b/demo/demo.py @@ -7,10 +7,12 @@ from functools import partial import tt_lib import torch +import torch.nn.functional as F from loguru import logger import time from pathlib import Path from transformers import AutoTokenizer +from transformers.generation.utils import top_k_top_p_filtering import os from tqdm import tqdm from models.utility_functions import is_wormhole_b0 @@ -100,6 +102,29 @@ def update_model_config(model, model_config_str): model.model_config.update(get_model_config(model_config_str)) +def top_pk_logits(logits, p=0.9, k=10, temperature=1.0, return_probs=False): + next_token_logscores = top_k_top_p_filtering(logits, top_k=k, top_p=p) + probs = F.softmax(next_token_logscores / temperature, dim=-1) + token = torch.multinomial(probs, num_samples=1).squeeze(-1) + if return_probs: + return token, probs + else: + return token + + +def top_pk_logits_efficient(logits, p=0.9, k=10, temperature=1.0, return_probs=False): + # do not keep the entire vocab size after top k. Instead, keep the k size tensor and record the associated indices + top_k_values, top_k_indices = torch.topk(logits, k=k) + top_p_values = top_k_top_p_filtering(top_k_values, top_p=p) + probs = F.softmax(top_p_values / temperature, dim=-1) + top_k_id = torch.multinomial(probs, num_samples=1).squeeze(-1) + token = top_k_indices.gather(-1, top_k_id.unsqueeze(-1)).squeeze(-1) + if return_probs: + return token, (probs, top_k_indices) + else: + return token + + def run_falcon_demo_kv( user_input, model_version, @@ -110,6 +135,7 @@ def run_falcon_demo_kv( model_location_generator, device, perf_mode=False, + greedy_sampling=False, ): torch.manual_seed(0) @@ -188,9 +214,7 @@ def run_falcon_demo_kv( ### First prefill run with compile ### logger.info("Running 1st run prefill stage with compile...") - post_processor = partial(post_process) use_cache = True - output_ids = torch.zeros(num_users, 1, dtype=torch.int64) time_prefill_compile = 0 N = num_users if not perf_mode else 1 for user_id in tqdm(range(N)): @@ -220,14 +244,8 @@ def run_falcon_demo_kv( if tt_prefill_attention_mask is not None: tt_prefill_attention_mask[0].deallocate() - logits = tt2torch_tensor(tt_logits[0]).squeeze(1) tt_logits[0].deallocate() - user_output_ids = post_processor(logits=logits, index=num_input_tokens - 1) - output_ids[user_id] = user_output_ids - - generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1) - tt_lib.device.Synchronize(device) logger.info("Finished 1st run prefill stage with compile!") @@ -237,10 +255,7 @@ def run_falcon_demo_kv( # Update model config update_model_config(tt_FalconCausalLM_singlelayer, model_config_strs_prefill_decode[1]) - decode_ids = torch.zeros(batch_size, 1, dtype=torch.int64) - - for user_id, output_id in enumerate(output_ids): - decode_ids[user_id] = output_id + decode_ids = torch.randint(low=0, high=configuration.vocab_size - 1, size=(batch_size, 1), dtype=torch.int64) prompt_is_done = [False for _ in range(num_users)] @@ -271,23 +286,14 @@ def run_falcon_demo_kv( if tt_decode_attention_mask is not None: tt_decode_attention_mask[0].deallocate() - logits = tt2torch_tensor(tt_logits[0]).squeeze(1) tt_logits[0].deallocate() - decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1) - - generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1) - logger.info("Finished 1st run decode stage with compile!") tt_lib.device.Synchronize(device) - del user_output_ids - del output_ids - del logits del tt_logits del tt_prefill_embeddings del tt_prefill_attention_mask - del generated_ids del decode_ids del tt_decode_embeddings del tt_FalconCausalLM_singlelayer @@ -410,7 +416,10 @@ def run_falcon_demo_kv( tt_logits[0].deallocate() if not perf_mode: - decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1) + if greedy_sampling: + decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1) + else: + decode_ids = top_pk_logits_efficient(logits.reshape(batch_size, -1)).reshape(batch_size, 1) for user_id, user_decode_id in enumerate(decode_ids[:num_users]): if user_decode_id == END_OF_TEXT: @@ -480,8 +489,10 @@ def run_falcon_demo_kv( # Option to measure perf using max seq length (with invalid outputs) @pytest.mark.parametrize("perf_mode", (False,)) +@pytest.mark.parametrize("greedy_sampling", (False,)) def test_demo( perf_mode, + greedy_sampling, user_input, model_location_generator, device, @@ -505,4 +516,5 @@ def test_demo( model_location_generator=model_location_generator, device=device, perf_mode=perf_mode, + greedy_sampling=greedy_sampling, )