Skip to content

Commit

Permalink
#10754: Add data-parallel support for UNet Shallow on N300
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Aug 30, 2024
1 parent af7a2a4 commit 2035fd0
Show file tree
Hide file tree
Showing 9 changed files with 388 additions and 95 deletions.
27 changes: 27 additions & 0 deletions models/experimental/functional_unet/tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc


def is_n300_with_eth_dispatch_cores(device_mesh) -> bool:
all_devices_using_full_grid = all(
[(8 == device.core_grid.x and 8 == device.core_grid.y) for device in device_mesh.get_devices()]
)
return all_devices_using_full_grid and (len(device_mesh.get_devices()) == 2)


def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None):
B, C, H, W = torch_tensor.shape
ttnn_tensor = ttnn.to_torch(ttnn_tensor, mesh_composer=mesh_composer).reshape(B, H, W, C).permute(0, 3, 1, 2)
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)


def check_pcc_pool(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None):
B, C, H, W = torch_tensor.shape
ttnn_tensor = (
ttnn.to_torch(ttnn_tensor, mesh_composer=mesh_composer).reshape(B, H, W, -1).permute(0, 3, 1, 2)[:, :C, :, :]
)
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)
53 changes: 46 additions & 7 deletions models/experimental/functional_unet/tests/test_unet_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest
import ttnn

from loguru import logger

from tests.ttnn.utils_for_testing import assert_with_pcc

from models.experimental.functional_unet.tt.model_preprocessing import (
Expand All @@ -13,23 +15,19 @@
)
from models.experimental.functional_unet.tt import unet_shallow_torch
from models.experimental.functional_unet.tt import unet_shallow_ttnn
from models.experimental.functional_unet.tests.common import is_n300_with_eth_dispatch_cores, check_pcc_conv


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_downblocks(batch, groups, device):
def test_unet_bottleneck(batch, groups, device, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

def check_pcc(torch_tensor, ttnn_tensor, pcc=0.995):
B, C, H, W = torch_tensor.shape
ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, C).permute(0, 3, 1, 2)
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)

torch_input, ttnn_input = create_unet_input_tensors(
device, batch, groups, pad_input=True, input_channels=32, input_height=66, input_width=10
)
Expand All @@ -38,4 +36,45 @@ def check_pcc(torch_tensor, ttnn_tensor, pcc=0.995):
ttnn_input = ttnn.to_device(ttnn_input, device=device)
ttnn_output = ttnn_model.bottleneck(ttnn_input)

check_pcc(torch_output, ttnn_output)
check_pcc_conv(torch_output, ttnn_output, pcc=0.999)


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_bottleneck_multi_device(batch, groups, device_mesh, reset_seeds):
if not is_n300_with_eth_dispatch_cores(device_mesh):
pytest.skip("Test is only valid for N300")

inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh)
output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device_mesh, mesh_mapper=weights_mesh_mapper)

num_devices = len(device_mesh.get_device_ids())
torch_input, ttnn_input = create_unet_input_tensors(
device_mesh,
num_devices * batch,
groups,
pad_input=True,
input_channels=32,
input_height=66,
input_width=10,
mesh_mapper=inputs_mesh_mapper,
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}"
)
torch_output = model.bottleneck(torch_input)

ttnn_input = ttnn_input.to(device_mesh)
ttnn_output = ttnn_model.bottleneck(ttnn_input)

assert len(ttnn_output.devices()) == 2, "Expected output tensor to be sharded across 2 devices"
check_pcc_conv(torch_output, ttnn_output, mesh_composer=output_mesh_composer, pcc=0.999)
74 changes: 60 additions & 14 deletions models/experimental/functional_unet/tests/test_unet_downblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,19 @@

import pytest
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from loguru import logger

from models.experimental.functional_unet.tt.model_preprocessing import (
create_unet_input_tensors,
create_unet_model_parameters,
)
from models.experimental.functional_unet.tt import unet_shallow_torch
from models.experimental.functional_unet.tt import unet_shallow_ttnn


def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.999):
B, C, H, W = torch_tensor.shape
ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, C).permute(0, 3, 1, 2)
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)


def check_pcc_pool(torch_tensor, ttnn_tensor, pcc=0.999):
B, C, H, W = torch_tensor.shape
ttnn_tensor = ttnn.to_torch(ttnn_tensor).reshape(B, H, W, -1).permute(0, 3, 1, 2)[:, :C, :, :]
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)
from models.experimental.functional_unet.tests.common import (
check_pcc_conv,
check_pcc_pool,
is_n300_with_eth_dispatch_cores,
)


