diff --git a/models/demos/lenet/demo/demo.py b/models/demos/lenet/demo/demo.py index 814b978a802..fa21d908497 100644 --- a/models/demos/lenet/demo/demo.py +++ b/models/demos/lenet/demo/demo.py @@ -10,7 +10,9 @@ from loguru import logger from torch.utils.data import DataLoader - +from models.utility_functions import ( + disable_persistent_kernel_cache, +) from ttnn.model_preprocessing import preprocess_model_parameters from models.demos.lenet.tt import tt_lenet from models.demos.lenet import lenet_utils @@ -50,6 +52,7 @@ def run_demo_dataset(device, batch_size, iterations, model_location_generator, r accuracy = correct / (batch_size * iterations) logger.info(f"Dataset Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}") + assert accuracy >= 1.0, f"Expected accuracy : {1.0} Actual accuracy: {accuracy}" @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @@ -62,6 +65,7 @@ def test_demo_dataset( model_location_generator, reset_seeds, ): + disable_persistent_kernel_cache() return run_demo_dataset( reset_seeds=reset_seeds, device=device, diff --git a/models/demos/lenet/tests/test_perf_lenet.py b/models/demos/lenet/tests/test_perf_lenet.py index d64bb633f9d..36cc76a6bdf 100644 --- a/models/demos/lenet/tests/test_perf_lenet.py +++ b/models/demos/lenet/tests/test_perf_lenet.py @@ -26,11 +26,11 @@ def get_expected_times(tt_lenet): if is_grayskull(): return { - tt_lenet: (7.525, 0.9495), + tt_lenet: (5.20, 0.63291), }[tt_lenet] elif is_wormhole_b0(): return { - tt_lenet: (9.52, 0.91), + tt_lenet: (7.95678, 0.8243), }[tt_lenet] @@ -77,26 +77,31 @@ def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, rese ) end = time.time() durations.append(end - start) + enable_persistent_kernel_cache() inference_and_compile_time, *inference_times = durations - average_inference_time = sum(inference_times) / len(inference_times) + inference_time = sum(inference_times) / len(inference_times) expected_compile_time, expected_inference_time = get_expected_times(tt_lenet) prep_perf_report( model_name="tt_lenet", batch_size=batch_size, inference_and_compile_time=inference_and_compile_time, - inference_time=average_inference_time, + inference_time=inference_time, expected_compile_time=expected_compile_time, expected_inference_time=expected_inference_time, comments="", inference_time_cpu=0.0, ) - logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") - logger.info(f"Inference time: {average_inference_time}") + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") logger.info(f"Inference times: {inference_times}") - logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}") + logger.info(f"Sample(s) per second: {1 / inference_time * batch_size}") + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" + logger.info("Exit Lenet perf test") @pytest.mark.parametrize( @@ -109,9 +114,9 @@ def test_perf_device_bare_metal(batch_size, reset_seeds): num_iterations = 1 margin = 0.03 if is_grayskull(): - expected_perf = 6330.022 + expected_perf = 203642.6580 elif is_wormhole_b0(): - expected_perf = 20028.54 + expected_perf = 113208.6151 command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] @@ -120,7 +125,7 @@ def test_perf_device_bare_metal(batch_size, reset_seeds): expected_perf_cols = {inference_time_key: expected_perf} post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) - expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True) prep_device_perf_report( model_name=f"tt_lenet{batch_size}", batch_size=batch_size, diff --git a/models/demos/lenet/tt/tt_lenet.py b/models/demos/lenet/tt/tt_lenet.py index 98774c049c8..bc41af10b1b 100644 --- a/models/demos/lenet/tt/tt_lenet.py +++ b/models/demos/lenet/tt/tt_lenet.py @@ -50,10 +50,7 @@ def conv(device, input_tensor, batch_size, parameters): def Lenet(input_tensor, model, batch_size, num_classes, device, parameters, reset_seeds): conv_1, out_height, out_width = conv(device, input_tensor, batch_size, parameters.layer1) - - conv_1 = ttnn.from_device(conv_1) - conv_1 = ttnn.to_layout(conv_1, layout=ttnn.TILE_LAYOUT) - conv_1 = ttnn.to_device(conv_1, device=device) + conv_1 = ttnn.sharded_to_interleaved(conv_1, ttnn.L1_MEMORY_CONFIG) conv_1 = ttnn.reshape(conv_1, (batch_size, out_height, out_width, conv_1.shape[-1])) conv_1 = ttnn.permute(conv_1, (0, 3, 1, 2)) conv_1 = ttnn.to_torch(conv_1) @@ -67,21 +64,26 @@ def Lenet(input_tensor, model, batch_size, num_classes, device, parameters, rese conv_2, out_height, out_width = conv(device, maxpool_1, batch_size, parameters.layer2) - conv_2 = ttnn.from_device(conv_2) - conv_2 = ttnn.to_layout(conv_2, layout=ttnn.TILE_LAYOUT) - conv_2 = ttnn.to_device(conv_2, device=device) - conv_2 = ttnn.reshape(conv_2, (batch_size, out_height, out_width, conv_2.shape[-1])) - conv_2 = ttnn.permute(conv_2, (0, 3, 1, 2)) - conv_2 = ttnn.to_torch(conv_2) + conv_2 = ttnn.to_layout(conv_2, layout=ttnn.ROW_MAJOR_LAYOUT) + maxpool_2 = ttnn.max_pool2d( + input_tensor=conv_2, + batch_size=batch_size, + input_h=out_height, + input_w=out_width, + channels=conv_2.shape[3], + kernel_size=[2, 2], + stride=[2, 2], + padding=[0, 0], + dilation=[1, 1], + ) - max = nn.MaxPool2d(kernel_size=2, stride=2) - maxpool_2 = max(conv_2) + maxpool_2 = ttnn.sharded_to_interleaved(maxpool_2, ttnn.L1_MEMORY_CONFIG) - maxpool_2 = ttnn.from_torch(maxpool_2, dtype=ttnn.bfloat16) + maxpool_2 = ttnn.to_layout(maxpool_2, layout=ttnn.TILE_LAYOUT) + maxpool_2 = ttnn.reshape(maxpool_2, (batch_size, 5, 5, maxpool_2.shape[3])) + maxpool_2 = ttnn.permute(maxpool_2, (0, 3, 1, 2)) maxpool_2 = ttnn.reshape(maxpool_2, (maxpool_2.shape[0], -1)) - maxpool_2 = ttnn.to_device(maxpool_2, device=device) - maxpool_2 = ttnn.to_layout(maxpool_2, layout=ttnn.TILE_LAYOUT) linear_1 = ttnn.linear( maxpool_2,