Skip to content

Commit

Permalink
#0: Add support for running UNet Shallow data transfers on same CQ
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Oct 22, 2024
1 parent c2caa95 commit bc40fbd
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 8 deletions.
2 changes: 1 addition & 1 deletion models/experimental/functional_unet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ When running this model on N300 or T3000, make sure to place dispatch on etherne
To run UNet Shallow for multiple iterations on single-chip at the best performance:

```sh
pytest --disable-warnings models/experimental/functional_unet/tests/test_unet_trace.py::test_unet_trace_2cq
pytest --disable-warnings models/experimental/functional_unet/tests/test_unet_trace.py::test_unet_trace_2cq_same_io
```

To run UNet Shallow for multiple iterations on N300 and T3000 at the best performance:
Expand Down
188 changes: 181 additions & 7 deletions models/experimental/functional_unet/tests/test_unet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
@pytest.mark.parametrize("device_params", [{"l1_small_size": 68864, "trace_region_size": 444416}], indirect=True)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 16),),
((2, 1, 32),),
)
def test_unet_trace(
batch: int,
Expand All @@ -49,26 +49,57 @@ def test_unet_trace(
parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

dram_grid_size = device.dram_grid_size()
dram_shard_spec = ttnn.ShardSpec(
ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(dram_grid_size.x - 1, dram_grid_size.y - 1))}
),
[
divup(ttnn_input.volume() // ttnn_input.shape[-1], dram_grid_size.x),
ttnn_input.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
)
dram_memory_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, dram_shard_spec
)
input_tensor = ttnn.allocate_tensor_on_device(
ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, ttnn.DRAM_MEMORY_CONFIG
ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, dram_memory_config
)

logger.info(f"Compiling model with warmup run")
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=0)
output_tensor = ttnn_model(input_tensor).cpu()
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config)
output_tensor = ttnn_model(l1_input_tensor, move_input_tensor_to_device=False)
logger.info(f"Done compile run")

logger.info(f"Capturing trace")
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=0)
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config)

input_trace_addr = l1_input_tensor.buffer_address()
shape = l1_input_tensor.shape
dtype = l1_input_tensor.dtype
layout = l1_input_tensor.layout
output_tensor.deallocate(force=True)

tid = ttnn.begin_trace_capture(device, cq_id=0)
output_tensor = ttnn_model(input_tensor)
output_tensor = ttnn_model(l1_input_tensor, move_input_tensor_to_device=False)

# Try allocating our persistent input tensor here and verifying it matches the address that trace captured
l1_input_tensor = ttnn.allocate_tensor_on_device(
shape, dtype, layout, device, ttnn_model.input_sharded_memory_config
)
assert input_trace_addr == l1_input_tensor.buffer_address()
ttnn.end_trace_capture(device, tid, cq_id=0)

logger.info(f"Running trace for {iterations} iterations...")

outputs = []
start = time.time()
for _ in range(iterations):
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=0)
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config, l1_input_tensor)
ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
outputs.append(output_tensor.cpu(blocking=False))
ttnn.synchronize_device(device)
Expand All @@ -86,7 +117,7 @@ def test_unet_trace(
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 16),),
((2, 1, 32),),
)
def test_unet_trace_2cq(
batch: int,
Expand Down Expand Up @@ -178,6 +209,7 @@ def test_unet_trace_2cq(
outputs.append(output_tensor.cpu(blocking=False))
ttnn.synchronize_device(device)
end = time.time()
logger.info(f"Average model time={1000.0 * (end-start) / iterations : .2f} ms")
logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps")

ttnn.DumpDeviceProfiler(device)
Expand All @@ -203,7 +235,7 @@ def buffer_address(tensor):
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 16),),
((2, 1, 32),),
)
def test_unet_trace_2cq_multi_device(
batch: int, groups: int, iterations: int, mesh_device, use_program_cache, reset_seeds, enable_async_mode
Expand Down Expand Up @@ -317,3 +349,145 @@ def test_unet_trace_2cq_multi_device(
check_pcc_conv(torch_output_tensor, outputs[-1], UNET_FULL_MODEL_PCC, mesh_composer=output_mesh_composer)

ttnn.release_trace(mesh_device, tid)


@skip_for_grayskull("UNet not currently supported on GS")
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 68864, "trace_region_size": 423936, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((2, 1, 32),),
)
def test_unet_trace_2cq_same_io(
batch: int,
groups: int,
iterations: int,
device,
use_program_cache,
reset_seeds,
):
torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)

