Skip to content

Commit

Permalink
#13745:replace tensor.reshape with tensor.reshape_unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
nardoTT committed Nov 29, 2024
1 parent 5af6427 commit 01163e9
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_impl.hpp"

#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"

using namespace tt::tt_metal;

namespace py = pybind11;
Expand Down Expand Up @@ -1557,7 +1560,7 @@ void pytensor_module(py::module& m_tensor) {
dtype = tt_tensor.get_dtype()
)doc")
.def(
"reshape",
"reshape_unsafe",
[](Tensor& self, int N, int C, int H, int W) {
return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector<int>{N, C, H, W}));
},
Expand All @@ -1569,7 +1572,7 @@ void pytensor_module(py::module& m_tensor) {
reshaped_tensor = tt_tensor.reshape(N, C, H, W)
)doc")
.def(
"reshape",
"reshape_unsafe",
[](Tensor& self, const ttnn::Shape& shape) -> Tensor { return self.reshape(shape); },
R"doc(
Reshapes TT tensor
Expand All @@ -1579,7 +1582,7 @@ void pytensor_module(py::module& m_tensor) {
reshaped_tensor = tt_tensor.reshape((4, 3, 32))
)doc")
.def(
"reshape",
"reshape_unsafe",
[](Tensor& self, const ttnn::SmallVector<int32_t>& shape) -> Tensor {
return self.reshape(infer_dims_for_reshape(self, shape));
},
Expand Down

0 comments on commit 01163e9

Please sign in to comment.