Skip to content

Commit

Permalink
#7379: Added proper testing method for program caching
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed May 16, 2024
1 parent 0c3704d commit 677b273
Showing 1 changed file with 31 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ def unpadding_test(
.to(ttl.tensor.Layout.TILE)
.to(device)
)
# breakpoint()
test_tensor_tt = ttl.tensor.nlp_kv_cache_load_slice(test_tensor, seq_len_start, seq_len_end)

test_tensor_pt = test_tensor_tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

# Pytorch reference
test_tensor_ref = inp[:, :, seq_len_start:seq_len_end]

return test_tensor_pt, test_tensor_ref, test_tensor_tt.memory_config()
return test_tensor_pt, test_tensor_ref, test_tensor_tt.memory_config(), device.num_program_cache_entries()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -85,14 +84,15 @@ def test_run_unpadding_test(
pytest.skip("Skipping test on Grayskull")

for i in range(3):
a_pt, a_ref, memory_config = unpadding_test(
a_pt, a_ref, memory_config, num_cache_entries = unpadding_test(
kv_cache_shape,
seq_len_start,
seq_len_end,
device,
dtype,
)
assert a_pt.shape == a_ref.shape
assert num_cache_entries == 1
if dtype == ttl.tensor.DataType.BFLOAT8_B:
# inevitable precision loss for bfloat8_b
eq, pcc = comp_pcc(a_pt, a_ref, 0.999)
Expand All @@ -107,19 +107,33 @@ def test_run_unpadding_test(
assert memory_config.shard_spec.shape[0] == a_ref.shape[-2]
assert memory_config.shard_spec.shape[1] == a_ref.shape[-1]

# shift input/output tensor by creating very small tensor between loop
inp = torch.rand(1, 1, 32, 32)
test_tensor = (
ttl.tensor.Tensor(
inp.reshape(-1).tolist(),
inp.shape,
dtype,
ttl.tensor.Layout.ROW_MAJOR,
)
.to(ttl.tensor.Layout.TILE)
.to(device)
)

# hardcoded test 2 to check program caching
kv_cache_shape = (2, 2, 128, 32)
seq_len_start = 0
seq_len_end = 64
for i in range(2):
a_pt, a_ref, memory_config = unpadding_test(
a_pt, a_ref, memory_config, num_cache_entries = unpadding_test(
kv_cache_shape,
seq_len_start,
seq_len_end,
device,
dtype,
)
assert a_pt.shape == a_ref.shape
assert num_cache_entries == 2
if dtype == ttl.tensor.DataType.BFLOAT8_B:
# inevitable precision loss for bfloat8_b
eq, pcc = comp_pcc(a_pt, a_ref, 0.999)
Expand All @@ -133,3 +147,16 @@ def test_run_unpadding_test(
assert memory_config.buffer_type == ttl.tensor.BufferType.L1
assert memory_config.shard_spec.shape[0] == a_ref.shape[-2]
assert memory_config.shard_spec.shape[1] == a_ref.shape[-1]

# shift input/output tensor by creating very small tensor between loop
inp = torch.rand(1, 1, 32, 32)
test_tensor = (
ttl.tensor.Tensor(
inp.reshape(-1).tolist(),
inp.shape,
dtype,
ttl.tensor.Layout.ROW_MAJOR,
)
.to(ttl.tensor.Layout.TILE)
.to(device)
)

0 comments on commit 677b273

Please sign in to comment.