diff --git a/models/demos/wormhole/mnist/README.md b/models/demos/wormhole/mnist/README.md index 7d7d5ed92489..611be3881d88 100644 --- a/models/demos/wormhole/mnist/README.md +++ b/models/demos/wormhole/mnist/README.md @@ -8,9 +8,9 @@ WH N150, WH N300 The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks. -### Batch size: 512 +### Batch size: 256 -Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 512 +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 256 ## How to Run @@ -28,5 +28,3 @@ The demo receives inputs from respective dataset MNIST. ## Additional Information If you encounter issues when running the model, ensure that device has support for all required operations. - -### Owner: [sabira-mcw](https://github.com/sabira-mcw) diff --git a/models/demos/wormhole/mnist/demo/demo.py b/models/demos/wormhole/mnist/demo/demo.py index 59e526353cde..dc6043966d1a 100644 --- a/models/demos/wormhole/mnist/demo/demo.py +++ b/models/demos/wormhole/mnist/demo/demo.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from models.demos.wormhole.mnist.reference.mnist import MnistModel from models.demos.wormhole.mnist.tt import tt_mnist - +from models.utility_functions import disable_persistent_kernel_cache from ttnn.model_preprocessing import preprocess_model_parameters from models.utility_functions import is_wormhole_b0, skip_for_grayskull @@ -25,7 +25,8 @@ def run_demo_dataset(batch_size, iterations, model_location_generator, mesh_devi state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) model = MnistModel(state_dict) model = model.eval() - + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size if mesh_device_flag else batch_size // 2 inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): @@ -65,10 +66,11 @@ def run_demo_dataset(batch_size, iterations, model_location_generator, mesh_devi accuracy = correct / (batch_size * iterations) logger.info(f"MNIST Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}") + assert accuracy >= 0.96484375, f"Expected accuracy : { 0.96484375} Actual accuracy: {accuracy}" @skip_for_grayskull() -@pytest.mark.parametrize("batch_size", [512]) +@pytest.mark.parametrize("batch_size", [256]) @pytest.mark.parametrize("iterations", [1]) def test_demo_dataset( batch_size, @@ -76,6 +78,7 @@ def test_demo_dataset( model_location_generator, mesh_device, ): + disable_persistent_kernel_cache() return run_demo_dataset( batch_size=batch_size, iterations=iterations, diff --git a/models/demos/wormhole/mnist/tests/test_perf_mnist.py b/models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py similarity index 82% rename from models/demos/wormhole/mnist/tests/test_perf_mnist.py rename to models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py index 31f6648b6331..d9d8008202bf 100644 --- a/models/demos/wormhole/mnist/tests/test_perf_mnist.py +++ b/models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py @@ -29,7 +29,7 @@ def get_expected_times(tt_mnist): if is_wormhole_b0(): return { - tt_mnist: (10.460, 0.0139), + tt_mnist: (10.89, 0.017), }[tt_mnist] @@ -37,7 +37,7 @@ def get_expected_times(tt_mnist): @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize( "batch_size", - [512], + [256], ) @pytest.mark.parametrize( "tt_mnist", @@ -52,7 +52,8 @@ def test_performance_mnist(mesh_device, batch_size, tt_mnist, model_location_gen test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) dataloader = DataLoader(test_dataset, batch_size=batch_size) x, labels = next(iter(dataloader)) - + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size if mesh_device_flag else batch_size // 2 inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) @@ -71,34 +72,38 @@ def test_performance_mnist(mesh_device, batch_size, tt_mnist, model_location_gen ttnn_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters) end = time.time() durations.append(end - start) - # enable_persistent_kernel_cache() + enable_persistent_kernel_cache() inference_and_compile_time, *inference_times = durations - average_inference_time = sum(inference_times) / len(inference_times) + inference_time = sum(inference_times) / len(inference_times) expected_compile_time, expected_inference_time = get_expected_times(tt_mnist) prep_perf_report( model_name="MNIST", batch_size=batch_size, inference_and_compile_time=inference_and_compile_time, - inference_time=average_inference_time, + inference_time=inference_time, expected_compile_time=expected_compile_time, expected_inference_time=expected_inference_time, comments="", inference_time_cpu=0.0, ) - logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") - logger.info(f"Inference time: {average_inference_time}") + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") logger.info(f"Inference times: {inference_times}") - logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}") + logger.info(f"Sample(s) per second: {1 / inference_time * batch_size}") + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" + logger.info("Exit MNIST perf test") @skip_for_grayskull() @pytest.mark.parametrize( "batch_size, expected_perf", [ - [512, 2899420.682], + [256, 1520045.60], ], ) @pytest.mark.models_device_performance_bare_metal @@ -107,14 +112,14 @@ def test_perf_device_bare_metal(batch_size, expected_perf): num_iterations = 1 margin = 0.03 - command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py" + command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist_wh.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" expected_perf_cols = {inference_time_key: expected_perf} post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) - expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True) prep_device_perf_report( model_name=f"tt_mnist{batch_size}", batch_size=batch_size, diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index a4cf4bad30bc..e11605b3e77e 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -18,7 +18,7 @@ run_perf_models_other() { if [ "$tt_arch" == "wormhole_b0" ]; then env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py -m $test_marker - env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/mnist/tests/test_perf_mnist.py -m $test_marker + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py -m $test_marker fi env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker @@ -113,6 +113,8 @@ run_device_perf_models() { fi if [ "$tt_arch" == "wormhole_b0" ]; then + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yam pytets models/demos/wormhole/mnist/tests -m $test_marker + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/experimental/functional_unet/tests/test_unet_perf.py -m $test_marker @@ -123,7 +125,6 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker - env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yam pytets models/demos/wormhole/mnist/tests/test_perf_mnist.py::test_performance_mnist -m $test_marker fi ## Merge all the generated reports diff --git a/tests/ttnn/integration_tests/mnist/test_mnist.py b/tests/ttnn/integration_tests/mnist/test_mnist_wh.py similarity index 91% rename from tests/ttnn/integration_tests/mnist/test_mnist.py rename to tests/ttnn/integration_tests/mnist/test_mnist_wh.py index 973dcf3a244b..153cb6b51fcb 100644 --- a/tests/ttnn/integration_tests/mnist/test_mnist.py +++ b/tests/ttnn/integration_tests/mnist/test_mnist_wh.py @@ -15,7 +15,7 @@ @skip_for_grayskull() @pytest.mark.parametrize( "batch_size", - [512], + [256], ) def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) @@ -26,10 +26,12 @@ def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): dataloader = DataLoader(test_dataset, batch_size=batch_size) x, labels = next(iter(dataloader)) torch_output = model(x) + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size if mesh_device_flag else batch_size // 2 inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - mesh_device_flag = True + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): parameters = preprocess_model_parameters(initialize_model=lambda: model, device=mesh_device)