Skip to content

Commit

Permalink
#11351: Replace tt_lib usage in eltwise complex
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Aug 15, 2024
1 parent 40ceb6b commit 76f94a4
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch

import tt_lib as ttl
import ttnn


class Complex:
Expand Down Expand Up @@ -82,7 +81,7 @@ def random_complex_tensor(shape, real_range=(-100, 100), imag_range=(-100, 100))

def convert_to_torch_tensor(tt_dev):
for i in range(len(tt_dev)):
tt_dev_r = tt_dev[i].real.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_dev_i = tt_dev[i].imag.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_dev_r = tt_dev[i].real.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_dev_i = tt_dev[i].imag.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_dev[i] = Complex(re=tt_dev_r, im=tt_dev_i).metal
return tt_dev
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch

import tt_lib as ttl
import pytest
import ttnn
from loguru import logger
Expand All @@ -26,12 +25,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
def test_level2_abs_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
Expand All @@ -41,8 +40,8 @@ def test_level2_abs_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
in_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data, grad_tensor = data_gen_with_range(input_shape, -50, 40, device)
Expand All @@ -65,12 +64,12 @@ def test_level2_abs_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
def test_level2_abs_bw_inp_zero(bs, hw, memcfg, dtype, device, function_level_defaults):
Expand All @@ -80,8 +79,8 @@ def test_level2_abs_bw_inp_zero(bs, hw, memcfg, dtype, device, function_level_de
in_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data, grad_tensor = data_gen_with_range(input_shape, -50, 80, device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch

import tt_lib as ttl
import pytest
import ttnn
from loguru import logger
Expand All @@ -26,12 +25,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
def test_level2_angle_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
Expand All @@ -41,8 +40,8 @@ def test_level2_angle_bw(bs, hw, memcfg, dtype, device, function_level_defaults)
in_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data, grad_tensor = data_gen_with_range(input_shape, -50, 40, device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch

import tt_lib as ttl
import pytest
import ttnn
from loguru import logger
Expand All @@ -26,12 +25,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
@pytest.mark.parametrize("alpha", [0.0, -5.0, 3.5])
Expand All @@ -45,18 +44,18 @@ def test_level2_complex_add_bw(bs, hw, alpha, memcfg, dtype, device, function_le
other_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(other_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(other_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(grad_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(grad_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
tt_dev = ttnn.add_bw(grad_tensor, input_tensor, other_tensor, alpha, memory_config=memcfg)
tt_dev = convert_to_torch_tensor(tt_dev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch

import tt_lib as ttl
from models.utility_functions import print_diff_argmax
import pytest
import ttnn
Expand All @@ -30,12 +29,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
def test_level2_complex_div_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
Expand All @@ -48,18 +47,18 @@ def test_level2_complex_div_bw(bs, hw, memcfg, dtype, device, function_level_def
other_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(other_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(other_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(grad_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(grad_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
tt_dev = ttnn.div_bw(grad_tensor, input_tensor, other_tensor, memory_config=memcfg)
tt_dev = convert_to_torch_tensor(tt_dev)
Expand All @@ -79,12 +78,12 @@ def test_level2_complex_div_bw(bs, hw, memcfg, dtype, device, function_level_def
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
@skip_for_wormhole_b0()
Expand All @@ -98,18 +97,18 @@ def test_level2_complex_div_bw_other_zero(bs, hw, memcfg, dtype, device, functio
other_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(other_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(other_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(grad_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(grad_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
tt_dev = ttnn.div_bw(grad_tensor, input_tensor, other_tensor, memory_config=memcfg)
tt_dev = convert_to_torch_tensor(tt_dev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch

import tt_lib as ttl
from models.utility_functions import print_diff_argmax
import pytest
import ttnn
Expand All @@ -27,12 +26,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
def test_level2_complex_mul_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
Expand All @@ -45,18 +44,18 @@ def test_level2_complex_mul_bw(bs, hw, memcfg, dtype, device, function_level_def
other_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(other_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(other_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(grad_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(grad_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
tt_dev = ttnn.mul_bw(grad_tensor, input_tensor, other_tensor, memory_config=memcfg)
tt_dev = convert_to_torch_tensor(tt_dev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

import tt_lib as ttl
import pytest
import ttnn
from loguru import logger
Expand All @@ -21,12 +20,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
@pytest.mark.parametrize("alpha", [-5.0, 1.0, 3.5])
Expand All @@ -40,18 +39,18 @@ def test_level2_complex_sub_bw(bs, hw, alpha, memcfg, dtype, device, function_le
other_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(other_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(other_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(grad_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(grad_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
tt_dev = ttnn.sub_bw(grad_tensor, input_tensor, other_tensor, alpha, memory_config=memcfg)
tt_dev = convert_to_torch_tensor(tt_dev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch

import tt_lib as ttl
import pytest
import ttnn
from loguru import logger
Expand All @@ -26,12 +25,12 @@
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
ttnn.DRAM_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.BFLOAT16,)))
@pytest.mark.parametrize("dtype", ((ttnn.bfloat16,)))
@pytest.mark.parametrize("bs", ((1, 1), (1, 2), (2, 2)))
@pytest.mark.parametrize("hw", ((32, 64), (320, 384)))
def test_level2_conj_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
Expand All @@ -41,14 +40,14 @@ def test_level2_conj_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
in_data.requires_grad = True

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(in_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(in_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttnn.Tensor(grad_data.real, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
ttnn.Tensor(grad_data.imag, dtype).to(ttnn.TILE_LAYOUT).to(device, memcfg),
)
tt_dev = ttnn.conj_bw(grad_tensor, input_tensor, memory_config=memcfg)
tt_dev = convert_to_torch_tensor(tt_dev)
Expand Down
Loading

0 comments on commit 76f94a4

Please sign in to comment.