Skip to content

Commit

Permalink
#5079: add unit tests for vanilla YOLOv4
Browse files Browse the repository at this point in the history
  • Loading branch information
dvartaniansTT committed Mar 25, 2024
1 parent 50c10d2 commit db3d6ac
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions tests/ttnn/unit_tests/operations/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,3 +1632,109 @@ def test_yolov4_conv_groups_1_low_resolution(
padded_input_channels=16 if input_channels == 3 else None,
output_layout=output_layout,
)


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant",
(
(1, 64, 64, 80, 80, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 255, 512, 20, 20, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 64, 128, 80, 80, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 32, 3, 320, 320, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 128, 128, 80, 80, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 128, 256, 20, 20, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 512, 256, 20, 20, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 64, 32, 320, 320, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 256, 512, 20, 20, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 64, 64, 80, 80, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 256, 128, 40, 40, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 512, 1024, 10, 10, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 512, 256, 40, 40, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 64, 128, 160, 160, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 256, 512, 10, 10, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 256, 128, 40, 40, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 64, 64, 160, 160, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 512, 512, 20, 20, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 512, 256, 20, 20, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 128, 256, 40, 40, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 255, 256, 40, 40, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 512, 2048, 10, 10, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 512, 512, 10, 10, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 64, 32, 160, 160, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 128, 128, 40, 40, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 32, 64, 160, 160, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 256, 128, 80, 80, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 256, 256, 20, 20, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 256, 256, 20, 20, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 512, 512, 10, 10, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 1024, 512, 10, 10, 3, 3, 1, 1, 1, 1, True, None, False),
(1, 128, 64, 160, 160, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 255, 1024, 10, 10, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 1024, 1024, 10, 10, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 1024, 512, 20, 20, 3, 3, 2, 2, 1, 1, True, None, False),
(1, 256, 256, 40, 40, 1, 1, 1, 1, 0, 0, True, None, False),
(1, 128, 128, 40, 40, 1, 1, 1, 1, 0, 0, True, None, False),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
)
@pytest.mark.parametrize(
"activations_dtype",
# [ttnn.bfloat8_b, ttnn.bfloat16],
[ttnn.bfloat8_b],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
# @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT])
def test_yolov4_REPO_320_320(
device,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant,
output_layout,
):
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")
if output_layout == ttnn.ROW_MAJOR_LAYOUT and input_height >= 1056:
pytest.skip("OOM")
run_conv(
device,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
# groups=groups,
padded_input_channels=16 if input_channels == 3 else None,
output_layout=output_layout,
)

0 comments on commit db3d6ac

Please sign in to comment.