@pytest.mark.parametrize("batch, groups", [(2, 1)])
Expand Down Expand Up @@ -61,3 +53,57 @@ def test_unet_downblock(batch, groups, block_name, input_channels, input_height,

check_pcc_conv(torch_residual, ttnn_residual)
check_pcc_pool(torch_output, ttnn_output)


@pytest.mark.parametrize("batch, groups", [(2, 1)])
@pytest.mark.parametrize(
"block_name, input_channels, input_height, input_width",
[
("downblock1", 4, 1056, 160),
("downblock2", 16, 528, 80),
("downblock3", 16, 264, 40),
("downblock4", 32, 132, 20),
],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_downblock_multi_device(
batch, groups, block_name, input_channels, input_height, input_width, device_mesh, reset_seeds
):
if not is_n300_with_eth_dispatch_cores(device_mesh):
pytest.skip("Test is only valid for N300")

inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh)
output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device_mesh, mesh_mapper=weights_mesh_mapper)

num_devices = len(device_mesh.get_device_ids())
torch_input, ttnn_input = create_unet_input_tensors(
device_mesh,
num_devices * batch,
groups,
pad_input=True,
input_channels=input_channels,
input_height=input_height,
input_width=input_width,
mesh_mapper=inputs_mesh_mapper,
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}"
)

torch_output, torch_residual = getattr(model, block_name)(torch_input)

ttnn_input = ttnn_input.to(device_mesh)
ttnn_output, ttnn_residual = getattr(ttnn_model, block_name)(ttnn_input)

assert len(ttnn_output.devices()) == 2, "Expected output tensor to be sharded across 2 devices"
assert len(ttnn_residual.devices()) == 2, "Expected residual output tensor to be sharded across 2 devices"
check_pcc_conv(torch_residual, ttnn_residual, mesh_composer=output_mesh_composer)
check_pcc_pool(torch_output, ttnn_output, mesh_composer=output_mesh_composer)
4 changes: 2 additions & 2 deletions models/experimental/functional_unet/tests/test_unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True)
def test_unet_model(batch, groups, device, use_program_cache):
def test_unet_model(batch, groups, device, use_program_cache, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)

Expand All @@ -30,4 +30,4 @@ def test_unet_model(batch, groups, device, use_program_cache):

B, C, H, W = torch_output_tensor.shape
ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.985)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.99)
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import ttnn

from loguru import logger

from models.experimental.functional_unet.tt.model_preprocessing import (
create_unet_input_tensors,
create_unet_model_parameters,
)
from models.experimental.functional_unet.tt import unet_shallow_torch
from models.experimental.functional_unet.tt import unet_shallow_ttnn
from models.experimental.functional_unet.tests.common import (
check_pcc_conv,
is_n300_with_eth_dispatch_cores,
)


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True)
def test_unet_multi_device_model(batch, groups, device_mesh, use_program_cache, reset_seeds):
if not is_n300_with_eth_dispatch_cores(device_mesh):
pytest.skip("Test is only valid for N300")

inputs_mesh_mapper = ttnn.ShardTensorToMesh(device_mesh, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device_mesh)
output_mesh_composer = ttnn.ConcatMeshToTensor(device_mesh, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(device_mesh, batch, groups, pad_input=True)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device_mesh)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device=device_mesh, mesh_mapper=weights_mesh_mapper)

num_devices = len(device_mesh.get_device_ids())
torch_input, ttnn_input = create_unet_input_tensors(
device_mesh, num_devices * batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={device_mesh.get_device_ids()}"
)

torch_output_tensor = model(torch_input)
output_tensor = ttnn_model(ttnn_input, list(torch_input.shape))

check_pcc_conv(torch_output_tensor, output_tensor, mesh_composer=output_mesh_composer, pcc=0.99)
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

@pytest.mark.parametrize("batch, groups", [(2, 1)])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_output_layer(batch, groups, device):
def test_unet_output_layer(batch, groups, device, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True, input_channels=16)
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False, input_channels=16)
torch_output = model.output_layer(torch_input)

ttnn_input = ttnn.to_device(ttnn_input, device)
ttnn_output = ttnn_model.output_layer(ttnn_input)

B, C, H, W = torch_output.shape
ttnn_output = ttnn.to_torch(ttnn_output)
assert list(ttnn_output.shape) == [1, 1, B * H * W, 32], "Expected output layer to return padded output"
ttnn_output = ttnn_output.reshape(B, H, W, 32)[:, :, :, :C].permute(0, 3, 1, 2)
assert list(ttnn_output.shape) == [1, 1, B * H * W, C], "Expected output layer to be [1, 1, BHW, C]"
ttnn_output = ttnn_output.reshape(B, H, W, C).permute(0, 3, 1, 2)
assert_with_pcc(torch_output, ttnn_output, 0.99)
Loading

0 comments on commit 2035fd0

Please sign in to comment.