Skip to content

Commit

Permalink
#0: use ttnn.from_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Oct 16, 2024
1 parent d9ebee4 commit 5475638
Showing 1 changed file with 3 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,11 @@ def transpose(
else:
x = torch.randn(input_shape).bfloat16().float()

xt = ttnn.to_layout(
ttnn.Tensor(
x,
input_dtype,
),
ttnn.TILE_LAYOUT,
ttnn_input = ttnn.from_torch(
x, layout=ttnn.TILE_LAYOUT, dtype=input_dtype, device=device, memory_config=input_mem_config
)
xtt = ttnn.transpose(ttnn_input, dim0, dim1, memory_config=output_mem_config)

xt = xt.to(device, input_mem_config)
xtt = ttnn.transpose(xt, dim0, dim1, memory_config=output_mem_config)
assert list(xtt.shape) == output_shape
transposed_ref = x.transpose(dim0, dim1)

Expand Down

0 comments on commit 5475638

Please sign in to comment.