From 2035fd074e4e60912763c917a607f7d7401de6b4 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 27 Aug 2024 12:06:32 +0000 Subject: [PATCH] #10754: Add data-parallel support for UNet Shallow on N300 --- .../functional_unet/tests/common.py | 27 +++ .../tests/test_unet_bottleneck.py | 53 +++++- .../tests/test_unet_downblock.py | 74 ++++++-- .../functional_unet/tests/test_unet_model.py | 4 +- .../tests/test_unet_multi_device.py | 51 ++++++ .../tests/test_unet_output_layer.py | 10 +- .../tests/test_unet_upblock.py | 83 ++++++++- .../functional_unet/tt/model_preprocessing.py | 23 ++- .../functional_unet/tt/unet_shallow_ttnn.py | 158 ++++++++++++------ 9 files changed, 388 insertions(+), 95 deletions(-) create mode 100644 models/experimental/functional_unet/tests/common.py create mode 100644 models/experimental/functional_unet/tests/test_unet_multi_device.py diff --git a/models/experimental/functional_unet/tests/common.py b/models/experimental/functional_unet/tests/common.py new file mode 100644 index 00000000000..77ae999e60b --- /dev/null +++ b/models/experimental/functional_unet/tests/common.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from tests.ttnn.utils_for_testing import assert_with_pcc + + +def is_n300_with_eth_dispatch_cores(device_mesh) -> bool: + all_devices_using_full_grid = all( + [(8 == device.core_grid.x and 8 == device.core_grid.y) for device in device_mesh.get_devices()] + ) + return all_devices_using_full_grid and (len(device_mesh.get_devices()) == 2) + + +def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None): + B, C, H, W = torch_tensor.shape + ttnn_tensor = ttnn.to_torch(ttnn_tensor, mesh_composer=mesh_composer).reshape(B, H, W, C).permute(0, 3, 1, 2) + assert_with_pcc(torch_tensor, ttnn_tensor, pcc) + + +def check_pcc_pool(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None): + B, C, H, W = torch_tensor.shape + ttnn_tensor = ( + ttnn.to_torch(ttnn_tensor, mesh_composer=mesh_composer).reshape(B, H, W, -1).permute(0, 3, 1, 2)[:, :C, :, :] + ) + assert_with_pcc(torch_tensor, ttnn_tensor, pcc) diff --git a/models/experimental/functional_unet/tests/test_unet_bottleneck.py b/models/experimental/functional_unet/tests/test_unet_bottleneck.py index bf76c24b9d2..c005294224c 100644 --- a/models/experimental/functional_unet/tests/test_unet_bottleneck.py +++ b/models/experimental/functional_unet/tests/test_unet_bottleneck.py @@ -5,6 +5,8 @@ import pytest import ttnn +from loguru import logger + from tests.ttnn.utils_for_testing import assert_with_pcc from models.experimental.functional_unet.tt.model_preprocessing import ( @@ -13,23 +15,19 @@ ) from models.experimental.functional_unet.tt import unet_shallow_torch from models.experimental.functional_unet.tt import unet_shallow_ttnn +from models.experimental.functional_unet.tests.common import is_n300_with_eth_dispatch_cores, check_pcc_conv @pytest.mark.parametrize("batch", [2]) @pytest.mark.parametrize("groups", [1]) @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) -def test_unet_downblocks(batch, groups, device): +def test_unet_bottleneck(batch, groups, device, reset_seeds): torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False) model = unet_shallow_torch.UNet.from_random_weights(groups=1) parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device) ttnn_model = unet_shallow_ttnn.UNet(parameters, device) - def check_pcc(torch_tensor, ttnn_tensor, pcc=0.995): - B, C, H, W = torch_tensor.shape - ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, C).permute(0, 3, 1, 2) - assert_with_pcc(torch_tensor, ttnn_tensor, pcc) - torch_input, ttnn_input = create_unet_input_tensors( device, batch, groups, pad_input=True, input_channels=32, input_height=66, input_width=10 ) @@ -38,4 +36,45 @@ def check_pcc(torch_tensor, ttnn_tensor, pcc=0.995): ttnn_input = ttnn.to_device(ttnn_input, device=device) ttnn_output = ttnn_model.bottleneck(ttnn_input) - check_pcc(torch_output, ttnn_output) + check_pcc_conv(torch_output, ttnn_output, pcc=0.999) + + +@pytest.mark.parametrize("batch", [2]) +@pytest.mark.parametrize("groups", [1]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +def test_unet_bottleneck_multi_device(batch, groups, device_mesh, reset_seeds): + if not is_n300_with_eth_dispatch_cores(device_mesh): + pytest.skip("Test is only valid for N300") + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh) + output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0) + + torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=False) + model = unet_shallow_torch.UNet.from_random_weights(groups=1) + + parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh) + ttnn_model = unet_shallow_ttnn.UNet(parameters, device_mesh, mesh_mapper=weights_mesh_mapper) + + num_devices = len(device_mesh.get_device_ids()) + torch_input, ttnn_input = create_unet_input_tensors( + device_mesh, + num_devices * batch, + groups, + pad_input=True, + input_channels=32, + input_height=66, + input_width=10, + mesh_mapper=inputs_mesh_mapper, + ) + logger.info(f"Created reference input tensors: {list(torch_input.shape)}") + logger.info( + f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}" + ) + torch_output = model.bottleneck(torch_input) + + ttnn_input = ttnn_input.to(device_mesh) + ttnn_output = ttnn_model.bottleneck(ttnn_input) + + assert len(ttnn_output.devices()) == 2, "Expected output tensor to be sharded across 2 devices" + check_pcc_conv(torch_output, ttnn_output, mesh_composer=output_mesh_composer, pcc=0.999) diff --git a/models/experimental/functional_unet/tests/test_unet_downblock.py b/models/experimental/functional_unet/tests/test_unet_downblock.py index 9c1d94df625..1a3d045715d 100644 --- a/models/experimental/functional_unet/tests/test_unet_downblock.py +++ b/models/experimental/functional_unet/tests/test_unet_downblock.py @@ -4,8 +4,7 @@ import pytest import ttnn - -from tests.ttnn.utils_for_testing import assert_with_pcc +from loguru import logger from models.experimental.functional_unet.tt.model_preprocessing import ( create_unet_input_tensors, @@ -13,18 +12,11 @@ ) from models.experimental.functional_unet.tt import unet_shallow_torch from models.experimental.functional_unet.tt import unet_shallow_ttnn - - -def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.999): - B, C, H, W = torch_tensor.shape - ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, C).permute(0, 3, 1, 2) - assert_with_pcc(torch_tensor, ttnn_tensor, pcc) - - -def check_pcc_pool(torch_tensor, ttnn_tensor, pcc=0.999): - B, C, H, W = torch_tensor.shape - ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, -1).permute(0, 3, 1, 2)[:, :C, :, :] - assert_with_pcc(torch_tensor, ttnn_tensor, pcc) +from models.experimental.functional_unet.tests.common import ( + check_pcc_conv, + check_pcc_pool, + is_n300_with_eth_dispatch_cores, +) @pytest.mark.parametrize("batch, groups", [(2, 1)]) @@ -61,3 +53,57 @@ def test_unet_downblock(batch, groups, block_name, input_channels, input_height, check_pcc_conv(torch_residual, ttnn_residual) check_pcc_pool(torch_output, ttnn_output) + + +@pytest.mark.parametrize("batch, groups", [(2, 1)]) +@pytest.mark.parametrize( + "block_name, input_channels, input_height, input_width", + [ + ("downblock1", 4, 1056, 160), + ("downblock2", 16, 528, 80), + ("downblock3", 16, 264, 40), + ("downblock4", 32, 132, 20), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +def test_unet_downblock_multi_device( + batch, groups, block_name, input_channels, input_height, input_width, device_mesh, reset_seeds +): + if not is_n300_with_eth_dispatch_cores(device_mesh): + pytest.skip("Test is only valid for N300") + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh) + output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0) + + torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=False) + model = unet_shallow_torch.UNet.from_random_weights(groups=1) + + parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh) + ttnn_model = unet_shallow_ttnn.UNet(parameters, device_mesh, mesh_mapper=weights_mesh_mapper) + + num_devices = len(device_mesh.get_device_ids()) + torch_input, ttnn_input = create_unet_input_tensors( + device_mesh, + num_devices * batch, + groups, + pad_input=True, + input_channels=input_channels, + input_height=input_height, + input_width=input_width, + mesh_mapper=inputs_mesh_mapper, + ) + logger.info(f"Created reference input tensors: {list(torch_input.shape)}") + logger.info( + f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}" + ) + + torch_output, torch_residual = getattr(model, block_name)(torch_input) + + ttnn_input = ttnn_input.to(device_mesh) + ttnn_output, ttnn_residual = getattr(ttnn_model, block_name)(ttnn_input) + + assert len(ttnn_output.devices()) == 2, "Expected output tensor to be sharded across 2 devices" + assert len(ttnn_residual.devices()) == 2, "Expected residual output tensor to be sharded across 2 devices" + check_pcc_conv(torch_residual, ttnn_residual, mesh_composer=output_mesh_composer) + check_pcc_pool(torch_output, ttnn_output, mesh_composer=output_mesh_composer) diff --git a/models/experimental/functional_unet/tests/test_unet_model.py b/models/experimental/functional_unet/tests/test_unet_model.py index e189a320a9e..84f66d9c856 100644 --- a/models/experimental/functional_unet/tests/test_unet_model.py +++ b/models/experimental/functional_unet/tests/test_unet_model.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize("batch", [2]) @pytest.mark.parametrize("groups", [1]) @pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True) -def test_unet_model(batch, groups, device, use_program_cache): +def test_unet_model(batch, groups, device, use_program_cache, reset_seeds): torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True) model = unet_shallow_torch.UNet.from_random_weights(groups=1) @@ -30,4 +30,4 @@ def test_unet_model(batch, groups, device, use_program_cache): B, C, H, W = torch_output_tensor.shape ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2) - assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.985) + assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.99) diff --git a/models/experimental/functional_unet/tests/test_unet_multi_device.py b/models/experimental/functional_unet/tests/test_unet_multi_device.py new file mode 100644 index 00000000000..c6ac1e43a51 --- /dev/null +++ b/models/experimental/functional_unet/tests/test_unet_multi_device.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import ttnn + +from loguru import logger + +from models.experimental.functional_unet.tt.model_preprocessing import ( + create_unet_input_tensors, + create_unet_model_parameters, +) +from models.experimental.functional_unet.tt import unet_shallow_torch +from models.experimental.functional_unet.tt import unet_shallow_ttnn +from models.experimental.functional_unet.tests.common import ( + check_pcc_conv, + is_n300_with_eth_dispatch_cores, +) + + +@pytest.mark.parametrize("batch", [2]) +@pytest.mark.parametrize("groups", [1]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True) +def test_unet_multi_device_model(batch, groups, device_mesh, use_program_cache, reset_seeds): + if not is_n300_with_eth_dispatch_cores(device_mesh): + pytest.skip("Test is only valid for N300") + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh) + output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0) + + torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=True) + model = unet_shallow_torch.UNet.from_random_weights(groups=groups) + + parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh) + ttnn_model = unet_shallow_ttnn.UNet(parameters, device=device_mesh, mesh_mapper=weights_mesh_mapper) + + num_devices = len(device_mesh.get_device_ids()) + torch_input, ttnn_input = create_unet_input_tensors( + device_mesh, num_devices * batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper + ) + logger.info(f"Created reference input tensors: {list(torch_input.shape)}") + logger.info( + f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}" + ) + + torch_output_tensor = model(torch_input) + output_tensor = ttnn_model(ttnn_input, list(torch_input.shape)) + + check_pcc_conv(torch_output_tensor, output_tensor, mesh_composer=output_mesh_composer, pcc=0.99) diff --git a/models/experimental/functional_unet/tests/test_unet_output_layer.py b/models/experimental/functional_unet/tests/test_unet_output_layer.py index 82c8c66be48..231e1a09f3a 100644 --- a/models/experimental/functional_unet/tests/test_unet_output_layer.py +++ b/models/experimental/functional_unet/tests/test_unet_output_layer.py @@ -17,19 +17,21 @@ @pytest.mark.parametrize("batch, groups", [(2, 1)]) @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) -def test_unet_output_layer(batch, groups, device): +def test_unet_output_layer(batch, groups, device, reset_seeds): torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False) model = unet_shallow_torch.UNet.from_random_weights(groups=groups) parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device) ttnn_model = unet_shallow_ttnn.UNet(parameters, device) - torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True, input_channels=16) + torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False, input_channels=16) torch_output = model.output_layer(torch_input) + + ttnn_input = ttnn.to_device(ttnn_input, device) ttnn_output = ttnn_model.output_layer(ttnn_input) B, C, H, W = torch_output.shape ttnn_output = ttnn.to_torch(ttnn_output) - assert list(ttnn_output.shape) == [1, 1, B * H * W, 32], "Expected output layer to return padded output" - ttnn_output = ttnn_output.reshape(B, H, W, 32)[:, :, :, :C].permute(0, 3, 1, 2) + assert list(ttnn_output.shape) == [1, 1, B * H * W, C], "Expected output layer to be [1, 1, BHW, C]" + ttnn_output = ttnn_output.reshape(B, H, W, C).permute(0, 3, 1, 2) assert_with_pcc(torch_output, ttnn_output, 0.99) diff --git a/models/experimental/functional_unet/tests/test_unet_upblock.py b/models/experimental/functional_unet/tests/test_unet_upblock.py index d33e8965f5f..7ff63e11349 100644 --- a/models/experimental/functional_unet/tests/test_unet_upblock.py +++ b/models/experimental/functional_unet/tests/test_unet_upblock.py @@ -5,6 +5,8 @@ import pytest import ttnn +from loguru import logger + from tests.ttnn.utils_for_testing import assert_with_pcc from models.experimental.functional_unet.tt.model_preprocessing import ( @@ -13,12 +15,10 @@ ) from models.experimental.functional_unet.tt import unet_shallow_torch from models.experimental.functional_unet.tt import unet_shallow_ttnn - - -def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.995): - B, C, H, W = torch_tensor.shape - ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, C).permute(0, 3, 1, 2) - assert_with_pcc(torch_tensor, ttnn_tensor, pcc) +from models.experimental.functional_unet.tests.common import ( + check_pcc_conv, + is_n300_with_eth_dispatch_cores, +) @pytest.mark.parametrize("batch, groups", [(2, 1)]) @@ -32,7 +32,9 @@ def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.995): ], ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) -def test_unet_upblock(batch, groups, block_name, input_channels, input_height, input_width, residual_channels, device): +def test_unet_upblock( + batch, groups, block_name, input_channels, input_height, input_width, residual_channels, device, reset_seeds +): torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False) model = unet_shallow_torch.UNet.from_random_weights(groups=groups) @@ -62,4 +64,69 @@ def test_unet_upblock(batch, groups, block_name, input_channels, input_height, i ttnn_input, ttnn_residual = ttnn_input.to(device), ttnn_residual.to(device) ttnn_output = getattr(ttnn_model, block_name)(ttnn_input, ttnn_residual) - check_pcc_conv(torch_output, ttnn_output) + check_pcc_conv(torch_output, ttnn_output, pcc=0.998) + + +@pytest.mark.parametrize("batch, groups", [(2, 1)]) +@pytest.mark.parametrize( + "block_name, input_channels, input_height, input_width, residual_channels", + [ + ("upblock1", 64, 66, 10, 32), + ("upblock2", 32, 132, 20, 32), + ("upblock3", 32, 264, 40, 16), + ("upblock4", 16, 528, 80, 16), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +def test_unet_upblock_multi_device( + batch, groups, block_name, input_channels, input_height, input_width, residual_channels, device_mesh, reset_seeds +): + if not is_n300_with_eth_dispatch_cores(device_mesh): + pytest.skip("Test is only valid for N300") + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh) + output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0) + + torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=False) + model = unet_shallow_torch.UNet.from_random_weights(groups=groups) + + parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh) + ttnn_model = unet_shallow_ttnn.UNet(parameters, device_mesh, mesh_mapper=weights_mesh_mapper) + + num_devices = len(device_mesh.get_device_ids()) + torch_input, ttnn_input = create_unet_input_tensors( + device_mesh, + num_devices * batch, + groups, + pad_input=False, + input_channels=input_channels, + input_height=input_height, + input_width=input_width, + mesh_mapper=inputs_mesh_mapper, + ) + logger.info(f"Created reference input tensors: {list(torch_input.shape)}") + logger.info( + f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}" + ) + torch_residual, ttnn_residual = create_unet_input_tensors( + device_mesh, + num_devices * batch, + groups, + pad_input=False, + input_channels=residual_channels, + input_height=input_height * 2, + input_width=input_width * 2, + mesh_mapper=inputs_mesh_mapper, + ) + logger.info(f"Created reference residual input tensors: {list(torch_residual.shape)}") + logger.info( + f"Created multi-device residual input tensors: shape={list(ttnn_residual.shape)} on devices={device_mesh.get_device_ids()}" + ) + torch_output = getattr(model, block_name)(torch_input, torch_residual) + + ttnn_input, ttnn_residual = ttnn_input.to(device_mesh), ttnn_residual.to(device_mesh) + ttnn_output = getattr(ttnn_model, block_name)(ttnn_input, ttnn_residual) + + assert len(ttnn_output.devices()) == 2, "Expected output tensor to be sharded across 2 devices" + check_pcc_conv(torch_output, ttnn_output, mesh_composer=output_mesh_composer, pcc=0.998) diff --git a/models/experimental/functional_unet/tt/model_preprocessing.py b/models/experimental/functional_unet/tt/model_preprocessing.py index 55065c1c038..c92848c42a7 100644 --- a/models/experimental/functional_unet/tt/model_preprocessing.py +++ b/models/experimental/functional_unet/tt/model_preprocessing.py @@ -11,14 +11,14 @@ def create_unet_input_tensors( - device, batch, groups, pad_input=True, input_channels=4, input_height=1056, input_width=160 + device, batch, groups, pad_input=True, input_channels=4, input_height=1056, input_width=160, mesh_mapper=None ): torch_input_tensor = torch.randn(batch, input_channels * groups, input_height, input_width) ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) ttnn_input_tensor = ttnn_input_tensor.reshape( + ttnn_input_tensor.shape[0], 1, - 1, - ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], + ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], ttnn_input_tensor.shape[3], ) if pad_input: @@ -30,8 +30,13 @@ def create_unet_input_tensors( ttnn_input_tensor, (0, max(0, pad - ttnn_input_tensor.shape[-1]), 0, max(0, hpad - ttnn_input_tensor.shape[-2])), ) - ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16) - + ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16, mesh_mapper=mesh_mapper) + ttnn_input_tensor = ttnn_input_tensor.reshape( + 1, + 1, + ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], + ttnn_input_tensor.shape[3], + ) return torch_input_tensor, ttnn_input_tensor @@ -60,11 +65,11 @@ def create_unet_model_parameters(model: unet_shallow_torch.UNet, input_tensor: t "num_cores_nhw": 55, } - parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} - parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} + parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} - parameters.c2["conv_blocking_and_parallelization_config_override"] = None - parameters.c2_2["conv_blocking_and_parallelization_config_override"] = None + parameters.c2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c2_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} parameters.c3["conv_blocking_and_parallelization_config_override"] = None parameters.c3_2["conv_blocking_and_parallelization_config_override"] = None parameters.c4["conv_blocking_and_parallelization_config_override"] = None diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 44292a1428d..3fa276dfbf1 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -75,6 +75,40 @@ def unet_concat(ttnn_tensors, dim=-1, use_reshard=True, perf_mode=False): return ttnn.concat(ttlib_tensors, dim=dim, memory_config=output_mem_config) +class UNetPointwiseConv2D: + def __init__( + self, + conv, + device=None, + activation_dtype=ttnn.bfloat16, + mesh_mapper=None, + ): + self.device = device + self.in_channels = conv.in_channels + self.mesh_mapper = mesh_mapper + self.activation_dtype = activation_dtype + + weight, bias = conv.module.weight, conv.module.bias + + assert conv.kernel_size == (1, 1) + assert conv.stride == (1, 1) + assert conv.padding == (0, 0) + + weight = weight.reshape(1, 1, self.in_channels, 1) + bias = torch.reshape(bias, (1, 1, 1, -1)) + + # Do this in two steps because tensors are padded differently in multi-device vs. single device + self.weight = ttnn.from_torch(weight, device=None, dtype=ttnn.bfloat16, mesh_mapper=mesh_mapper) + self.bias = ttnn.from_torch(bias, device=None, dtype=ttnn.bfloat16, mesh_mapper=mesh_mapper) + self.weight = ttnn.to_layout(self.weight, ttnn.TILE_LAYOUT).to(device) + self.bias = ttnn.to_layout(self.bias, ttnn.TILE_LAYOUT).to(device) + + def __call__(self, x): + x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) + x = ttnn.linear(x, self.weight, bias=self.bias, dtype=self.activation_dtype) + return x + + class UNetConv2D: def __init__( self, @@ -85,6 +119,7 @@ def __init__( activation="relu", activation_dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, + mesh_mapper=None, ): self.device = device self.batch_size = conv.batch_size @@ -99,12 +134,17 @@ def __init__( self.use_1d_systolic_array = conv.use_1d_systolic_array self.deallocate_activation = True self.cache = cache + self.mesh_mapper = mesh_mapper self.conv_config = ttnn.Conv2dConfig( dtype=activation_dtype, weights_dtype=weights_dtype, math_fidelity=ttnn.MathFidelity.LoFi, - height_sharding=self.use_1d_systolic_array, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if self.use_1d_systolic_array + else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), deallocate_activation=self.deallocate_activation, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, @@ -125,15 +165,11 @@ def __init__( weight = weight bias = torch.reshape(bias, (1, 1, 1, -1)) - # Required for pointwise convolutions (output layer) - if bias.shape[-1] == 1: - bias = bias.repeat((1, 1, 32, 32)) - - self.weight = ttnn.from_torch(weight, dtype=ttnn.float32) - self.bias = ttnn.from_torch(bias, dtype=ttnn.float32) + self.weight = ttnn.from_torch(weight, dtype=ttnn.float32, mesh_mapper=mesh_mapper) + self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper) def __call__(self, x): - x, output_height, output_width, self.weight, self.bias = ttnn.conv2d( + x, _, _, self.weight, self.bias = ttnn.conv2d( input_tensor=x, weight_tensor=self.weight, bias_tensor=self.bias, @@ -153,39 +189,51 @@ def __call__(self, x): class UNetMaxPool2D: - def __init__(self, pool, device=None, reader_patterns_cache={}): + def __init__(self, pool, channels, device=None, reader_patterns_cache={}): self.pool = pool - self.max_pool = ttnn.MaxPool2d( - kernel_size=pool.kernel_size, - stride=pool.stride, - padding=pool.padding, - dilation=pool.dilation, - dtype=pool.dtype, - batch_size=pool.batch_size, - input_height=pool.input_height, - input_width=pool.input_width, - reader_patterns_cache=reader_patterns_cache, - parallel_config_override=pool.parallel_config_override, - deallocate_activation=True, - device=device, - ) + self.channels = channels + self.device = device def __call__(self, x): # For some reason the shard widths don't always match - so don't assert on it - assert ( - x.memory_config().shard_spec.num_cores() - == self.max_pool.max_pool.input_sharded_memory_config.shard_spec.num_cores() - and x.memory_config().shard_spec.shape[0] - == self.max_pool.max_pool.input_sharded_memory_config.shard_spec.shape[0] - ), "Expected same input shard to match max pool's shard configuration" - return self.max_pool(x) + # assert ( + # x.memory_config().shard_spec.num_cores() + # == self.max_pool.max_pool.input_sharded_memory_config.shard_spec.num_cores() + # and x.memory_config().shard_spec.shape[0] + # == self.max_pool.max_pool.input_sharded_memory_config.shard_spec.shape[0] + # ), "Expected same input shard to match max pool's shard configuration" + x = ttnn.max_pool2d_new( + input_tensor=x, + batch_size=self.pool.batch_size, + input_h=self.pool.input_height, + input_w=self.pool.input_width, + channels=self.channels, + kernel_size=[self.pool.kernel_size, self.pool.kernel_size], + stride=[self.pool.stride, self.pool.stride], + padding=[self.pool.padding, self.pool.padding], + dilation=[self.pool.dilation, self.pool.dilation], + device=self.device, + ) + return x class UNetDownblock: - def __init__(self, conv1, bn1, conv2, bn2, pool, device, conv_cache={}, max_pool_cache={}, should_reshard=False): - self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache) - self.conv2 = UNetConv2D(conv2, bn=bn2, device=device, cache=conv_cache) - self.pool1 = UNetMaxPool2D(pool, device=device, reader_patterns_cache=max_pool_cache) + def __init__( + self, + conv1, + bn1, + conv2, + bn2, + pool, + device, + conv_cache={}, + max_pool_cache={}, + should_reshard=False, + mesh_mapper=None, + ): + self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) + self.conv2 = UNetConv2D(conv2, bn=bn2, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) + self.pool1 = UNetMaxPool2D(pool, conv2.out_channels, device=device, reader_patterns_cache=max_pool_cache) self.should_reshard = should_reshard if self.should_reshard: @@ -232,11 +280,13 @@ def __call__(self, x): class UNetUpblock: - def __init__(self, conv1, bn1, conv2, bn2, conv3, bn3, device, conv_cache={}, should_reshard=False): + def __init__( + self, conv1, bn1, conv2, bn2, conv3, bn3, device, conv_cache={}, should_reshard=False, mesh_mapper=None + ): self.device = device - self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache) - self.conv2 = UNetConv2D(conv2, bn2, device, conv_cache) - self.conv3 = UNetConv2D(conv3, bn3, device, conv_cache) + self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper) + self.conv2 = UNetConv2D(conv2, bn2, device, conv_cache, mesh_mapper=mesh_mapper) + self.conv3 = UNetConv2D(conv3, bn3, device, conv_cache, mesh_mapper=mesh_mapper) self.should_reshard = should_reshard if self.should_reshard: @@ -298,7 +348,7 @@ def __call__(self, x, residual): class UNet: - def __init__(self, parameters: ParameterDict, device) -> None: + def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: self.device = device self.conv_cache = {} self.max_pool_cache = {} @@ -312,6 +362,7 @@ def __init__(self, parameters: ParameterDict, device) -> None: conv_cache=self.conv_cache, max_pool_cache=self.max_pool_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) self.downblock2 = UNetDownblock( parameters.c2, @@ -323,6 +374,7 @@ def __init__(self, parameters: ParameterDict, device) -> None: conv_cache=self.conv_cache, max_pool_cache=self.max_pool_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) self.downblock3 = UNetDownblock( parameters.c3, @@ -334,6 +386,7 @@ def __init__(self, parameters: ParameterDict, device) -> None: conv_cache=self.conv_cache, max_pool_cache=self.max_pool_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) self.downblock4 = UNetDownblock( parameters.c4, @@ -345,10 +398,13 @@ def __init__(self, parameters: ParameterDict, device) -> None: conv_cache=self.conv_cache, max_pool_cache=self.max_pool_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) - self.bnc = UNetConv2D(parameters.bnc, parameters.bnb, device, cache=self.conv_cache) - self.bnc2 = UNetConv2D(parameters.bnc_2, parameters.bnb_2, device, cache=self.conv_cache) + self.bnc = UNetConv2D(parameters.bnc, parameters.bnb, device, cache=self.conv_cache, mesh_mapper=mesh_mapper) + self.bnc2 = UNetConv2D( + parameters.bnc_2, parameters.bnb_2, device, cache=self.conv_cache, mesh_mapper=mesh_mapper + ) bnc_parallel_config = determine_parallel_config( is_1d_systolic=True, batch_size=self.bnc.batch_size, @@ -380,6 +436,7 @@ def __init__(self, parameters: ParameterDict, device) -> None: device, conv_cache=self.conv_cache, should_reshard=False, + mesh_mapper=mesh_mapper, ) self.upblock2 = UNetUpblock( parameters.c6, @@ -391,6 +448,7 @@ def __init__(self, parameters: ParameterDict, device) -> None: device, conv_cache=self.conv_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) self.upblock3 = UNetUpblock( parameters.c7, @@ -402,6 +460,7 @@ def __init__(self, parameters: ParameterDict, device) -> None: device, conv_cache=self.conv_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) self.upblock4 = UNetUpblock( parameters.c8, @@ -413,22 +472,21 @@ def __init__(self, parameters: ParameterDict, device) -> None: device, conv_cache=self.conv_cache, should_reshard=True, + mesh_mapper=mesh_mapper, ) - self.output_layer = UNetConv2D( + self.output_layer = UNetPointwiseConv2D( parameters.output_layer, - bn=None, device=device, - cache=self.conv_cache, - activation="", - activation_dtype=ttnn.bfloat16, - weights_dtype=ttnn.bfloat8_b, + mesh_mapper=mesh_mapper, ) def bottleneck(self, x): - x = ttnn.to_memory_config( + if x.is_sharded(): + x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG) + x = ttnn.interleaved_to_sharded( x, - memory_config=self.bnc_sharded_memory_config, + self.bnc_sharded_memory_config, ) x = self.bnc(x) return self.bnc2(x) @@ -450,8 +508,6 @@ def __call__(self, x, original_shape, perf_mode=False): x = self.upblock3(x, c2_residual) x = self.upblock4(x, c1_residual) - # Pointwise convolutions currently don't handle padded inputs - x = x.cpu().pad_to_tile(0) x = self.output_layer(x) x = ttnn.from_device(x)