model = unet_shallow_torch.UNet.from_random_weights(groups=1)
torch_output_tensor = model(torch_input)

parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

op_event = ttnn.create_event(device)
write_event = ttnn.create_event(device)
model_event = ttnn.create_event(device)
read_event = ttnn.create_event(device)

dram_grid_size = device.dram_grid_size()
dram_shard_spec = ttnn.ShardSpec(
ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(dram_grid_size.x - 1, dram_grid_size.y - 1))}
),
[
divup(ttnn_input.volume() // ttnn_input.shape[-1], dram_grid_size.x),
ttnn_input.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
)
dram_memory_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, dram_shard_spec
)

input_tensor = ttnn.allocate_tensor_on_device(
ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, dram_memory_config
)
ttnn.record_event(0, op_event)
ttnn.record_event(1, read_event)

logger.info(f"Compiling model with warmup run")
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1)

ttnn.record_event(1, write_event)
ttnn.wait_for_event(0, write_event)

l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config)
ttnn.record_event(0, op_event)
output_tensor = ttnn_model(l1_input_tensor, move_input_tensor_to_device=False)
dram_shard_spec = ttnn.ShardSpec(
ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(dram_grid_size.x - 1, dram_grid_size.y - 1))}
),
[
divup(output_tensor.volume() // output_tensor.shape[-1], dram_grid_size.x),
output_tensor.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
)
dram_memory_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, dram_shard_spec
)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config)
logger.info(f"Done compile run")

logger.info(f"Capturing trace")
ttnn.wait_for_event(1, op_event)
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1)
ttnn.record_event(1, write_event)
ttnn.wait_for_event(0, write_event)
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config)
ttnn.record_event(0, op_event)

input_trace_addr = l1_input_tensor.buffer_address()
shape = l1_input_tensor.shape
dtype = l1_input_tensor.dtype
layout = l1_input_tensor.layout
output_tensor.deallocate(force=True)

tid = ttnn.begin_trace_capture(device, cq_id=0)
output_tensor = ttnn_model(l1_input_tensor, move_input_tensor_to_device=False)

# Try allocating our persistent input tensor here and verifying it matches the address that trace captured
l1_input_tensor = ttnn.allocate_tensor_on_device(
shape, dtype, layout, device, ttnn_model.input_sharded_memory_config
)
assert input_trace_addr == l1_input_tensor.buffer_address()
ttnn.end_trace_capture(device, tid, cq_id=0)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor)
ttnn.synchronize_device(device)

outputs = []
start = time.time()
ttnn.wait_for_event(1, op_event)
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1)
ttnn.record_event(1, write_event)
for _ in range(iterations - 1):
ttnn.wait_for_event(0, write_event)
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config, l1_input_tensor)
ttnn.record_event(0, op_event)
ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
ttnn.wait_for_event(0, read_event)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor)
ttnn.record_event(0, model_event)
ttnn.wait_for_event(1, op_event)
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1)
ttnn.record_event(1, write_event)
ttnn.wait_for_event(1, model_event)
outputs.append(dram_output_tensor.cpu(blocking=False, cq_id=1))
ttnn.record_event(1, read_event)
ttnn.wait_for_event(0, write_event)
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config, l1_input_tensor)
ttnn.record_event(0, op_event)
ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
ttnn.wait_for_event(0, read_event)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor)
ttnn.record_event(0, model_event)
ttnn.wait_for_event(1, model_event)
outputs.append(dram_output_tensor.cpu(blocking=False, cq_id=1))
ttnn.synchronize_device(device)
end = time.time()
logger.info(f"Average model time={1000.0 * (end-start) / iterations : .2f} ms")
logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps")

logger.info(f"Running sanity check against reference model output")
check_pcc_conv(torch_output_tensor, outputs[-1], UNET_FULL_MODEL_PCC)
ttnn.release_trace(device, tid)

0 comments on commit bc40fbd

Please sign in to comment.