Skip to content

Commit

Permalink
Fold batches into channels and use grouped convolutions in UNet Shallow
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Nov 20, 2024
1 parent eacb47a commit a50e6ea
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 243 deletions.
3 changes: 2 additions & 1 deletion models/experimental/functional_unet/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from tests.ttnn.utils_for_testing import assert_with_pcc

UNET_FULL_MODEL_PCC = 0.9916
UNET_FULL_MODEL_PCC = 0.99995


def is_n300_with_eth_dispatch_cores(mesh_device) -> bool:
Expand All @@ -33,6 +33,7 @@ def verify_with_pcc(torch_tensor, ttnn_tensor, pcc):
)


# TODO: This is the same as the function below, we should consolidate them
def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None):
B, C, H, W = torch_tensor.shape
ttnn_tensor = (
Expand Down
28 changes: 16 additions & 12 deletions models/experimental/functional_unet/tests/test_unet_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@
from models.experimental.functional_unet.tests.common import is_n300_with_eth_dispatch_cores, check_pcc_conv


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("batch", [1])
@pytest.mark.parametrize("groups", [2])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_bottleneck(batch, groups, device, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
def test_unet_bottleneck(batch: int, groups: int, device: ttnn.Device, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

torch_input, ttnn_input = create_unet_input_tensors(
device, batch, groups, pad_input=True, input_channels=32, input_height=66, input_width=10
batch, groups, pad_input=True, input_channels=32, input_height=66, input_width=10
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(f"Created input tensors: shape={list(ttnn_input.shape)}")

torch_output = model.bottleneck(torch_input)

ttnn_input = ttnn.to_device(ttnn_input, device=device)
Expand All @@ -39,27 +42,28 @@ def test_unet_bottleneck(batch, groups, device, reset_seeds):
check_pcc_conv(torch_output, ttnn_output, pcc=0.999)


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("batch", [1])
@pytest.mark.parametrize("groups", [2])
@pytest.mark.parametrize("enable_async_mode", (True,), indirect=True)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_bottleneck_multi_device(batch, groups, mesh_device, reset_seeds, enable_async_mode):
def test_unet_bottleneck_multi_device(
batch: int, groups: int, mesh_device: ttnn.MeshDevice, reset_seeds, enable_async_mode
):
if not is_n300_with_eth_dispatch_cores(mesh_device):
pytest.skip("Test is only valid for N300")

inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=mesh_device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, mesh_device, mesh_mapper=weights_mesh_mapper)

num_devices = len(mesh_device.get_device_ids())
torch_input, ttnn_input = create_unet_input_tensors(
mesh_device,
num_devices * batch,
groups,
pad_input=True,
Expand Down
29 changes: 14 additions & 15 deletions models/experimental/functional_unet/tests/test_unet_downblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)


@pytest.mark.parametrize("batch, groups", [(2, 1)])
@pytest.mark.parametrize("batch, groups", [(1, 2)])
@pytest.mark.parametrize(
"block_name, input_channels, input_height, input_width",
[
Expand All @@ -31,30 +31,30 @@
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_downblock(
batch,
groups,
block_name,
input_channels,
input_height,
input_width,
device,
batch: int,
groups: int,
block_name: str,
input_channels: int,
input_height: int,
input_width: int,
device: ttnn.Device,
reset_seeds,
):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

torch_input, ttnn_input = create_unet_input_tensors(
device,
batch,
groups,
pad_input=True,
input_channels=input_channels,
input_height=input_height,
input_width=input_width,
)

torch_output, torch_residual = getattr(model, block_name)(torch_input)

ttnn_input = ttnn_input.to(device)
Expand All @@ -64,7 +64,7 @@ def test_unet_downblock(
check_pcc_pool(torch_output, ttnn_output)


@pytest.mark.parametrize("batch, groups", [(2, 1)])
@pytest.mark.parametrize("batch, groups", [(1, 2)])
@pytest.mark.parametrize(
"block_name, input_channels, input_height, input_width",
[
Expand All @@ -86,15 +86,14 @@ def test_unet_downblock_multi_device(
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=mesh_device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, mesh_device, mesh_mapper=weights_mesh_mapper)

num_devices = len(mesh_device.get_device_ids())
torch_input, ttnn_input = create_unet_input_tensors(
mesh_device,
num_devices * batch,
groups,
pad_input=True,
Expand Down
8 changes: 4 additions & 4 deletions models/experimental/functional_unet/tests/test_unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from models.experimental.functional_unet.tests.common import check_pcc_conv, UNET_FULL_MODEL_PCC


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("batch", [1])
@pytest.mark.parametrize("groups", [2])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
def test_unet_model(batch, groups, device, use_program_cache, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
)


@pytest.mark.parametrize("batch", [2])
@pytest.mark.parametrize("groups", [1])
@pytest.mark.parametrize("batch", [1])
@pytest.mark.parametrize("groups", [2])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache, reset_seeds):
if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device):
Expand All @@ -32,7 +32,7 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache,
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=mesh_device)
Expand All @@ -42,7 +42,7 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache,
logger.info(f"Using {num_devices} devices for this test")

torch_input, ttnn_input = create_unet_input_tensors(
mesh_device, num_devices * batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
num_devices * batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
from models.experimental.functional_unet.tt import unet_shallow_ttnn


@pytest.mark.parametrize("batch, groups", [(2, 1)])
@pytest.mark.parametrize("batch, groups", [(1, 2)])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_unet_output_layer(batch, groups, device, reset_seeds):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=False)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=False, input_channels=16)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=False, input_channels=16)
torch_output = model.output_layer(torch_input)

ttnn_input = ttnn.to_device(ttnn_input, device=device)
Expand All @@ -35,4 +35,4 @@ def test_unet_output_layer(batch, groups, device, reset_seeds):
ttnn_output = ttnn.to_torch(ttnn_output)
assert list(ttnn_output.shape) == [1, 1, B * H * W, C], "Expected output layer to be [1, 1, BHW, C]"
ttnn_output = ttnn_output.reshape(B, H, W, C).permute(0, 3, 1, 2)
assert_with_pcc(torch_output, ttnn_output, 0.99995)
assert_with_pcc(torch_output, ttnn_output, 0.9995)
20 changes: 11 additions & 9 deletions models/experimental/functional_unet/tests/test_unet_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,25 @@
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, groups, expected_device_perf_fps",
((2, 1, 779.0),),
((1, 2, 977.0),),
)
def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float):
command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

total_batch = groups * batch

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
post_processed_results = run_device_perf(
command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=batch
command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=total_batch
)
expected_perf_cols = {inference_time_key: expected_device_perf_fps}
expected_results = check_device_perf(
post_processed_results, margin=0.01, expected_perf_cols=expected_perf_cols, assert_on_fail=True
)
prep_device_perf_report(
model_name=f"unet-shallow_batch-{batch}_groups-{groups}",
batch_size=batch,
batch_size=total_batch,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
Expand All @@ -62,7 +64,7 @@ def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: flo
@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
@pytest.mark.parametrize(
"batch, groups, iterations, expected_compile_time, expected_inference_time_ms",
((2, 1, 16, 25.0, 39.0),),
((1, 2, 16, 25.0, 39.0),),
)
def test_unet_perf_e2e(
batch: int,
Expand All @@ -76,10 +78,10 @@ def test_unet_perf_e2e(
):
profiler.clear()

torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)

profiler.start(f"initialize_ref_model")
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)
profiler.end(f"initialize_ref_model")

profiler.start(f"initialize_model")
Expand Down Expand Up @@ -137,7 +139,7 @@ def test_unet_perf_e2e(
@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
@pytest.mark.parametrize(
"batch, groups, iterations, expected_compile_time, expected_inference_time_ms",
((2, 1, 16, 25.0, 61.0),),
((1, 2, 16, 25.0, 61.0),),
)
def test_unet_data_parallel_perf_e2e(
batch: int,
Expand All @@ -159,7 +161,7 @@ def test_unet_data_parallel_perf_e2e(
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)

profiler.start(f"initialize_ref_model")
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)
Expand All @@ -173,7 +175,7 @@ def test_unet_data_parallel_perf_e2e(
num_devices = len(mesh_device.get_device_ids())
total_batch = num_devices * batch
torch_input, ttnn_input = create_unet_input_tensors(
mesh_device, total_batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
total_batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
Expand Down
24 changes: 12 additions & 12 deletions models/experimental/functional_unet/tests/test_unet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
@pytest.mark.parametrize("device_params", [{"l1_small_size": 68864, "trace_region_size": 444416}], indirect=True)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 32),),
((1, 2, 32),),
)
def test_unet_trace(
batch: int,
Expand All @@ -41,9 +41,9 @@ def test_unet_trace(
use_program_cache,
reset_seeds,
):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)

model = unet_shallow_torch.UNet.from_random_weights(groups=1)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)
torch_output_tensor = model(torch_input)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_unet_trace(
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 32),),
((1, 2, 32),),
)
def test_unet_trace_2cq(
batch: int,
Expand All @@ -127,9 +127,9 @@ def test_unet_trace_2cq(
use_program_cache,
reset_seeds,
):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)

