Skip to content

Commit

Permalink
#13405: TTNN implementation of LENET model
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Oct 16, 2024
1 parent a73be8a commit ec20433
Show file tree
Hide file tree
Showing 8 changed files with 504 additions and 1 deletion.
31 changes: 31 additions & 0 deletions models/demos/lenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# LENET

# Platforms:
E150, WH N300, N150

## Introduction

The LeNet model is a foundational convolutional neural network (CNN) architecture that was specifically developed for handwritten digit recognition on the MNIST dataset. This pioneering model consists of several convolutional layers interspersed with pooling layers, followed by fully connected layers that output the final classification. By utilizing convolutional layers, LeNet effectively captures spatial hierarchies and local patterns in images, leading to significantly enhanced performance compared to traditional, simpler architectures. Its design laid the groundwork for many modern deep learning models used in image classification tasks today.

### Batch size: 8

Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8

## How to Run

To run the demo for digit classification using the LeNet model, follow these instructions:

Ensure you have the necessary dependencies installed and that your environment is set up correctly for running the model.

Use the following command to execute the LeNet demo
```
pytest models/demos/lenet/demo/demo.py::test_demo_dataset
```
This command will initiate the test for the demo dataset, allowing you to observe the model's performance in classifying handwritten digits


## Inputs

The demo accepts inputs from the MNIST dataset, which consists of a large collection of labeled handwritten digits. The dataset provides a diverse range of examples, enabling the model to learn and generalize effectively. Each input consists of a grayscale image of a handwritten digit, which is processed through the model to produce a predicted classification.

### Owner: [sabira-mcw](https://github.com/sabira-mcw)
71 changes: 71 additions & 0 deletions models/demos/lenet/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn

from torchvision import transforms, datasets
from loguru import logger

from torch.utils.data import DataLoader

from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.lenet.tt import tt_lenet
from models.demos.lenet import lenet_utils


def run_demo_dataset(device, batch_size, iterations, model_location_generator, reset_seeds):
num_classes = 10
test_input, images, outputs = lenet_utils.get_test_data(batch_size)

pt_model_path = model_location_generator("model.pt", model_subdir="LeNet")
torch_LeNet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes)
model = torch_LeNet.float()
model = torch_LeNet.eval()

torch_output = model(test_input)
parameters = preprocess_model_parameters(
initialize_model=lambda: torch_LeNet,
custom_preprocessor=lenet_utils.custom_preprocessor,
)
parameters = lenet_utils.custom_preprocessor_device(parameters, device)
correct = 0
for iters in range(iterations):
x = test_input.permute(0, 2, 3, 1)
x = ttnn.from_torch(x, dtype=ttnn.bfloat16)
tt_output = tt_lenet.Lenet(x, model, batch_size, num_classes, device, parameters, reset_seeds)
tt_output = ttnn.to_torch(tt_output)
_, torch_predicted = torch.max(torch_output.data, -1)
_, ttnn_predicted = torch.max(tt_output.data, -1)

for i in range(batch_size):
logger.info(f"Iter: {iters} Sample {i}:")
logger.info(f"torch Label: {torch_predicted[i]}")
logger.info(f"Predicted Label: {ttnn_predicted[i]}")

if torch_predicted[i] == ttnn_predicted[i]:
correct += 1

accuracy = correct / (batch_size * iterations)
logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("iterations", [1])
def test_demo_dataset(
device,
batch_size,
iterations,
model_location_generator,
reset_seeds,
):
return run_demo_dataset(
reset_seeds=reset_seeds,
device=device,
batch_size=batch_size,
iterations=iterations,
model_location_generator=model_location_generator,
)
102 changes: 102 additions & 0 deletions models/demos/lenet/lenet_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import torchvision
import torchvision.transforms as transforms
from models.experimental.lenet.reference.lenet import LeNet5
import ttnn


def get_test_data(batch_size=64):
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.1325,), std=(0.3105,)),
]
)

test_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
)

batch = []
images = []
outputs = []

for i in range(batch_size):
img, output = test_dataset[i]
tensor = transform(img).unsqueeze(0)
batch.append(tensor)
images.append(img)
outputs.append(output)

batch = torch.cat(batch)
return batch, images, outputs


