Skip to content

Commit

Permalink
#13707: Port reshape op to SimpleShape, related refactoring (#13838)
Browse files Browse the repository at this point in the history
* #13707: Port reshape op to SimpleShape, related refactoring

* #13707: Rework

* #13707: Fix failing tests

* #13707: Review fixes: simplification & better error messages

* #13707: Trying to fix stable diffusion

* #13707: Rebase fix
  • Loading branch information
sminakov-tt authored Oct 17, 2024
1 parent 68d66cc commit 4e1fef9
Show file tree
Hide file tree
Showing 49 changed files with 444 additions and 689 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ def ResnetLinear(
"""

matmul_config = hardcoded_matmul_config_linear[batch_size]
weight_shape = weight.shape.with_tile_padding()
weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1])
bias_shape = bias.shape.with_tile_padding()
bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1])
weight = weight.reshape(weight.shape.to_rank(4))
bias = bias.reshape(bias.shape.to_rank(4))

def linear_(act):
output = ttnn.linear(
Expand Down Expand Up @@ -717,9 +715,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -806,9 +804,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)
# for _, tensor in conv_op_cache["reader_patterns_cache"]["conv"].items():
Expand Down Expand Up @@ -948,9 +946,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -1037,9 +1035,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ def ResnetLinear(
"""

matmul_config = hardcoded_matmul_config_linear[batch_size]
weight_shape = weight.shape.with_tile_padding()
weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1])
bias_shape = bias.shape.with_tile_padding()
bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1])
weight = weight.reshape(weight.shape.to_rank(4))
bias = bias.reshape(bias.shape.to_rank(4))

def linear_(act):
output = ttnn.linear(
Expand Down Expand Up @@ -1136,9 +1134,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
x.shape.with_tile_padding()[2] // self.batch_size,
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -1207,9 +1205,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ def ResnetLinear(
"""

matmul_config = hardcoded_matmul_config_linear[batch_size]
weight_shape = weight.shape.with_tile_padding()
weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1])
bias_shape = bias.shape.with_tile_padding()
bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1])
weight = weight.reshape(weight.shape.to_rank(4))
bias = bias.reshape(bias.shape.to_rank(4))

def linear_(act):
output = ttnn.linear(
Expand Down Expand Up @@ -713,9 +711,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] / self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -802,9 +800,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -930,9 +928,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -1020,7 +1018,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[2] // self.batch_size,
x.shape.with_tile_padding()[3],
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,8 @@ def ResnetLinear(
"""

matmul_config = hardcoded_matmul_config_linear[batch_size]
weight_shape = weight.shape.with_tile_padding()
weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1])
bias_shape = bias.shape.with_tile_padding()
bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1])
weight = weight.reshape(weight.shape.to_rank(4))
bias = bias.reshape(bias.shape.to_rank(4))

