Skip to content

Commit

Permalink
#13405: Replaced one torch maxpool with ttnn maxpool
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Nov 22, 2024
1 parent bf301b1 commit a309102
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 26 deletions.
6 changes: 5 additions & 1 deletion models/demos/lenet/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from loguru import logger

from torch.utils.data import DataLoader

from models.utility_functions import (
disable_persistent_kernel_cache,
)
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.lenet.tt import tt_lenet
from models.demos.lenet import lenet_utils
Expand Down Expand Up @@ -50,6 +52,7 @@ def run_demo_dataset(device, batch_size, iterations, model_location_generator, r

accuracy = correct / (batch_size * iterations)
logger.info(f"Dataset Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")
assert accuracy >= 1.0, f"Expected accuracy : {1.0} Actual accuracy: {accuracy}"


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
Expand All @@ -62,6 +65,7 @@ def test_demo_dataset(
model_location_generator,
reset_seeds,
):
disable_persistent_kernel_cache()
return run_demo_dataset(
reset_seeds=reset_seeds,
device=device,
Expand Down
25 changes: 15 additions & 10 deletions models/demos/lenet/tests/test_perf_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
def get_expected_times(tt_lenet):
if is_grayskull():
return {
tt_lenet: (7.525, 0.9495),
tt_lenet: (5.20, 0.63291),
}[tt_lenet]
elif is_wormhole_b0():
return {
tt_lenet: (9.52, 0.91),
tt_lenet: (7.95678, 0.8243),
}[tt_lenet]


Expand Down Expand Up @@ -77,26 +77,31 @@ def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, rese
)
end = time.time()
durations.append(end - start)
enable_persistent_kernel_cache()

inference_and_compile_time, *inference_times = durations
average_inference_time = sum(inference_times) / len(inference_times)
inference_time = sum(inference_times) / len(inference_times)
expected_compile_time, expected_inference_time = get_expected_times(tt_lenet)

prep_perf_report(
model_name="tt_lenet",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=average_inference_time,
inference_time=inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}")
logger.info(f"Inference time: {average_inference_time}")
logger.info(f"Compile time: {inference_and_compile_time - inference_time}")
logger.info(f"Inference time: {inference_time}")
logger.info(f"Inference times: {inference_times}")
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}")
logger.info(f"Sample(s) per second: {1 / inference_time * batch_size}")
assert (
inference_time < expected_inference_time
), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}"
logger.info("Exit Lenet perf test")


@pytest.mark.parametrize(
Expand All @@ -109,9 +114,9 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 6330.022
expected_perf = 203642.6580
elif is_wormhole_b0():
expected_perf = 20028.54
expected_perf = 113208.6151

command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand All @@ -120,7 +125,7 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True)
prep_device_perf_report(
model_name=f"tt_lenet{batch_size}",
batch_size=batch_size,
Expand Down
32 changes: 17 additions & 15 deletions models/demos/lenet/tt/tt_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def conv(device, input_tensor, batch_size, parameters):

def Lenet(input_tensor, model, batch_size, num_classes, device, parameters, reset_seeds):
conv_1, out_height, out_width = conv(device, input_tensor, batch_size, parameters.layer1)

conv_1 = ttnn.from_device(conv_1)
conv_1 = ttnn.to_layout(conv_1, layout=ttnn.TILE_LAYOUT)
conv_1 = ttnn.to_device(conv_1, device=device)
conv_1 = ttnn.sharded_to_interleaved(conv_1, ttnn.L1_MEMORY_CONFIG)
conv_1 = ttnn.reshape(conv_1, (batch_size, out_height, out_width, conv_1.shape[-1]))
conv_1 = ttnn.permute(conv_1, (0, 3, 1, 2))
conv_1 = ttnn.to_torch(conv_1)
Expand All @@ -67,21 +64,26 @@ def Lenet(input_tensor, model, batch_size, num_classes, device, parameters, rese

conv_2, out_height, out_width = conv(device, maxpool_1, batch_size, parameters.layer2)

conv_2 = ttnn.from_device(conv_2)
conv_2 = ttnn.to_layout(conv_2, layout=ttnn.TILE_LAYOUT)
conv_2 = ttnn.to_device(conv_2, device=device)
conv_2 = ttnn.reshape(conv_2, (batch_size, out_height, out_width, conv_2.shape[-1]))
conv_2 = ttnn.permute(conv_2, (0, 3, 1, 2))
conv_2 = ttnn.to_torch(conv_2)
conv_2 = ttnn.to_layout(conv_2, layout=ttnn.ROW_MAJOR_LAYOUT)
maxpool_2 = ttnn.max_pool2d(
input_tensor=conv_2,
batch_size=batch_size,
input_h=out_height,
input_w=out_width,
channels=conv_2.shape[3],
kernel_size=[2, 2],
stride=[2, 2],
padding=[0, 0],
dilation=[1, 1],
)

max = nn.MaxPool2d(kernel_size=2, stride=2)
maxpool_2 = max(conv_2)
maxpool_2 = ttnn.sharded_to_interleaved(maxpool_2, ttnn.L1_MEMORY_CONFIG)

maxpool_2 = ttnn.from_torch(maxpool_2, dtype=ttnn.bfloat16)
maxpool_2 = ttnn.to_layout(maxpool_2, layout=ttnn.TILE_LAYOUT)
maxpool_2 = ttnn.reshape(maxpool_2, (batch_size, 5, 5, maxpool_2.shape[3]))

maxpool_2 = ttnn.permute(maxpool_2, (0, 3, 1, 2))
maxpool_2 = ttnn.reshape(maxpool_2, (maxpool_2.shape[0], -1))
maxpool_2 = ttnn.to_device(maxpool_2, device=device)
maxpool_2 = ttnn.to_layout(maxpool_2, layout=ttnn.TILE_LAYOUT)

linear_1 = ttnn.linear(
maxpool_2,
Expand Down

0 comments on commit a309102

Please sign in to comment.