Skip to content

Commit

Permalink
#7273: Replace greedy decoding in Falcon7b demo with top-k-top-p samp…
Browse files Browse the repository at this point in the history
…ling

Signed-off-by: Salar <skhorasgani@tenstorrent.com>
  • Loading branch information
skhorasganiTT committed Apr 10, 2024
1 parent ea2760a commit 997704b
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from functools import partial
import tt_lib
import torch
import torch.nn.functional as F
from loguru import logger
import time
from pathlib import Path
from transformers import AutoTokenizer
from transformers.generation.utils import top_k_top_p_filtering
import os
from tqdm import tqdm
from models.utility_functions import is_wormhole_b0
Expand Down Expand Up @@ -100,6 +102,29 @@ def update_model_config(model, model_config_str):
model.model_config.update(get_model_config(model_config_str))


def top_pk_logits(logits, p=0.9, k=10, temperature=1.0, return_probs=False):
next_token_logscores = top_k_top_p_filtering(logits, top_k=k, top_p=p)
probs = F.softmax(next_token_logscores / temperature, dim=-1)
token = torch.multinomial(probs, num_samples=1).squeeze(-1)
if return_probs:
return token, probs
else:
return token


def top_pk_logits_efficient(logits, p=0.9, k=10, temperature=1.0, return_probs=False):
# do not keep the entire vocab size after top k. Instead, keep the k size tensor and record the associated indices
top_k_values, top_k_indices = torch.topk(logits, k=k)
top_p_values = top_k_top_p_filtering(top_k_values, top_p=p)
probs = F.softmax(top_p_values / temperature, dim=-1)
top_k_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
token = top_k_indices.gather(-1, top_k_id.unsqueeze(-1)).squeeze(-1)
if return_probs:
return token, (probs, top_k_indices)
else:
return token


def run_falcon_demo_kv(
user_input,
model_version,
Expand All @@ -110,6 +135,7 @@ def run_falcon_demo_kv(
model_location_generator,
device,
perf_mode=False,
greedy_sampling=False,
):
torch.manual_seed(0)

Expand Down Expand Up @@ -188,9 +214,7 @@ def run_falcon_demo_kv(

### First prefill run with compile ###
logger.info("Running 1st run prefill stage with compile...")
post_processor = partial(post_process)
use_cache = True
output_ids = torch.zeros(num_users, 1, dtype=torch.int64)
time_prefill_compile = 0
N = num_users if not perf_mode else 1
for user_id in tqdm(range(N)):
Expand Down Expand Up @@ -220,14 +244,8 @@ def run_falcon_demo_kv(
if tt_prefill_attention_mask is not None:
tt_prefill_attention_mask[0].deallocate()

logits = tt2torch_tensor(tt_logits[0]).squeeze(1)
tt_logits[0].deallocate()

user_output_ids = post_processor(logits=logits, index=num_input_tokens - 1)
output_ids[user_id] = user_output_ids

generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1)

tt_lib.device.Synchronize(device)
logger.info("Finished 1st run prefill stage with compile!")

Expand All @@ -237,10 +255,7 @@ def run_falcon_demo_kv(
# Update model config
update_model_config(tt_FalconCausalLM_singlelayer, model_config_strs_prefill_decode[1])

decode_ids = torch.zeros(batch_size, 1, dtype=torch.int64)

for user_id, output_id in enumerate(output_ids):
decode_ids[user_id] = output_id
decode_ids = torch.randint(low=0, high=configuration.vocab_size - 1, size=(batch_size, 1), dtype=torch.int64)

prompt_is_done = [False for _ in range(num_users)]

Expand Down Expand Up @@ -271,23 +286,14 @@ def run_falcon_demo_kv(
if tt_decode_attention_mask is not None:
tt_decode_attention_mask[0].deallocate()

logits = tt2torch_tensor(tt_logits[0]).squeeze(1)
tt_logits[0].deallocate()

decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1)

generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)

logger.info("Finished 1st run decode stage with compile!")
tt_lib.device.Synchronize(device)

del user_output_ids
del output_ids
del logits
del tt_logits
del tt_prefill_embeddings
del tt_prefill_attention_mask
del generated_ids
del decode_ids
del tt_decode_embeddings
del tt_FalconCausalLM_singlelayer
Expand Down Expand Up @@ -410,7 +416,10 @@ def run_falcon_demo_kv(
tt_logits[0].deallocate()

if not perf_mode:
decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1)
if greedy_sampling:
decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1)
else:
decode_ids = top_pk_logits_efficient(logits.reshape(batch_size, -1)).reshape(batch_size, 1)

for user_id, user_decode_id in enumerate(decode_ids[:num_users]):
if user_decode_id == END_OF_TEXT:
Expand Down Expand Up @@ -480,8 +489,10 @@ def run_falcon_demo_kv(

# Option to measure perf using max seq length (with invalid outputs)
@pytest.mark.parametrize("perf_mode", (False,))
@pytest.mark.parametrize("greedy_sampling", (False,))
def test_demo(
perf_mode,
greedy_sampling,
user_input,
model_location_generator,
device,
Expand All @@ -505,4 +516,5 @@ def test_demo(
model_location_generator=model_location_generator,
device=device,
perf_mode=perf_mode,
greedy_sampling=greedy_sampling,
)

0 comments on commit 997704b

Please sign in to comment.