def linear_(act):
output = ttnn.linear(
Expand Down Expand Up @@ -743,9 +741,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -832,9 +830,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -956,9 +954,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -1045,9 +1043,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,8 @@ def ResnetLinear(
"""

matmul_config = hardcoded_matmul_config_linear[batch_size]
weight_shape = weight.shape.with_tile_padding()
weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1])
bias_shape = bias.shape.with_tile_padding()
bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1])
weight = weight.reshape(weight.shape.to_rank(4))
bias = bias.reshape(bias.shape.to_rank(4))

def linear_(act):
output = ttnn.linear(
Expand Down Expand Up @@ -801,9 +799,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -897,9 +895,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -1030,9 +1028,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down Expand Up @@ -1126,9 +1124,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x,
(
self.batch_size,
x.shape.with_tile_padding()[1],
(int)(x.shape.with_tile_padding()[2] / self.batch_size),
x.shape.with_tile_padding()[3],
x.shape[1],
x.shape[2] // self.batch_size,
x.shape[3],
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def __call__(
self.batch_size,
self.conv_out_input_height,
self.conv_out_input_width,
32, # Padded to tile dim
-1,
),
)
sample = ttnn.permute(sample, (0, 3, 1, 2)) # permute from NHWC to NCHW
Expand Down
4 changes: 2 additions & 2 deletions tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ int main(int argc, char **argv) {
////////////////////////////////////////////////////////////////////////////
// Application Setup
////////////////////////////////////////////////////////////////////////////
tt::tt_metal::LegacyShape shape = {1, 32, 61, 32};
ttnn::SimpleShape shape{1, 32, 61, 32};
// Allocates a DRAM buffer on device populated with values specified by initialize
Tensor a = ttnn::numpy::arange<bfloat16>(0, tt_metal::compute_volume(shape), 1).reshape(shape).to(device);
Tensor a = ttnn::numpy::arange<bfloat16>(0, shape.volume(), 1).reshape(shape).to(device);
Tensor b = ttnn::tilize_with_zero_padding(a);
Tensor c = b.cpu();
////////////////////////////////////////////////////////////////////////////
Expand Down
4 changes: 2 additions & 2 deletions tests/tt_eager/tensors/test_async_tensor_apis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) {
},
host_tensor.get_storage());
// Send tensor to device, read it back and copy it to empty tensor initialized by main thread
Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, 128);
Tensor reshaped_tensor = host_tensor.reshape(ttnn::SimpleShape{1, 1, 32, 128});
auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device);
auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR);
readback_tensor.set_storage(thread_local_tensor.get_storage());
Expand Down Expand Up @@ -262,7 +262,7 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) {
},
host_tensor.get_storage());

Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, tensor_stop / 32);
Tensor reshaped_tensor = host_tensor.reshape(ttnn::SimpleShape{1, 1, 32, tensor_stop / 32});
auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device);
auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR);
log_info(LogTest, "Worker populating empty host readback_tensor");
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/gtests/test_multi_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ using namespace tt::tt_metal;

Tensor create_host_multi_device_tensor(const Tensor& tensor, const ReplicateTensor& strategy) {
std::vector<OwnedBuffer> owned_buffers;
std::vector<tt::tt_metal::LegacyShape> shapes;
std::vector<ttnn::Shape> shapes;

for (int i = 0; i < strategy.replication_factor; i++) {
owned_buffers.push_back(std::get<OwnedStorage>(tensor.get_storage()).buffer);
shapes.push_back(tensor.get_legacy_shape());
shapes.push_back(tensor.get_shape());
}

return Tensor{
Expand Down
18 changes: 14 additions & 4 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,11 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona
tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, tile, false));
}
std::vector<OwnedBuffer> host_owned_buffers;
std::vector<tt::tt_metal::LegacyShape> host_owned_shapes;
std::vector<ttnn::Shape> host_owned_shapes;
for (const auto &shard : tt_shards) {
TT_ASSERT(std::holds_alternative<OwnedStorage>(shard.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(shard.get_storage()));
host_owned_buffers.push_back(std::get<OwnedStorage>(shard.get_storage()).buffer);
host_owned_shapes.push_back(shard.get_legacy_shape());
host_owned_shapes.push_back(shard.shape());
}
auto distributed_tensor_config = get_distributed_tensor_config(strategy);
auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes};
Expand Down Expand Up @@ -1593,7 +1593,7 @@ void pytensor_module(py::module &m_tensor) {
)doc")
.def(
"reshape",
[](Tensor &self, int N, int C, int H, int W) { return self.reshape(N, C, H, W); },
[](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, {N, C, H, W})); },
R"doc(
Reshapes TT tensor
Expand All @@ -1603,14 +1603,24 @@ void pytensor_module(py::module &m_tensor) {
)doc")
.def(
"reshape",
[](Tensor &self, const tt::tt_metal::LegacyShape &shape) -> Tensor { return self.reshape(shape); },
[](Tensor &self, const ttnn::Shape &shape) -> Tensor { return self.reshape(shape); },
R"doc(
Reshapes TT tensor
.. code-block:: python
reshaped_tensor = tt_tensor.reshape((4, 3, 32))
)doc")
.def(
"reshape",
[](Tensor &self, const std::vector<int32_t> &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); },
R"doc(
Reshapes TT tensor
.. code-block:: python
reshaped_tensor = tt_tensor.reshape((4, -1, 32))
)doc")
.def_property(
"tensor_id",
[](const Tensor &self) { return self.tensor_id; },
Expand Down
Loading

0 comments on commit 4e1fef9

Please sign in to comment.