def load_torch_lenet(weka_path, num_classes):
model2 = LeNet5(num_classes).to("cpu")
checkpoint = torch.load(weka_path, map_location=torch.device("cpu"))
model2.load_state_dict(checkpoint["model_state_dict"])
model2.eval()
return model2, checkpoint["model_state_dict"]


def custom_preprocessor(model, device):
parameters = {}

layers_to_process = ["layer1", "layer2", "fc", "fc1", "fc2"]

for layer in layers_to_process:
if layer.startswith("layer"):
conv_layer = getattr(model, layer)[0]
bn_layer = getattr(model, layer)[1]

weight = conv_layer.weight
bias = conv_layer.bias

running_mean = bn_layer.running_mean
running_var = bn_layer.running_var
eps = 1e-05

scale = bn_layer.weight
shift = bn_layer.bias

weight = weight * (scale / torch.sqrt(running_var + eps))[:, None, None, None]

if bias is not None:
bias = (bias - running_mean) * (scale / torch.sqrt(running_var + eps)) + shift
else:
bias = shift - running_mean * (scale / torch.sqrt(running_var + eps))

weight = ttnn.from_torch(weight, dtype=ttnn.bfloat16)
bias = ttnn.from_torch(bias, dtype=ttnn.bfloat16)
bias = ttnn.reshape(bias, (1, 1, 1, -1))

else: # Handling linear layers
linear_layer = getattr(model, layer)
weight = linear_layer.weight
weight = torch.permute(weight, (1, 0))
bias = linear_layer.bias
weight = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
bias = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)

parameters[layer] = {"weight": weight, "bias": bias}

return parameters


def custom_preprocessor_device(parameters, device):
parameters.fc.weight = ttnn.to_device(parameters.fc.weight, device)
parameters.fc.bias = ttnn.to_device(parameters.fc.bias, device)
parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device)
parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device)
parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device)
parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device)

return parameters
130 changes: 130 additions & 0 deletions models/demos/lenet/tests/test_perf_lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
import time
from pathlib import Path

from torchvision import models
from loguru import logger
import ttnn
from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from models.demos.lenet.tt import tt_lenet
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.demos.lenet import lenet_utils
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
from models.perf.perf_utils import prep_perf_report


def get_expected_times(tt_lenet):
if is_grayskull():
return {
tt_lenet: (3.4, 0.58),
}[tt_lenet]
elif is_wormhole_b0():
return {
tt_lenet: (9.52, 0.91),
}[tt_lenet]


@pytest.mark.parametrize(
"batch_size",
[8],
)
@pytest.mark.parametrize(
"tt_lenet",
[tt_lenet],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.models_performance_bare_metal
def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, reset_seeds):
num_classes = 10
test_input, images, outputs = lenet_utils.get_test_data(batch_size)

pt_model_path = model_location_generator("model.pt", model_subdir="LeNet")
torch_LeNet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes)
model = torch_LeNet.float()
model = model.eval()
disable_persistent_kernel_cache()

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_LeNet,
custom_preprocessor=lenet_utils.custom_preprocessor,
)
parameters = lenet_utils.custom_preprocessor_device(parameters, device)

x = test_input.permute(0, 2, 3, 1)
x = ttnn.from_torch(x, dtype=ttnn.bfloat16)
durations = []
for _ in range(2):
start = time.time()

ttnn_output = tt_lenet.Lenet(
device=device,
model=model,
input_tensor=x,
batch_size=batch_size,
parameters=parameters,
num_classes=num_classes,
reset_seeds=reset_seeds,
)
end = time.time()
durations.append(end - start)

inference_and_compile_time, *inference_times = durations
average_inference_time = sum(inference_times) / len(inference_times)
expected_compile_time, expected_inference_time = get_expected_times(tt_lenet)

prep_perf_report(
model_name="tt_lenet",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=average_inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}")
logger.info(f"Inference time: {average_inference_time}")
logger.info(f"Inference times: {inference_times}")
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}")


@pytest.mark.parametrize(
"batch_size",
[8],
)
@pytest.mark.models_device_performance_bare_metal
def test_perf_device_bare_metal(batch_size, reset_seeds):
subdir = "tt_lenet"
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 419.5
elif is_wormhole_b0():
expected_perf = 988.29

command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
prep_device_perf_report(
model_name=f"tt_lenet{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
)
Loading

0 comments on commit ec20433

Please sign in to comment.