Skip to content

Commit

Permalink
#0: Updated Conv2dConfig in test_small_resnet_block
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Nov 20, 2024
1 parent 76c1a82 commit 372d237
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tests/ttnn/unit_tests/operations/test_small_resnet50_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
),
compute_config=ttnn.GetComputeKernelConfig(
math_fidelity=self.model_config["MATH_FIDELITY"],
),
conv_op_cache=conv_op_cache,
Expand All @@ -139,9 +141,11 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
),
compute_config=ttnn.GetComputeKernelConfig(
math_fidelity=self.model_config["MATH_FIDELITY"],
),
conv_op_cache=conv_op_cache,
)

Expand All @@ -162,6 +166,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
),
compute_config=ttnn.GetComputeKernelConfig(
math_fidelity=self.model_config["MATH_FIDELITY"],
),
conv_op_cache=conv_op_cache,
Expand All @@ -187,9 +193,11 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
),
compute_config=ttnn.GetComputeKernelConfig(
math_fidelity=self.model_config["MATH_FIDELITY"],
),
conv_op_cache=conv_op_cache,
)

Expand All @@ -211,6 +219,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
),
compute_config=ttnn.GetComputeKernelConfig(
math_fidelity=self.model_config["MATH_FIDELITY"],
),
conv_op_cache=conv_op_cache,
Expand Down

0 comments on commit 372d237

Please sign in to comment.