Skip to content

Commit

Permalink
Merge branch 'main' into llong/i2s-alignment-fix-dram-to-l1
Browse files Browse the repository at this point in the history
  • Loading branch information
llongTT authored Nov 27, 2024
2 parents 2c7b4d6 + 8750d3b commit b569c69
Show file tree
Hide file tree
Showing 114 changed files with 3,101 additions and 1,889 deletions.
41 changes: 31 additions & 10 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def get_dispatch_core_type():
return dispatch_core_type


def get_dispatch_core_config(device_params):
import ttnn

dispatch_core_type = get_dispatch_core_type()
dispatch_core_axis = device_params.pop(
"dispatch_core_axis",
ttnn.DispatchCoreAxis.COL if os.environ["ARCH_NAME"] == "blackhole" else ttnn.DispatchCoreAxis.ROW,
)
dispatch_core_config = ttnn.DispatchCoreConfig(dispatch_core_type, dispatch_core_axis)
return dispatch_core_config


@pytest.fixture(scope="function")
def device_params(request):
return getattr(request, "param", {})
Expand All @@ -117,7 +129,8 @@ def device(request, device_params):

num_devices = ttnn.GetNumPCIeDevices()
assert device_id < num_devices, "CreateDevice not supported for non-mmio device"
device = ttnn.CreateDevice(device_id=device_id, dispatch_core_type=get_dispatch_core_type(), **device_params)
dispatch_core_config = get_dispatch_core_config(device_params)
device = ttnn.CreateDevice(device_id=device_id, dispatch_core_config=dispatch_core_config, **device_params)
ttnn.SetDefaultDevice(device)

yield device
Expand All @@ -137,7 +150,8 @@ def pcie_devices(request, device_params):
request.node.pci_ids = device_ids

# Get only physical devices
devices = ttnn.CreateDevices(device_ids, dispatch_core_type=get_dispatch_core_type(), **device_params)
dispatch_core_config = get_dispatch_core_config(device_params)
devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params)

yield [devices[i] for i in range(num_devices)]

Expand All @@ -156,7 +170,8 @@ def all_devices(request, device_params):
request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids]

# Get only physical devices
devices = ttnn.CreateDevices(device_ids, dispatch_core_type=get_dispatch_core_type(), **device_params)
dispatch_core_config = get_dispatch_core_config(device_params)
devices = ttnn.CreateDevices(device_ids, dispatch_core_config=dispatch_core_config, **device_params)

yield [devices[i] for i in range(num_devices)]

Expand Down Expand Up @@ -207,7 +222,10 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(mesh_shape, dispatch_core_type=get_dispatch_core_type(), **device_params)
dispatch_core_config = get_dispatch_core_config(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=mesh_shape, dispatch_core_config=dispatch_core_config, **device_params
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand All @@ -234,9 +252,10 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

request.node.pci_ids = device_ids[:num_pcie_devices_requested]

dispatch_core_config = get_dispatch_core_config(device_params)
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 2),
dispatch_core_type=get_dispatch_core_type(),
mesh_shape=ttnn.MeshShape(2, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
offset=(0, 1),
mesh_type=ttnn.MeshType.Ring,
Expand All @@ -259,9 +278,10 @@ def n300_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
if ttnn.get_num_devices() < 2:
pytest.skip()

dispatch_core_config = get_dispatch_core_config(device_params)
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, 2),
dispatch_core_type=get_dispatch_core_type(),
mesh_shape=ttnn.MeshShape(1, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
)

Expand All @@ -283,9 +303,10 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
pytest.skip()