model = unet_shallow_torch.UNet.from_random_weights(groups=1)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)
torch_output_tensor = model(torch_input)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
Expand Down Expand Up @@ -235,7 +235,7 @@ def buffer_address(tensor):
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 32),),
((1, 2, 32),),
)
def test_unet_trace_2cq_multi_device(
batch: int, groups: int, iterations: int, mesh_device, use_program_cache, reset_seeds, enable_async_mode
Expand All @@ -247,7 +247,7 @@ def test_unet_trace_2cq_multi_device(
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=mesh_device)
Expand All @@ -258,7 +258,7 @@ def test_unet_trace_2cq_multi_device(

total_batch = num_devices * batch
torch_input, ttnn_input = create_unet_input_tensors(
mesh_device, total_batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
total_batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
Expand Down Expand Up @@ -358,7 +358,7 @@ def test_unet_trace_2cq_multi_device(
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 32),),
((1, 2, 32),),
)
def test_unet_trace_2cq_same_io(
batch: int,
Expand All @@ -368,9 +368,9 @@ def test_unet_trace_2cq_same_io(
use_program_cache,
reset_seeds,
):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)
torch_input, ttnn_input = create_unet_input_tensors(batch, groups, pad_input=True)

model = unet_shallow_torch.UNet.from_random_weights(groups=1)
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)
torch_output_tensor = model(torch_input)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
Expand Down
Loading

0 comments on commit a50e6ea

Please sign in to comment.