-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#0: add unsqueeze and move squeeze from py to cpp
- Loading branch information
Showing
13 changed files
with
360 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, dim", | ||
[ | ||
((1, 1, 1, 256), 2), | ||
((1, 1, 1, 256), -1), | ||
((1, 1, 1, 30), 2), | ||
((1, 1, 1, 30), -1), | ||
], | ||
) | ||
def test_squeeze_as_reshape(device, input_shape, dim): | ||
torch_input_tensor = torch.rand(input_shape, dtype=torch.float32) | ||
torch_squeeze_tensor = torch.squeeze(torch_input_tensor, dim) | ||
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) | ||
ttnn_output = ttnn.squeeze(input_tensor, dim) | ||
torch_output_tensor = ttnn.to_torch(ttnn_output) | ||
assert torch.allclose(torch_output_tensor, torch_squeeze_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, dim", | ||
[ | ||
((1, 1, 256), 2), | ||
((1, 1, 256), -2), | ||
((1, 256), 1), | ||
((1, 1, 30), 2), | ||
((1, 1, 30), -2), | ||
((1, 30), 1), | ||
], | ||
) | ||
def test_unsqueeze(device, input_shape, dim): | ||
torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16) | ||
torch_unsqueeze_tensor = torch.unsqueeze(torch_input_tensor, dim) | ||
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) | ||
ttnn_output = ttnn.unsqueeze(input_tensor, dim) | ||
torch_output_tensor = ttnn.to_torch(ttnn_output) | ||
assert torch.allclose(torch_output_tensor, torch_unsqueeze_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
#include "squeeze.hpp" | ||
#include "ttnn/operations/core/core.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
ttnn::Tensor SqueezeOperation::invoke( | ||
const ttnn::Tensor& input_tensor, | ||
const int dim | ||
) { | ||
|
||
const auto tensor_shape = input_tensor.get_shape(); | ||
const auto rank = tensor_shape.rank(); | ||
std::vector<uint32_t> output_shape_vector; | ||
|
||
int normal_dim = dim; | ||
if (dim < 0) { | ||
// Handle negative dimension by converting it to positive | ||
normal_dim += rank; | ||
} | ||
|
||
// Remove the dimension if it is of size 1 | ||
for (size_t i = 0; i < tensor_shape.size(); ++i) { | ||
if (static_cast<int>(i) != normal_dim || tensor_shape[i] != 1) { | ||
output_shape_vector.push_back(tensor_shape[i]); | ||
} | ||
} | ||
|
||
// If dim is out of range or original dimension was not of size 1, include all dimensions | ||
if (dim >= static_cast<int>(tensor_shape.size()) || tensor_shape[dim] != 1) { | ||
return input_tensor; | ||
} | ||
|
||
ttnn::Shape output_shape(output_shape_vector); | ||
return ttnn::reshape(input_tensor, output_shape); | ||
|
||
} | ||
|
||
} // ttnn::operations::data_movement namespace |
26 changes: 26 additions & 0 deletions
26
ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ttnn/decorators.hpp" | ||
|
||
|
||
namespace ttnn { | ||
namespace operations::data_movement { | ||
|
||
struct SqueezeOperation { | ||
static ttnn::Tensor invoke( | ||
const ttnn::Tensor& input_tensor, | ||
const int dim | ||
); | ||
|
||
}; | ||
|
||
|
||
} // namespace operations::data_movement | ||
|
||
constexpr auto squeeze = ttnn::register_operation<"ttnn::squeeze", ttnn::operations::data_movement::SqueezeOperation>(); | ||
|
||
} // namespace ttnn |
64 changes: 64 additions & 0 deletions
64
ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "squeeze_pybind.hpp" | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
#include "ttnn/operations/data_movement/squeeze/squeeze.hpp" | ||
#include "ttnn/types.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
namespace detail { | ||
|
||
template <typename data_movement_operation_t> | ||
void bind_squeeze(pybind11::module& module, const data_movement_operation_t& operation, const char* doc) { | ||
bind_registered_operation( | ||
module, | ||
operation, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[](const data_movement_operation_t& self, | ||
const ttnn::Tensor& input_tensor, | ||
const int dim | ||
) -> ttnn::Tensor { | ||
return self(input_tensor, dim); | ||
}, | ||
py::arg("input_tensor"), | ||
py::arg("dim") | ||
} | ||
); | ||
} | ||
|
||
} // namespace detail | ||
|
||
|
||
void py_bind_squeeze(pybind11::module& module) { | ||
detail::bind_squeeze( | ||
module, | ||
ttnn::squeeze, | ||
R"doc(squeeze(input_tensor: ttnn.Tensor, dim: int) -> ttnn.Tensor | ||
Returns a tensor squeezed at the specified dimension. Pytorch supports a tuple as well as a single scalar value for dim, currently our version only supports scalar values. We will address this in the future. If input_tensor.shape()[dim] is not 1, squeeze will be ignored for that shape. | ||
Equivalent pytorch code: | ||
.. code-block:: python | ||
input_tensor = torch.rand((1,1,1,256), dtype=torch.bfloat16) | ||
output_tensor = torch.squeeze(input_tensor, 2) # tensor of shape (1,1,256), where at dimension 2 we removed it | ||
Args: | ||
* :attr:`input_tensor`: Input Tensor. | ||
* :attr:`dim`: Dim where we want to squeeze | ||
)doc"); | ||
} | ||
|
||
} // namespace ttnn::operations::data_movement |
13 changes: 13 additions & 0 deletions
13
ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
void py_bind_squeeze(pybind11::module& module); | ||
|
||
} // namespace ttnn::operations::data_movement |
47 changes: 47 additions & 0 deletions
47
ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
#include "unsqueeze.hpp" | ||
#include "ttnn/operations/core/core.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
ttnn::Tensor UnsqueezeOperation::invoke( | ||
const ttnn::Tensor& input_tensor, | ||
const int dim | ||
) { | ||
|
||
const auto tensor_shape = input_tensor.get_shape(); | ||
const auto rank = tensor_shape.rank(); | ||
std::vector<uint32_t> output_shape_vector; | ||
|
||
TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR or (!tensor_shape.has_tile_padding()), "Currently supporing ROW-MAJOR tensors or TILE tensors with no padding"); | ||
|
||
int normal_dim = dim; | ||
// Handle negative dimension by converting it to positive | ||
if (dim < 0) { | ||
normal_dim += rank + 1; | ||
} | ||
|
||
// Insert new dimension | ||
for (int i = 0; i < rank; ++i) { | ||
if (i == normal_dim) { | ||
output_shape_vector.push_back(1); | ||
} | ||
output_shape_vector.push_back(tensor_shape[i]); | ||
} | ||
|
||
// If the dimension is at the end, append it | ||
if (normal_dim >= tensor_shape.size()) { | ||
output_shape_vector.push_back(1); | ||
} | ||
|
||
ttnn::Shape output_shape(output_shape_vector); | ||
return ttnn::reshape(input_tensor, output_shape); | ||
|
||
|
||
} | ||
|
||
} // ttnn::operations::data_movement namespace |
26 changes: 26 additions & 0 deletions
26
ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ttnn/decorators.hpp" | ||
|
||
|
||
namespace ttnn { | ||
namespace operations::data_movement { | ||
|
||
struct UnsqueezeOperation { | ||
static ttnn::Tensor invoke( | ||
const ttnn::Tensor& input_tensor, | ||
const int dim | ||
); | ||
|
||
}; | ||
|
||
|
||
} // namespace operations::data_movement | ||
|
||
constexpr auto unsqueeze = ttnn::register_operation<"ttnn::unsqueeze", ttnn::operations::data_movement::UnsqueezeOperation>(); | ||
|
||
} // namespace ttnn |
64 changes: 64 additions & 0 deletions
64
ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "unsqueeze_pybind.hpp" | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp" | ||
#include "ttnn/types.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
namespace detail { | ||
|
||
template <typename data_movement_operation_t> | ||
void bind_unsqueeze(pybind11::module& module, const data_movement_operation_t& operation, const char* doc) { | ||
bind_registered_operation( | ||
module, | ||
operation, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[](const data_movement_operation_t& self, | ||
const ttnn::Tensor& input_tensor, | ||
const int dim | ||
) -> ttnn::Tensor { | ||
return self(input_tensor, dim); | ||
}, | ||
py::arg("input_tensor"), | ||
py::arg("dim") | ||
} | ||
); | ||
} | ||
|
||
} // namespace detail | ||
|
||
|
||
void py_bind_unsqueeze(pybind11::module& module) { | ||
detail::bind_unsqueeze( | ||
module, | ||
ttnn::unsqueeze, | ||
R"doc(unsqueeze(input_tensor: ttnn.Tensor, dim: int) -> ttnn.Tensor | ||
Returns a tensor unsqueezed at the specified dimension | ||
Equivalent pytorch code: | ||
.. code-block:: python | ||
input_tensor = torch.rand((1,1,256), dtype=torch.bfloat16) | ||
output_tensor = torch.unsqueeze(input_tensor, 2) # tensor of shape (1,1,1,256), where at dimension 2 we added a new dim of size 1 | ||
Args: | ||
* :attr:`input_tensor`: Input Tensor. | ||
* :attr:`dim`: Dim where we want to unsqueeze (add a new dimension of size 1) | ||
)doc"); | ||
} | ||
|
||
} // namespace ttnn::operations::data_movement |
13 changes: 13 additions & 0 deletions
13
ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
void py_bind_unsqueeze(pybind11::module& module); | ||
|
||
} // namespace ttnn::operations::data_movement |
Oops, something went wrong.