request.node.pci_ids = ttnn.get_pcie_device_ids()
dispatch_core_config = get_dispatch_core_config(device_params)
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
dispatch_core_type=get_dispatch_core_type(),
mesh_shape=ttnn.MeshShape(2, 4),
dispatch_core_config=dispatch_core_config,
**device_params,
mesh_type=ttnn.MeshType.Ring,
)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/wormhole/yolov4/test_yolov4_performant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_run_yolov4_inference(device, use_program_cache, batch_size, act_dtype,


@run_for_wormhole_b0()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1636352}], indirect=True)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1843200}], indirect=True)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
((1, ttnn.bfloat16, ttnn.bfloat16),),
Expand All @@ -50,7 +50,7 @@ def test_run_yolov4_trace_inference(

@run_for_wormhole_b0()
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 24576, "trace_region_size": 1636352, "num_command_queues": 2}], indirect=True
"device_params", [{"l1_small_size": 24576, "trace_region_size": 3686400, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
Expand Down
29 changes: 28 additions & 1 deletion models/demos/yolov4/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,34 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class
if not is_torch_model:
input_shape = img.shape
input_tensor = torch.permute(img, (0, 2, 3, 1))
input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16)
# input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16)
input_tensor = torch.permute(img, (0, 2, 3, 1)) # put channel at the end
input_tensor = torch.nn.functional.pad(
input_tensor, (0, 13, 0, 0, 0, 0, 0, 0)
) # pad channel dim from 3 to 16
N, H, W, C = input_tensor.shape
input_tensor = torch.reshape(input_tensor, (N, 1, H * W, C))

shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 7),
),
}
)
n_cores = 64
shard_spec = ttnn.ShardSpec(shard_grid, [N * H * W // n_cores, C], ttnn.ShardOrientation.ROW_MAJOR, False)
input_mem_config = ttnn.MemoryConfig(
ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec
)
input_tensor = ttnn.from_torch(
input_tensor,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=device,
memory_config=input_mem_config,
)
img = input_tensor
t1 = time.time()

Expand Down
2 changes: 1 addition & 1 deletion models/demos/yolov4/tests/yolov4_test_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def validate(self, output_tensor=None):
output_tensor = output_tensor.reshape(1, 40, 40, 255)
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))

valid_pcc = 0.99
valid_pcc = 0.985
self.pcc_passed, self.pcc_message = assert_with_pcc(self.torch_output_tensor[0], output_tensor, pcc=valid_pcc)

logger.info(
Expand Down
3 changes: 3 additions & 0 deletions models/demos/yolov4/ttnn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
activation="",
fused_op=True,
width_sharding=False,
output_layout=ttnn.TILE_LAYOUT,
) -> None:
if fused_op:
self.weights, self.bias = fold_bn_to_conv_weights_bias(model, path)
Expand All @@ -57,6 +58,7 @@ def __init__(
self.out_channels = self.weights.shape[0]
self.act_block_h = act_block_h
self.reshard = reshard
self.output_layout = output_layout

if width_sharding:
self.shard_layout = ttnn.TensorMemoryLayout.WIDTH_SHARDED
Expand Down Expand Up @@ -86,6 +88,7 @@ def __call__(self, device, input_tensor):
reshard_if_not_optimal=self.reshard,
deallocate_activation=self.deallocate,
reallocate_halo_output=False,
output_layout=self.output_layout,
)
if self.act_block_h is not None:
conv_config.act_block_h_override = self.act_block_h
Expand Down
11 changes: 10 additions & 1 deletion models/demos/yolov4/ttnn/downsample1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import ttnn
from models.demos.yolov4.ttnn.common import Conv
from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


class Down1:
Expand All @@ -15,7 +17,7 @@ def __init__(self, model) -> None:
torch_model = model.torch_model
self.torch_model = torch_model
self.conv1 = Conv(torch_model, "down1.conv1", [1, 320, 320, 3], (1, 1, 1, 1), act_block_h=128)
self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1), reshard=True)
self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1))
self.conv3 = Conv(torch_model, "down1.conv3", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False)
self.conv4 = Conv(torch_model, "down1.conv4", [1, 160, 160, 64], (1, 1, 0, 0))
self.conv5 = Conv(torch_model, "down1.conv5", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False)
Expand All @@ -30,6 +32,13 @@ def __call__(self, device, input_tensor):
output_tensor_split = self.conv2(device, output_tensor)
output_tensor_split = ttnn.mish(output_tensor_split)

shard_grid = get_shard_grid_from_num_cores(50, device)
shard_spec = ttnn.ShardSpec(shard_grid, (512, 64), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config_conv_5 = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec
)
output_tensor_split = ttnn.to_memory_config(output_tensor_split, memory_config=in_sharded_mem_config_conv_5)

output_tensor_left = self.conv3(device, output_tensor_split)
output_tensor_left = ttnn.mish(output_tensor_left)

Expand Down
4 changes: 3 additions & 1 deletion models/demos/yolov4/ttnn/downsample4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def __init__(self, model) -> None:
else:
torch_model = model.torch_model
self.torch_model = torch_model
self.conv1 = Conv(torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True)
self.conv1 = Conv(
torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True, height_sharding=False
)
self.conv2 = Conv(torch_model, "down4.conv2", [1, 20, 20, 512], (1, 1, 0, 0), deallocate=False)
self.conv3 = Conv(torch_model, "down4.conv3", [1, 20, 20, 512], (1, 1, 0, 0))

