Skip to content

Commit

Permalink
#5383: Modify Falcon7b demo to reduce number of decode compile passes…
Browse files Browse the repository at this point in the history
… (only need every 32 kvcache sizes)

Signed-off-by: Salar <skhorasgani@tenstorrent.com>
  • Loading branch information
skhorasganiTT committed Apr 9, 2024
1 parent e5d9c85 commit ea2760a
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,10 @@ def run_falcon_demo_kv(
for user_id, output_id in enumerate(output_ids):
decode_ids[user_id] = output_id

kv_cache_len = num_input_tokens # This will increment by one after each decode
prompt_is_done = [False for _ in range(num_users)]

time_decode_compile = 0
for output_token_index in tqdm(range(max_seq_len - num_input_tokens)):
for kv_cache_len in tqdm(range(num_input_tokens, max_seq_len, 32)):
time_decode_compile_start = time.time()
(
tt_decode_embeddings,
Expand Down Expand Up @@ -278,7 +277,6 @@ def run_falcon_demo_kv(
decode_ids = post_processor(logits=logits, index=...).reshape(batch_size, 1)

generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)
kv_cache_len += 1

logger.info("Finished 1st run decode stage with compile!")
tt_lib.device.Synchronize(device)
Expand Down

0 comments on commit ea2760a

Please sign in to comment.