-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#13405: TTNN implementation of LENET model
- Loading branch information
1 parent
a73be8a
commit ec20433
Showing
8 changed files
with
504 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# LENET | ||
|
||
# Platforms: | ||
E150, WH N300, N150 | ||
|
||
## Introduction | ||
|
||
The LeNet model is a foundational convolutional neural network (CNN) architecture that was specifically developed for handwritten digit recognition on the MNIST dataset. This pioneering model consists of several convolutional layers interspersed with pooling layers, followed by fully connected layers that output the final classification. By utilizing convolutional layers, LeNet effectively captures spatial hierarchies and local patterns in images, leading to significantly enhanced performance compared to traditional, simpler architectures. Its design laid the groundwork for many modern deep learning models used in image classification tasks today. | ||
|
||
### Batch size: 8 | ||
|
||
Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8 | ||
|
||
## How to Run | ||
|
||
To run the demo for digit classification using the LeNet model, follow these instructions: | ||
|
||
Ensure you have the necessary dependencies installed and that your environment is set up correctly for running the model. | ||
|
||
Use the following command to execute the LeNet demo | ||
``` | ||
pytest models/demos/lenet/demo/demo.py::test_demo_dataset | ||
``` | ||
This command will initiate the test for the demo dataset, allowing you to observe the model's performance in classifying handwritten digits | ||
|
||
|
||
## Inputs | ||
|
||
The demo accepts inputs from the MNIST dataset, which consists of a large collection of labeled handwritten digits. The dataset provides a diverse range of examples, enabling the model to learn and generalize effectively. Each input consists of a grayscale image of a handwritten digit, which is processed through the model to produce a predicted classification. | ||
|
||
### Owner: [sabira-mcw](https://github.com/sabira-mcw) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
import ttnn | ||
|
||
from torchvision import transforms, datasets | ||
from loguru import logger | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
from models.demos.lenet.tt import tt_lenet | ||
from models.demos.lenet import lenet_utils | ||
|
||
|
||
def run_demo_dataset(device, batch_size, iterations, model_location_generator, reset_seeds): | ||
num_classes = 10 | ||
test_input, images, outputs = lenet_utils.get_test_data(batch_size) | ||
|
||
pt_model_path = model_location_generator("model.pt", model_subdir="LeNet") | ||
torch_LeNet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes) | ||
model = torch_LeNet.float() | ||
model = torch_LeNet.eval() | ||
|
||
torch_output = model(test_input) | ||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: torch_LeNet, | ||
custom_preprocessor=lenet_utils.custom_preprocessor, | ||
) | ||
parameters = lenet_utils.custom_preprocessor_device(parameters, device) | ||
correct = 0 | ||
for iters in range(iterations): | ||
x = test_input.permute(0, 2, 3, 1) | ||
x = ttnn.from_torch(x, dtype=ttnn.bfloat16) | ||
tt_output = tt_lenet.Lenet(x, model, batch_size, num_classes, device, parameters, reset_seeds) | ||
tt_output = ttnn.to_torch(tt_output) | ||
_, torch_predicted = torch.max(torch_output.data, -1) | ||
_, ttnn_predicted = torch.max(tt_output.data, -1) | ||
|
||
for i in range(batch_size): | ||
logger.info(f"Iter: {iters} Sample {i}:") | ||
logger.info(f"torch Label: {torch_predicted[i]}") | ||
logger.info(f"Predicted Label: {ttnn_predicted[i]}") | ||
|
||
if torch_predicted[i] == ttnn_predicted[i]: | ||
correct += 1 | ||
|
||
accuracy = correct / (batch_size * iterations) | ||
logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}") | ||
|
||
|
||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) | ||
@pytest.mark.parametrize("batch_size", [8]) | ||
@pytest.mark.parametrize("iterations", [1]) | ||
def test_demo_dataset( | ||
device, | ||
batch_size, | ||
iterations, | ||
model_location_generator, | ||
reset_seeds, | ||
): | ||
return run_demo_dataset( | ||
reset_seeds=reset_seeds, | ||
device=device, | ||
batch_size=batch_size, | ||
iterations=iterations, | ||
model_location_generator=model_location_generator, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from models.experimental.lenet.reference.lenet import LeNet5 | ||
import ttnn | ||
|
||
|
||
def get_test_data(batch_size=64): | ||
transform = transforms.Compose( | ||
[ | ||
transforms.Resize((32, 32)), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=(0.1325,), std=(0.3105,)), | ||
] | ||
) | ||
|
||
test_dataset = torchvision.datasets.MNIST( | ||
root="./data", | ||
train=False, | ||
download=True, | ||
) | ||
|
||
batch = [] | ||
images = [] | ||
outputs = [] | ||
|
||
for i in range(batch_size): | ||
img, output = test_dataset[i] | ||
tensor = transform(img).unsqueeze(0) | ||
batch.append(tensor) | ||
images.append(img) | ||
outputs.append(output) | ||
|
||
batch = torch.cat(batch) | ||
return batch, images, outputs | ||
|
||
|
||
def load_torch_lenet(weka_path, num_classes): | ||
model2 = LeNet5(num_classes).to("cpu") | ||
checkpoint = torch.load(weka_path, map_location=torch.device("cpu")) | ||
model2.load_state_dict(checkpoint["model_state_dict"]) | ||
model2.eval() | ||
return model2, checkpoint["model_state_dict"] | ||
|
||
|
||
def custom_preprocessor(model, device): | ||
parameters = {} | ||
|
||
layers_to_process = ["layer1", "layer2", "fc", "fc1", "fc2"] | ||
|
||
for layer in layers_to_process: | ||
if layer.startswith("layer"): | ||
conv_layer = getattr(model, layer)[0] | ||
bn_layer = getattr(model, layer)[1] | ||
|
||
weight = conv_layer.weight | ||
bias = conv_layer.bias | ||
|
||
running_mean = bn_layer.running_mean | ||
running_var = bn_layer.running_var | ||
eps = 1e-05 | ||
|
||
scale = bn_layer.weight | ||
shift = bn_layer.bias | ||
|
||
weight = weight * (scale / torch.sqrt(running_var + eps))[:, None, None, None] | ||
|
||
if bias is not None: | ||
bias = (bias - running_mean) * (scale / torch.sqrt(running_var + eps)) + shift | ||
else: | ||
bias = shift - running_mean * (scale / torch.sqrt(running_var + eps)) | ||
|
||
weight = ttnn.from_torch(weight, dtype=ttnn.bfloat16) | ||
bias = ttnn.from_torch(bias, dtype=ttnn.bfloat16) | ||
bias = ttnn.reshape(bias, (1, 1, 1, -1)) | ||
|
||
else: # Handling linear layers | ||
linear_layer = getattr(model, layer) | ||
weight = linear_layer.weight | ||
weight = torch.permute(weight, (1, 0)) | ||
bias = linear_layer.bias | ||
weight = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) | ||
bias = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) | ||
|
||
parameters[layer] = {"weight": weight, "bias": bias} | ||
|
||
return parameters | ||
|
||
|
||
def custom_preprocessor_device(parameters, device): | ||
parameters.fc.weight = ttnn.to_device(parameters.fc.weight, device) | ||
parameters.fc.bias = ttnn.to_device(parameters.fc.bias, device) | ||
parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device) | ||
parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device) | ||
parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device) | ||
parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device) | ||
|
||
return parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import pytest | ||
import ttnn | ||
import time | ||
from pathlib import Path | ||
|
||
from torchvision import models | ||
from loguru import logger | ||
import ttnn | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
from models.utility_functions import ( | ||
enable_persistent_kernel_cache, | ||
disable_persistent_kernel_cache, | ||
) | ||
from models.demos.lenet.tt import tt_lenet | ||
from models.utility_functions import is_grayskull, is_wormhole_b0 | ||
from models.demos.lenet import lenet_utils | ||
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report | ||
from models.perf.perf_utils import prep_perf_report | ||
|
||
|
||
def get_expected_times(tt_lenet): | ||
if is_grayskull(): | ||
return { | ||
tt_lenet: (3.4, 0.58), | ||
}[tt_lenet] | ||
elif is_wormhole_b0(): | ||
return { | ||
tt_lenet: (9.52, 0.91), | ||
}[tt_lenet] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size", | ||
[8], | ||
) | ||
@pytest.mark.parametrize( | ||
"tt_lenet", | ||
[tt_lenet], | ||
) | ||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) | ||
@pytest.mark.models_performance_bare_metal | ||
def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, reset_seeds): | ||
num_classes = 10 | ||
test_input, images, outputs = lenet_utils.get_test_data(batch_size) | ||
|
||
pt_model_path = model_location_generator("model.pt", model_subdir="LeNet") | ||
torch_LeNet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes) | ||
model = torch_LeNet.float() | ||
model = model.eval() | ||
disable_persistent_kernel_cache() | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: torch_LeNet, | ||
custom_preprocessor=lenet_utils.custom_preprocessor, | ||
) | ||
parameters = lenet_utils.custom_preprocessor_device(parameters, device) | ||
|
||
x = test_input.permute(0, 2, 3, 1) | ||
x = ttnn.from_torch(x, dtype=ttnn.bfloat16) | ||
durations = [] | ||
for _ in range(2): | ||
start = time.time() | ||
|
||
ttnn_output = tt_lenet.Lenet( | ||
device=device, | ||
model=model, | ||
input_tensor=x, | ||
batch_size=batch_size, | ||
parameters=parameters, | ||
num_classes=num_classes, | ||
reset_seeds=reset_seeds, | ||
) | ||
end = time.time() | ||
durations.append(end - start) | ||
|
||
inference_and_compile_time, *inference_times = durations | ||
average_inference_time = sum(inference_times) / len(inference_times) | ||
expected_compile_time, expected_inference_time = get_expected_times(tt_lenet) | ||
|
||
prep_perf_report( | ||
model_name="tt_lenet", | ||
batch_size=batch_size, | ||
inference_and_compile_time=inference_and_compile_time, | ||
inference_time=average_inference_time, | ||
expected_compile_time=expected_compile_time, | ||
expected_inference_time=expected_inference_time, | ||
comments="", | ||
inference_time_cpu=0.0, | ||
) | ||
|
||
logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") | ||
logger.info(f"Inference time: {average_inference_time}") | ||
logger.info(f"Inference times: {inference_times}") | ||
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size", | ||
[8], | ||
) | ||
@pytest.mark.models_device_performance_bare_metal | ||
def test_perf_device_bare_metal(batch_size, reset_seeds): | ||
subdir = "tt_lenet" | ||
num_iterations = 1 | ||
margin = 0.03 | ||
if is_grayskull(): | ||
expected_perf = 419.5 | ||
elif is_wormhole_b0(): | ||
expected_perf = 988.29 | ||
|
||
command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py" | ||
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] | ||
|
||
inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" | ||
expected_perf_cols = {inference_time_key: expected_perf} | ||
|
||
post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) | ||
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) | ||
prep_device_perf_report( | ||
model_name=f"tt_lenet{batch_size}", | ||
batch_size=batch_size, | ||
post_processed_results=post_processed_results, | ||
expected_results=expected_results, | ||
comments="", | ||
) |
Oops, something went wrong.