Expand Down
11 changes: 10 additions & 1 deletion models/demos/yolov4/ttnn/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ def __init__(self, model) -> None:
self.torch_model = torch_model
self.conv1 = Conv(torch_model, "head.conv1", [1, 40, 40, 128], (1, 1, 1, 1), reshard=True, deallocate=False)
self.conv2 = Conv(torch_model, "head.conv2", [1, 40, 40, 256], (1, 1, 0, 0), fused_op=False)
self.conv3 = Conv(torch_model, "head.conv3", [1, 40, 40, 128], (2, 2, 1, 1), reshard=True, deallocate=False)
self.conv3 = Conv(
torch_model,
"head.conv3",
[1, 40, 40, 128],
(2, 2, 1, 1),
reshard=True,
deallocate=False,
height_sharding=False,
)
self.conv4 = Conv(
torch_model,
"head.conv4",
Expand Down Expand Up @@ -71,6 +79,7 @@ def __init__(self, model) -> None:
[1, 20, 20, 256],
(2, 2, 1, 1),
reshard=True,
height_sharding=False,
)
self.conv12 = Conv(
torch_model,
Expand Down
85 changes: 77 additions & 8 deletions models/demos/yolov4/ttnn/neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, model) -> None:
"neek.conv3",
[1, 10, 10, 1024],
(1, 1, 0, 0),
reshard=True,
reshard=False,
)

self.conv4 = Conv(
Expand Down Expand Up @@ -99,7 +99,6 @@ def __init__(self, model) -> None:
"neek.conv12",
[1, 20, 20, 256],
(1, 1, 1, 1),
reshard=True,
)
self.conv7_5 = Conv(
torch_model,
Expand All @@ -115,6 +114,7 @@ def __init__(self, model) -> None:
[1, 20, 20, 256],
(1, 1, 0, 0),
deallocate=False,
height_sharding=False,
)
self.conv9_2 = Conv(
torch_model,
Expand Down Expand Up @@ -223,9 +223,38 @@ def __call__(self, device, input_tensor):
output_tensor = self.conv7(device, output_tensor_left_1)
output_tensor = ttnn.leaky_relu(output_tensor, negative_slope=0.1)

output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor_upsample_1 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG)
output_shape = output_tensor.shape
output_tensor = ttnn.untilize_with_unpadding(
output_tensor,
output_tensor_end=(
output_shape[0] - 1,
output_shape[1] - 1,
output_shape[2] - 1,
output_shape[3] - 1,
),
memory_config=ttnn.L1_MEMORY_CONFIG,
)

output_tensor = ttnn.reshape(output_tensor, (1, 10, 10, 256))
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 4),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (20, 32), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config)
shard_spec = ttnn.ShardSpec(shard_grid, (80, 32), ttnn.ShardOrientation.ROW_MAJOR, False)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)

output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config)
output_tensor_upsample_1 = ttnn.sharded_to_interleaved(output_tensor_upsample_1, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_1 = ttnn.reshape(output_tensor_upsample_1, (1, 1, 400, 256))
output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT)

outDowSample5 = input_tensor[1]
Expand Down Expand Up @@ -254,12 +283,52 @@ def __call__(self, device, input_tensor):
output_tensor = self.conv7_5(device, output_tensor)
output_tensor_left_2 = ttnn.leaky_relu(output_tensor, negative_slope=0.1)

shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(6, 3),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (64, 64), ttnn.ShardOrientation.COL_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor_left_2 = ttnn.to_memory_config(output_tensor_left_2, memory_config=in_sharded_mem_config)
output_tensor = self.conv9(device, output_tensor_left_2)
output_tensor = ttnn.leaky_relu(output_tensor, negative_slope=0.1)

output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor_upsample_2 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG)
output_shape = output_tensor.shape
output_tensor = ttnn.untilize_with_unpadding(
output_tensor,
output_tensor_end=(
output_shape[0] - 1,
output_shape[1] - 1,
output_shape[2] - 1,
output_shape[3] - 1,
),
memory_config=ttnn.L1_MEMORY_CONFIG,
)

output_tensor = ttnn.reshape(output_tensor, (1, 20, 20, 128))
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 4),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (80, 16), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config)
shard_spec = ttnn.ShardSpec(shard_grid, (80 * 4, 16), ttnn.ShardOrientation.ROW_MAJOR, False)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)

output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config)
output_tensor_upsample_2 = ttnn.sharded_to_interleaved(output_tensor_upsample_2, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_2 = ttnn.reshape(output_tensor_upsample_2, (1, 1, 1600, 128))
output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT)

outDowSample3 = input_tensor[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, groups, expected_device_perf_fps",
((1, 2, 1115.0),),
((1, 2, 1122.0),),
)
def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float):
command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]"
Expand Down
Binary file modified models/perf/images/example_perf_report.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tech_reports/LLMs/images/4.6-op-to-op-gap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tech_reports/LLMs/images/4.6-overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit b569c69

Please sign in to comment.