diff --git a/models/experimental/functional_mobilenetv2/README.md b/models/experimental/functional_mobilenetv2/README.md new file mode 100644 index 00000000000..898887d2a0d --- /dev/null +++ b/models/experimental/functional_mobilenetv2/README.md @@ -0,0 +1,21 @@ +# MobilenetV2 +The MobileNetV2 model is a convolutional neural network (CNN) architecture designed for efficient mobile and embedded vision applications. It was introduced in the paper ["MobileNetV2: Inverted Residuals and Linear Bottlenecks"](https://arxiv.org/abs/1801.04381).
+The MobileNetV2 model has been pre-trained on the ImageNet dataset and can be used for various tasks such as image classification, object detection, and semantic segmentation. It has achieved state-of-the-art performance on several benchmarks 1 for mobile and embedded vision applications. + +## How to Run + +To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables. +For more information, refer [installation and build guide](https://docs.tenstorrent.com/tt-metalium/latest/get_started/get_started.html#install-and-build). + +To run the functional Mobilenetv2 model on a single-chip: +```sh +pytest --disable-warnings models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py +``` + +## Supported Hardware +- N150 + +## Other Details + +- Inputs by default are random data. +- The model weights will be automatically downloaded from Google Drive using wget implemented in weights_download.sh. diff --git a/models/experimental/functional_mobilenetv2/reference/mobilenetv2.py b/models/experimental/functional_mobilenetv2/reference/mobilenetv2.py new file mode 100644 index 00000000000..f60e1e4f759 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/reference/mobilenetv2.py @@ -0,0 +1,384 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn + + +class Mobilenetv2(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) + self.b1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU6(inplace=True) + + self.c2 = nn.Conv2d(32, 32, 3, 1, 1, groups=32, bias=False) + self.b2 = nn.BatchNorm2d(32) + + self.c3 = nn.Conv2d(32, 16, 1, 1, bias=False) + self.b3 = nn.BatchNorm2d(16) + + self.c4 = nn.Conv2d(16, 96, 1, 1, bias=False) + self.b4 = nn.BatchNorm2d(96) + + self.c5 = nn.Conv2d(96, 96, 3, 2, 1, groups=96, bias=False) + self.b5 = nn.BatchNorm2d(96) + + self.c6 = nn.Conv2d(96, 24, 1, 1, bias=False) + self.b6 = nn.BatchNorm2d(24) + + self.c7 = nn.Conv2d(24, 144, 1, 1, bias=False) + self.b7 = nn.BatchNorm2d(144) + + self.c8 = nn.Conv2d(144, 144, 3, 1, 1, groups=144, bias=False) + self.b8 = nn.BatchNorm2d(144) + + self.c9 = nn.Conv2d(144, 24, 1, 1, bias=False) + self.b9 = nn.BatchNorm2d(24) + + self.c10 = nn.Conv2d(24, 144, 1, 1, bias=False) + self.b10 = nn.BatchNorm2d(144) + + self.c11 = nn.Conv2d(144, 144, 3, 2, 1, groups=144, bias=False) + self.b11 = nn.BatchNorm2d(144) + + self.c12 = nn.Conv2d(144, 32, 1, 1, bias=False) + self.b12 = nn.BatchNorm2d(32) + + self.c13 = nn.Conv2d(32, 192, 1, 1, bias=False) + self.b13 = nn.BatchNorm2d(192) + + self.c14 = nn.Conv2d(192, 192, 3, 1, 1, groups=192, bias=False) + self.b14 = nn.BatchNorm2d(192) + + self.c15 = nn.Conv2d(192, 32, 1, 1, bias=False) + self.b15 = nn.BatchNorm2d(32) + + self.c16 = nn.Conv2d(32, 192, 1, 1, bias=False) + self.b16 = nn.BatchNorm2d(192) + + self.c17 = nn.Conv2d(192, 192, 3, 1, 1, groups=192, bias=False) + self.b17 = nn.BatchNorm2d(192) + + self.c18 = nn.Conv2d(192, 32, 1, 1, bias=False) + self.b18 = nn.BatchNorm2d(32) + + self.c19 = nn.Conv2d(32, 192, 1, 1, bias=False) + self.b19 = nn.BatchNorm2d(192) + + self.c20 = nn.Conv2d(192, 192, 3, 2, 1, groups=192, bias=False) + self.b20 = nn.BatchNorm2d(192) + + self.c21 = nn.Conv2d(192, 64, 1, 1, bias=False) + self.b21 = nn.BatchNorm2d(64) + + self.c22 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b22 = nn.BatchNorm2d(384) + + self.c23 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b23 = nn.BatchNorm2d(384) + + self.c24 = nn.Conv2d(384, 64, 1, 1, bias=False) + self.b24 = nn.BatchNorm2d(64) + + self.c25 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b25 = nn.BatchNorm2d(384) + + self.c26 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b26 = nn.BatchNorm2d(384) + + self.c27 = nn.Conv2d(384, 64, 1, 1, bias=False) + self.b27 = nn.BatchNorm2d(64) + + self.c28 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b28 = nn.BatchNorm2d(384) + + self.c29 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b29 = nn.BatchNorm2d(384) + + self.c30 = nn.Conv2d(384, 64, 1, 1, bias=False) + self.b30 = nn.BatchNorm2d(64) + + self.c31 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b31 = nn.BatchNorm2d(384) + + self.c32 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b32 = nn.BatchNorm2d(384) + + self.c33 = nn.Conv2d(384, 96, 1, 1, bias=False) + self.b33 = nn.BatchNorm2d(96) + + self.c34 = nn.Conv2d(96, 576, 1, 1, bias=False) + self.b34 = nn.BatchNorm2d(576) + + self.c35 = nn.Conv2d(576, 576, 3, 1, 1, groups=576, bias=False) + self.b35 = nn.BatchNorm2d(576) + + self.c36 = nn.Conv2d(576, 96, 1, 1, bias=False) + self.b36 = nn.BatchNorm2d(96) + + self.c37 = nn.Conv2d(96, 576, 1, 1, bias=False) + self.b37 = nn.BatchNorm2d(576) + + self.c38 = nn.Conv2d(576, 576, 3, 1, 1, groups=576, bias=False) + self.b38 = nn.BatchNorm2d(576) + + self.c39 = nn.Conv2d(576, 96, 1, 1, bias=False) + self.b39 = nn.BatchNorm2d(96) + + self.c40 = nn.Conv2d(96, 576, 1, 1, bias=False) + self.b40 = nn.BatchNorm2d(576) + + self.c41 = nn.Conv2d(576, 576, 3, 2, 1, groups=576, bias=False) + self.b41 = nn.BatchNorm2d(576) + + self.c42 = nn.Conv2d(576, 160, 1, 1, bias=False) + self.b42 = nn.BatchNorm2d(160) + + self.c43 = nn.Conv2d(160, 960, 1, 1, bias=False) + self.b43 = nn.BatchNorm2d(960) + + self.c44 = nn.Conv2d(960, 960, 3, 1, 1, groups=960, bias=False) + self.b44 = nn.BatchNorm2d(960) + + self.c45 = nn.Conv2d(960, 160, 1, 1, bias=False) + self.b45 = nn.BatchNorm2d(160) + + self.c46 = nn.Conv2d(160, 960, 1, 1, bias=False) + self.b46 = nn.BatchNorm2d(960) + + self.c47 = nn.Conv2d(960, 960, 3, 1, 1, groups=960, bias=False) + self.b47 = nn.BatchNorm2d(960) + + self.c48 = nn.Conv2d(960, 160, 1, 1, bias=False) + self.b48 = nn.BatchNorm2d(160) + + self.c49 = nn.Conv2d(160, 960, 1, 1, bias=False) + self.b49 = nn.BatchNorm2d(960) + + self.c50 = nn.Conv2d(960, 960, 3, 1, 1, groups=960, bias=False) + self.b50 = nn.BatchNorm2d(960) + + self.c51 = nn.Conv2d(960, 320, 1, 1, bias=False) + self.b51 = nn.BatchNorm2d(320) + + self.c52 = nn.Conv2d(320, 1280, 1, 1, bias=False) + self.b52 = nn.BatchNorm2d(1280) + + self.l1 = nn.Linear(in_features=1280, out_features=1000) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x2_m) + x3_b = self.b3(x3) + + x4 = self.c4(x3_b) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + + x6 = self.c6(x5_m) + x6_b = self.b6(x6) + + x7 = self.c7(x6_b) + x7_b = self.b7(x7) + x7_m = self.relu(x7_b) + + x8 = self.c8(x7_m) + x8_b = self.b8(x8) + x8_m = self.relu(x8_b) + + x9 = self.c9(x8_m) + x9_b = self.b9(x9) + a1 = x9_b + x6_b + x10 = self.c10(a1) + x10_b = self.b10(x10) + x10_m = self.relu(x10_b) + + x11 = self.c11(x10_m) + x11_b = self.b11(x11) + x11_m = self.relu(x11_b) + + x12 = self.c12(x11_m) + x12_b = self.b12(x12) + + x13 = self.c13(x12_b) + x13_b = self.b13(x13) + x13_m = self.relu(x13_b) + + x14 = self.c14(x13_m) + x14_b = self.b14(x14) + x14_m = self.relu(x14_b) + + x15 = self.c15(x14_m) + x15_b = self.b15(x15) + + a2 = x15_b + x12_b + + x16 = self.c16(a2) + x16_b = self.b16(x16) + x16_m = self.relu(x16_b) + + x17 = self.c17(x16_m) + x17_b = self.b17(x17) + x17_m = self.relu(x17_b) + + x18 = self.c18(x17_m) + x18_b = self.b18(x18) + + a3 = a2 + x18_b + + x19 = self.c19(a3) + x19_b = self.b19(x19) + x19_m = self.relu(x19_b) + + x20 = self.c20(x19_m) + x20_b = self.b20(x20) + x20_m = self.relu(x20_b) + + x21 = self.c21(x20_m) + x21_b = self.b21(x21) + + x22 = self.c22(x21_b) + x22_b = self.b22(x22) + x22_m = self.relu(x22_b) + + x23 = self.c23(x22_m) + x23_b = self.b23(x23) + x23_m = self.relu(x23_b) + + x24 = self.c24(x23_m) + x24_b = self.b24(x24) + + a4 = x21_b + x24_b + + x25 = self.c25(a4) + x25_b = self.b25(x25) + x25_m = self.relu(x25_b) + + x26 = self.c26(x25_m) + x26_b = self.b26(x26) + x26_m = self.relu(x26_b) + + x27 = self.c27(x26_m) + x27_b = self.b27(x27) + + a5 = a4 + x27_b + + x28 = self.c28(a5) + x28_b = self.b28(x28) + x28_m = self.relu(x28_b) + + x29 = self.c29(x28_m) + x29_b = self.b29(x29) + x29_m = self.relu(x29_b) + + x30 = self.c30(x29_m) + x30_b = self.b30(x30) + + a6 = a5 + x30_b + + x31 = self.c31(a6) + x31_b = self.b31(x31) + x31_m = self.relu(x31_b) + + x32 = self.c32(x31_m) + x32_b = self.b32(x32) + x32_m = self.relu(x32_b) + + x33 = self.c33(x32_m) + x33_b = self.b33(x33) + + x34 = self.c34(x33_b) + x34_b = self.b34(x34) + x34_m = self.relu(x34_b) + + x35 = self.c35(x34_m) + x35_b = self.b35(x35) + x35_m = self.relu(x35_b) + + x36 = self.c36(x35_m) + x36_b = self.b36(x36) + + a7 = x33_b + x36_b + + x37 = self.c37(a7) + x37_b = self.b37(x37) + x37_m = self.relu(x37_b) + + x38 = self.c38(x37_m) + x38_b = self.b38(x38) + x38_m = self.relu(x38_b) + + x39 = self.c39(x38_m) + x39_b = self.b39(x39) + + a8 = a7 + x39_b + + x40 = self.c40(a8) + x40_b = self.b40(x40) + x40_m = self.relu(x40_b) + + x41 = self.c41(x40_m) + x41_b = self.b41(x41) + x41_m = self.relu(x41_b) + + x42 = self.c42(x41_m) + x42_b = self.b42(x42) + + x43 = self.c43(x42_b) + x43_b = self.b43(x43) + x43_m = self.relu(x43_b) + + x44 = self.c44(x43_m) + x44_b = self.b44(x44) + x44_m = self.relu(x44_b) + + x45 = self.c45(x44_m) + x45_b = self.b45(x45) + + a9 = x45_b + x42_b + + x46 = self.c46(a9) + x46_b = self.b46(x46) + x46_m = self.relu(x46_b) + + x47 = self.c47(x46_m) + x47_b = self.b47(x47) + x47_m = self.relu(x47_b) + + x48 = self.c48(x47_m) + x48_b = self.b48(x48) + + a10 = a9 + x48_b + + x49 = self.c49(a10) + x49_b = self.b49(x49) + x49_m = self.relu(x49_b) + + x50 = self.c50(x49_m) + x50_b = self.b50(x50) + x50_m = self.relu(x50_b) + + x51 = self.c51(x50_m) + x51_b = self.b51(x51) + + x52 = self.c52(x51_b) + x52_b = self.b52(x52) + x52_m = self.relu(x52_b) + x = nn.functional.adaptive_avg_pool2d(x52_m, (1, 1)) + x = torch.flatten(x, 1) + x53 = self.l1(x) + return x53 diff --git a/models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py b/models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py new file mode 100644 index 00000000000..c46cf738071 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import ttnn +import torch + +from tests.ttnn.utils_for_testing import assert_with_pcc + +from models.experimental.functional_mobilenetv2.reference.mobilenetv2 import Mobilenetv2 +from models.experimental.functional_mobilenetv2.tt.model_preprocessing import ( + create_mobilenetv2_input_tensors, + create_mobilenetv2_model_parameters, +) +from models.experimental.functional_mobilenetv2.tt import ttnn_mobilenetv2 +import os +from models.utility_functions import ( + skip_for_grayskull, +) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@skip_for_grayskull() +def test_mobilenetv2(device, reset_seeds): + if not os.path.exists("models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth"): + os.system( + "bash models/experimental/functional_mobilenetv2/weights_download.sh" + ) # execute the weights_download.sh file + + state_dict = torch.load("models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth") + ds_state_dict = {k: v for k, v in state_dict.items()} + torch_model = Mobilenetv2() + + new_state_dict = {} + + for (name1, parameter1), (name2, parameter2) in zip(torch_model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + torch_input_tensor, ttnn_input_tensor = create_mobilenetv2_input_tensors() + torch_output_tensor = torch_model(torch_input_tensor) + + parameters = create_mobilenetv2_model_parameters(torch_model, torch_input_tensor, device=device) + + ttnn_model = ttnn_mobilenetv2.MobileNetV2(parameters, device, torch_model) + output_tensor = ttnn_model(device, ttnn_input_tensor) + + # + # Tensor Postprocessing + # + output_tensor = ttnn.to_torch(output_tensor) + output_tensor = output_tensor.reshape(torch_output_tensor.shape) + output_tensor = output_tensor.to(torch_input_tensor.dtype) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.95) diff --git a/models/experimental/functional_mobilenetv2/tt/model_preprocessing.py b/models/experimental/functional_mobilenetv2/tt/model_preprocessing.py new file mode 100644 index 00000000000..a21a3d2f45f --- /dev/null +++ b/models/experimental/functional_mobilenetv2/tt/model_preprocessing.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn + +from models.experimental.functional_mobilenetv2.reference.mobilenetv2 import Mobilenetv2 +from ttnn.model_preprocessing import infer_ttnn_module_args + + +def create_mobilenetv2_input_tensors(batch=1, input_channels=3, input_height=128, input_width=128): + torch_input_tensor = torch.randn(batch, input_channels, input_height, input_width) + ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) + ttnn_input_tensor = ttnn_input_tensor.reshape( + 1, + 1, + ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], + ttnn_input_tensor.shape[3], + ) + ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16) + + return torch_input_tensor, ttnn_input_tensor + + +def create_mobilenetv2_model_parameters(model: Mobilenetv2, input_tensor, device): + parameters = infer_ttnn_module_args(model=model, run_model=lambda model: model(input_tensor), device=None) + assert parameters is not None + for key in parameters.keys(): + parameters[key].module = getattr(model, key) + + parameters["l1"] = {} + parameters["l1"]["weight"] = model.l1.weight + parameters["l1"]["bias"] = model.l1.bias + + return parameters diff --git a/models/experimental/functional_mobilenetv2/tt/ttnn_mobilenetv2.py b/models/experimental/functional_mobilenetv2/tt/ttnn_mobilenetv2.py new file mode 100644 index 00000000000..5be2c09a59c --- /dev/null +++ b/models/experimental/functional_mobilenetv2/tt/ttnn_mobilenetv2.py @@ -0,0 +1,464 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch + +from ttnn.model_preprocessing import ParameterDict + +from torch import nn +from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias + + +class MobileNetV2Conv2D: + def fold_batch_norm2d_into_conv2d(self, conv, bn): + if not bn.track_running_stats: + raise RuntimeError("BatchNorm2d must have track_running_stats=True to be folded into Conv2d") + weight = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + eps = bn.eps + scale = bn.weight + shift = bn.bias + weight = weight * (scale / torch.sqrt(running_var + eps))[:, None, None, None] + bias = shift - running_mean * (scale / torch.sqrt(running_var + eps)) + return weight, bias + + def __init__( + self, + conv, + bn=None, + device=None, + cache={}, + activation="", + activation_dtype=ttnn.bfloat8_b, + weights_dtype=ttnn.bfloat8_b, + use_1d_systolic_array=True, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ): + self.device = device + self.batch_size = conv.batch_size + self.input_height = conv.input_height + self.input_width = conv.input_width + self.in_channels = conv.in_channels + self.out_channels = conv.out_channels + self.kernel_size = conv.kernel_size + self.padding = conv.padding + self.stride = conv.stride + self.groups = conv.groups + self.use_1d_systolic_array = use_1d_systolic_array + self.deallocate_activation = True + self.cache = cache + + self.conv_config = ttnn.Conv2dConfig( + dtype=activation_dtype, + weights_dtype=weights_dtype, + math_fidelity=ttnn.MathFidelity.LoFi, + shard_layout=shard_layout, + deallocate_activation=self.deallocate_activation, + fp32_dest_acc_enabled=True, + packer_l1_accum_enabled=False, + enable_act_double_buffer=False, + enable_split_reader=False, + enable_subblock_padding=False, + reshard_if_not_optimal=True if self.use_1d_systolic_array else False, + activation=activation, + ) + config_override = conv.conv_blocking_and_parallelization_config_override + if config_override and "act_block_h" in config_override: + self.conv_config.act_block_h_override = config_override["act_block_h"] + + if bn is not None: + weight, bias = self.fold_batch_norm2d_into_conv2d(conv.module, bn.module) + else: + weight, bias = conv.module.weight, conv.module.bias + + weight = weight + bias = torch.reshape(bias, (1, 1, 1, -1)) + self.weight = ttnn.from_torch(weight, dtype=ttnn.float32) + self.bias = ttnn.from_torch(bias, dtype=ttnn.float32) + + def __call__(self, x): + x, output_height, output_width, self.weight, self.bias = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.weight, + bias_tensor=self.bias, + device=self.device, + in_channels=self.in_channels, + out_channels=self.out_channels, + input_height=self.input_height, + input_width=self.input_width, + batch_size=self.batch_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + conv_config=self.conv_config, + conv_op_cache=self.cache, + groups=self.groups, + ) + return x + + +class MobileNetV2: + def input_preprocessor(self, tensor, n, c, h, w): + tensor = ttnn.to_torch(tensor).to(torch.float32) + tensor = torch.reshape(tensor, (n, h, w, c)) + tensor = torch.permute(tensor, (0, 3, 1, 2)) + return tensor + + def __init__(self, parameters: ParameterDict, device, model) -> None: + self.device = device + self.model = model + self.parameters = parameters + + self.c1 = MobileNetV2Conv2D(parameters.c1, parameters.b1, device) + self.c2 = MobileNetV2Conv2D(parameters.c2, parameters.b2, device) + + self.c3 = MobileNetV2Conv2D(parameters.c3, parameters.b3, device) + + self.c4 = MobileNetV2Conv2D(parameters.c4, parameters.b4, device) + + self.c5 = MobileNetV2Conv2D(parameters.c5, parameters.b5, device) + + self.c6 = MobileNetV2Conv2D(parameters.c6, parameters.b6, device) + + self.c7 = MobileNetV2Conv2D(parameters.c7, parameters.b7, device) + + self.c8 = MobileNetV2Conv2D(parameters.c8, parameters.b8, device) + + self.c9 = MobileNetV2Conv2D(parameters.c9, parameters.b9, device) + + self.c10 = MobileNetV2Conv2D(parameters.c10, parameters.b10, device) + + self.c11 = MobileNetV2Conv2D(parameters.c11, parameters.b11, device) + + self.c12 = MobileNetV2Conv2D(parameters.c12, parameters.b12, device) + + self.c13 = MobileNetV2Conv2D(parameters.c13, parameters.b13, device) + self.c14 = MobileNetV2Conv2D(parameters.c14, parameters.b14, device) + self.c15 = MobileNetV2Conv2D(parameters.c15, parameters.b15, device) + self.c16 = MobileNetV2Conv2D(parameters.c16, parameters.b16, device) + self.c17 = MobileNetV2Conv2D(parameters.c17, parameters.b17, device) + self.c18 = MobileNetV2Conv2D(parameters.c18, parameters.b18, device) + self.c19 = MobileNetV2Conv2D(parameters.c19, parameters.b19, device) + self.c20 = MobileNetV2Conv2D(parameters.c20, parameters.b20, device) + self.c21 = MobileNetV2Conv2D(parameters.c21, parameters.b21, device) + self.c22 = MobileNetV2Conv2D(parameters.c22, parameters.b22, device) + self.c23 = MobileNetV2Conv2D(parameters.c23, parameters.b23, device) + self.c24 = MobileNetV2Conv2D(parameters.c24, parameters.b24, device) + self.c25 = MobileNetV2Conv2D(parameters.c25, parameters.b25, device) + self.c26 = MobileNetV2Conv2D(parameters.c26, parameters.b26, device) + self.c27 = MobileNetV2Conv2D(parameters.c27, parameters.b27, device) + self.c28 = MobileNetV2Conv2D(parameters.c28, parameters.b28, device) + self.c29 = MobileNetV2Conv2D(parameters.c29, parameters.b29, device) + self.c30 = MobileNetV2Conv2D(parameters.c30, parameters.b30, device) + self.c31 = MobileNetV2Conv2D(parameters.c31, parameters.b31, device) + self.c32 = MobileNetV2Conv2D(parameters.c32, parameters.b32, device) + self.c33 = MobileNetV2Conv2D(parameters.c33, parameters.b33, device) + self.c34 = MobileNetV2Conv2D(parameters.c34, parameters.b34, device) + + self.c35 = MobileNetV2Conv2D( + parameters.c35, parameters.b35, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c36 = MobileNetV2Conv2D(parameters.c36, parameters.b36, device) + self.c37 = MobileNetV2Conv2D(parameters.c37, parameters.b37, device) + + self.c38 = MobileNetV2Conv2D( + parameters.c38, parameters.b38, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c39 = MobileNetV2Conv2D(parameters.c39, parameters.b39, device) + self.c40 = MobileNetV2Conv2D(parameters.c40, parameters.b40, device) + + self.c41 = MobileNetV2Conv2D( + parameters.c41, parameters.b41, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c42 = MobileNetV2Conv2D(parameters.c42, parameters.b42, device) + self.c43 = MobileNetV2Conv2D(parameters.c43, parameters.b43, device) + + self.c44 = MobileNetV2Conv2D( + parameters.c44, parameters.b44, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c45 = MobileNetV2Conv2D(parameters.c45, parameters.b45, device) + self.c46 = MobileNetV2Conv2D(parameters.c46, parameters.b46, device) + + self.c47 = MobileNetV2Conv2D( + parameters.c47, parameters.b47, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c48 = MobileNetV2Conv2D(parameters.c48, parameters.b48, device) + self.c49 = MobileNetV2Conv2D(parameters.c49, parameters.b49, device) + + self.c50 = MobileNetV2Conv2D( + parameters.c50, parameters.b50, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c51 = MobileNetV2Conv2D(parameters.c51, parameters.b51, device) + self.c52 = MobileNetV2Conv2D(parameters.c52, parameters.b52, device) + + self.l1_weight = parameters.l1["weight"] + self.l1_bias = parameters.l1["bias"] + + def __call__( + self, + device, + x, + ): + output_tensor = self.c1(x) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c2(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c3(output_tensor) + + output_tensor = self.c4(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c5(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c6(output_tensor) + output_tensor_c6 = output_tensor + + if output_tensor_c6.is_sharded(): + output_tensor_c6 = ttnn.sharded_to_interleaved(output_tensor_c6, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c7(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c8(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c9(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.add(output_tensor_c6, output_tensor) + + output_tensor = self.c10(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c11(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c12(output_tensor) + output_tensor_c12 = output_tensor + + if output_tensor_c12.is_sharded(): + output_tensor_c12 = ttnn.sharded_to_interleaved(output_tensor_c12, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c13(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c14(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c15(output_tensor) + output_tensor_c15 = output_tensor + + if output_tensor_c15.is_sharded(): + output_tensor_c15 = ttnn.sharded_to_interleaved(output_tensor_c15, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c15 + output_tensor_c12 + output_tensor_a2 = output_tensor + + if output_tensor_a2.is_sharded(): + output_tensor_a2 = ttnn.sharded_to_interleaved(output_tensor_a2, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c16(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c17(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c18(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_a2 + output_tensor + + output_tensor = self.c19(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c20(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c21(output_tensor) + + output_tensor_c21 = output_tensor + if output_tensor_c21.is_sharded(): + output_tensor_c21 = ttnn.sharded_to_interleaved(output_tensor_c21, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c22(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c23(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c24(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c21 + output_tensor + + output_tensor_a4 = output_tensor + + if output_tensor_a4.is_sharded(): + output_tensor_a4 = ttnn.sharded_to_interleaved(output_tensor_a4, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c25(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c26(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c27(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_a4 + output_tensor + output_tensor_a5 = output_tensor + if output_tensor_a5.is_sharded(): + output_tensor_a5 = ttnn.sharded_to_interleaved(output_tensor_a5, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c28(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c29(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c30(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.add(output_tensor_a5, output_tensor) + + output_tensor = self.c31(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c32(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c33(output_tensor) + + output_tensor_c33 = output_tensor + if output_tensor_c33.is_sharded(): + output_tensor_c33 = ttnn.sharded_to_interleaved(output_tensor_c33, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c34(output_tensor_c33) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c35(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c36(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c33 + output_tensor + + output_tensor_a7 = output_tensor + + if output_tensor_a7.is_sharded(): + output_tensor_a7 = ttnn.sharded_to_interleaved(output_tensor_a7, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c37(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c38(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c39(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_a7 + output_tensor + + output_tensor = self.c40(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c41(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c42(output_tensor) + + output_tensor_c42 = output_tensor + if output_tensor_c42.is_sharded(): + output_tensor_c42 = ttnn.sharded_to_interleaved(output_tensor_c42, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c43(output_tensor_c42) + output_tensor = ttnn.relu6(output_tensor) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + + output_tensor = self.c44(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c45(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c42 + output_tensor + output_tensor_a9 = output_tensor + + if output_tensor_a9.is_sharded(): + output_tensor_a9 = ttnn.sharded_to_interleaved(output_tensor_a9, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c46(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = self.c47(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c48(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor + output_tensor_a9 + + output_tensor = self.c49(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = self.c50(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c51(output_tensor) + + output_tensor = self.c52(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.global_avg_pool2d(output_tensor) + + output_tensor = self.input_preprocessor(output_tensor, 1, 1280, 1, 1) + + output_tensor = torch.flatten(output_tensor, 1) + + output_tensor = ttnn.from_torch(output_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG) + + self.l1_weight = preprocess_linear_weight(self.l1_weight, dtype=ttnn.bfloat16) + self.l1_bias = preprocess_linear_bias(self.l1_bias, dtype=ttnn.bfloat16) + self.l1_weight = ttnn.to_device(self.l1_weight, device) + self.l1_bias = ttnn.to_device(self.l1_bias, device) + + output_tensor = ttnn.linear(output_tensor, self.l1_weight, bias=self.l1_bias) + + return ttnn.from_device(output_tensor) diff --git a/models/experimental/functional_mobilenetv2/weights_download.sh b/models/experimental/functional_mobilenetv2/weights_download.sh new file mode 100644 index 00000000000..89dbc939007 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/weights_download.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Output filename +OUTPUT="models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth" + +# Create output directory if it doesn't exist +mkdir -p "$(dirname "$OUTPUT")" + +# Download the file using wget +if wget "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" -O "${OUTPUT}"; then + echo "File downloaded successfully: ${OUTPUT}" +else + echo "Error downloading the file." + exit 1 +fi