Skip to content

Commit

Permalink
#0: add unsqueeze and move squeeze from py to cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed Sep 19, 2024
1 parent bd3f53e commit 189cc2f
Show file tree
Hide file tree
Showing 13 changed files with 360 additions and 16 deletions.
27 changes: 27 additions & 0 deletions tests/ttnn/unit_tests/test_squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn


@pytest.mark.parametrize(
"input_shape, dim",
[
((1, 1, 1, 256), 2),
((1, 1, 1, 256), -1),
((1, 1, 1, 30), 2),
((1, 1, 1, 30), -1),
],
)
def test_squeeze_as_reshape(device, input_shape, dim):
torch_input_tensor = torch.rand(input_shape, dtype=torch.float32)
torch_squeeze_tensor = torch.squeeze(torch_input_tensor, dim)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
ttnn_output = ttnn.squeeze(input_tensor, dim)
torch_output_tensor = ttnn.to_torch(ttnn_output)
assert torch.allclose(torch_output_tensor, torch_squeeze_tensor)
29 changes: 29 additions & 0 deletions tests/ttnn/unit_tests/test_unsqueeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn


@pytest.mark.parametrize(
"input_shape, dim",
[
((1, 1, 256), 2),
((1, 1, 256), -2),
((1, 256), 1),
((1, 1, 30), 2),
((1, 1, 30), -2),
((1, 30), 1),
],
)
def test_unsqueeze(device, input_shape, dim):
torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)
torch_unsqueeze_tensor = torch.unsqueeze(torch_input_tensor, dim)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
ttnn_output = ttnn.unsqueeze(input_tensor, dim)
torch_output_tensor = ttnn.to_torch(ttnn_output)
assert torch.allclose(torch_output_tensor, torch_unsqueeze_tensor)
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape/reshape_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/squeeze/squeeze_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape/device/reshape_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape/device/reshape_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/interleaved_to_sharded_partial.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/sharded_to_interleaved_partial_pybind.hpp"
#include "ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/interleaved_to_sharded_partial_pybind.hpp"
#include "ttnn/operations/data_movement/reshape/reshape_pybind.hpp"
#include "ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.hpp"
#include "ttnn/operations/data_movement/squeeze/squeeze_pybind.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/reshard_pybind.hpp"
Expand Down Expand Up @@ -61,6 +63,8 @@ void py_module(py::module& module) {
bind_fill_rm(module);
py_bind_repeat(module);
py_bind_reshape(module);
py_bind_unsqueeze(module);
py_bind_squeeze(module);
detail::bind_indexed_fill(module);
bind_fold_operation(module);
py_bind_sharded_to_interleaved_partial(module);
Expand Down
43 changes: 43 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0


#include "squeeze.hpp"
#include "ttnn/operations/core/core.hpp"

namespace ttnn::operations::data_movement {

ttnn::Tensor SqueezeOperation::invoke(
const ttnn::Tensor& input_tensor,
const int dim
) {

const auto tensor_shape = input_tensor.get_shape();
const auto rank = tensor_shape.rank();
std::vector<uint32_t> output_shape_vector;

int normal_dim = dim;
if (dim < 0) {
// Handle negative dimension by converting it to positive
normal_dim += rank;
}

// Remove the dimension if it is of size 1
for (size_t i = 0; i < tensor_shape.size(); ++i) {
if (static_cast<int>(i) != normal_dim || tensor_shape[i] != 1) {
output_shape_vector.push_back(tensor_shape[i]);
}
}

// If dim is out of range or original dimension was not of size 1, include all dimensions
if (dim >= static_cast<int>(tensor_shape.size()) || tensor_shape[dim] != 1) {
return input_tensor;
}

ttnn::Shape output_shape(output_shape_vector);
return ttnn::reshape(input_tensor, output_shape);

}

} // ttnn::operations::data_movement namespace
26 changes: 26 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"


namespace ttnn {
namespace operations::data_movement {

struct SqueezeOperation {
static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const int dim
);

};


} // namespace operations::data_movement

constexpr auto squeeze = ttnn::register_operation<"ttnn::squeeze", ttnn::operations::data_movement::SqueezeOperation>();

} // namespace ttnn
64 changes: 64 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze_pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "squeeze_pybind.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/data_movement/squeeze/squeeze.hpp"
#include "ttnn/types.hpp"

namespace ttnn::operations::data_movement {

namespace detail {

template <typename data_movement_operation_t>
void bind_squeeze(pybind11::module& module, const data_movement_operation_t& operation, const char* doc) {
bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const data_movement_operation_t& self,
const ttnn::Tensor& input_tensor,
const int dim
) -> ttnn::Tensor {
return self(input_tensor, dim);
},
py::arg("input_tensor"),
py::arg("dim")
}
);
}

} // namespace detail


void py_bind_squeeze(pybind11::module& module) {
detail::bind_squeeze(
module,
ttnn::squeeze,
R"doc(squeeze(input_tensor: ttnn.Tensor, dim: int) -> ttnn.Tensor
Returns a tensor squeezed at the specified dimension. Pytorch supports a tuple as well as a single scalar value for dim, currently our version only supports scalar values. We will address this in the future. If input_tensor.shape()[dim] is not 1, squeeze will be ignored for that shape.
Equivalent pytorch code:
.. code-block:: python
input_tensor = torch.rand((1,1,1,256), dtype=torch.bfloat16)
output_tensor = torch.squeeze(input_tensor, 2) # tensor of shape (1,1,256), where at dimension 2 we removed it
Args:
* :attr:`input_tensor`: Input Tensor.
* :attr:`dim`: Dim where we want to squeeze
)doc");
}

} // namespace ttnn::operations::data_movement
13 changes: 13 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze_pybind.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "pybind11/pybind_fwd.hpp"

namespace ttnn::operations::data_movement {

void py_bind_squeeze(pybind11::module& module);

} // namespace ttnn::operations::data_movement
47 changes: 47 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0


#include "unsqueeze.hpp"
#include "ttnn/operations/core/core.hpp"

namespace ttnn::operations::data_movement {

ttnn::Tensor UnsqueezeOperation::invoke(
const ttnn::Tensor& input_tensor,
const int dim
) {

const auto tensor_shape = input_tensor.get_shape();
const auto rank = tensor_shape.rank();
std::vector<uint32_t> output_shape_vector;

TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR or (!tensor_shape.has_tile_padding()), "Currently supporing ROW-MAJOR tensors or TILE tensors with no padding");

int normal_dim = dim;
// Handle negative dimension by converting it to positive
if (dim < 0) {
normal_dim += rank + 1;
}

// Insert new dimension
for (int i = 0; i < rank; ++i) {
if (i == normal_dim) {
output_shape_vector.push_back(1);
}
output_shape_vector.push_back(tensor_shape[i]);
}

// If the dimension is at the end, append it
if (normal_dim >= tensor_shape.size()) {
output_shape_vector.push_back(1);
}

ttnn::Shape output_shape(output_shape_vector);
return ttnn::reshape(input_tensor, output_shape);


}

} // ttnn::operations::data_movement namespace
26 changes: 26 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"


namespace ttnn {
namespace operations::data_movement {

struct UnsqueezeOperation {
static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const int dim
);

};


} // namespace operations::data_movement

constexpr auto unsqueeze = ttnn::register_operation<"ttnn::unsqueeze", ttnn::operations::data_movement::UnsqueezeOperation>();

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "unsqueeze_pybind.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp"
#include "ttnn/types.hpp"

namespace ttnn::operations::data_movement {

namespace detail {

template <typename data_movement_operation_t>
void bind_unsqueeze(pybind11::module& module, const data_movement_operation_t& operation, const char* doc) {
bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const data_movement_operation_t& self,
const ttnn::Tensor& input_tensor,
const int dim
) -> ttnn::Tensor {
return self(input_tensor, dim);
},
py::arg("input_tensor"),
py::arg("dim")
}
);
}

} // namespace detail


void py_bind_unsqueeze(pybind11::module& module) {
detail::bind_unsqueeze(
module,
ttnn::unsqueeze,
R"doc(unsqueeze(input_tensor: ttnn.Tensor, dim: int) -> ttnn.Tensor
Returns a tensor unsqueezed at the specified dimension
Equivalent pytorch code:
.. code-block:: python
input_tensor = torch.rand((1,1,256), dtype=torch.bfloat16)
output_tensor = torch.unsqueeze(input_tensor, 2) # tensor of shape (1,1,1,256), where at dimension 2 we added a new dim of size 1
Args:
* :attr:`input_tensor`: Input Tensor.
* :attr:`dim`: Dim where we want to unsqueeze (add a new dimension of size 1)
)doc");
}

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "pybind11/pybind_fwd.hpp"

namespace ttnn::operations::data_movement {

void py_bind_unsqueeze(pybind11::module& module);

} // namespace ttnn::operations::data_movement
Loading

0 comments on commit 189cc2f

Please sign in to comment.