Skip to content

Commit

Permalink
#99999: Remove interim avg_pool2d
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrysky3 committed Oct 28, 2024
1 parent ada4ad8 commit 1f2c2d9
Show file tree
Hide file tree
Showing 11 changed files with 20 additions and 43 deletions.
4 changes: 2 additions & 2 deletions tests/tt_eager/ops/test_average_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/pool/avgpool/avg_pool.hpp"
#include "ttnn/operations/pool/global_avg_pool/global_avg_pool.hpp"
#include "ttnn/operations/experimental/auto_format/auto_format.hpp"
#include "ttnn/operations/numpy/functions.hpp"

Expand All @@ -23,7 +23,7 @@ Tensor run_avg_pool_2d_resnet(tt::tt_metal::LegacyShape& tensor_shape, Device* d
if (!AutoFormat::check_input_tensor_format(input_tensor, padded_input_shape)) {
padded_input_tensor = AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, 0, Layout::TILE); // pad with 0s
}
auto device_output = avg_pool2d(padded_input_tensor);
auto device_output = global_avg_pool2d(padded_input_tensor);
return device_output.cpu();
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_run_average_pool(act_shape, dtype, device, use_program_cache, enable_as
ttact_res = ttact.to(device)

def run_ops(ttact_res):
return ttnn.avg_pool2d(ttact_res)
return ttnn.global_avg_pool2d(ttact_res)

# Compile
run_ops(ttact_res)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_run_average_pool(act_shape, dtype, device):
ttact = ttact.pad_to_tile(0.0)
ttact = ttact.to(device)

out = ttnn.avg_pool2d(ttact)
out = ttnn.global_avg_pool2d(ttact)

out = out.cpu().to(ttnn.ROW_MAJOR_LAYOUT)
out_shape = [batch_size, 1, 1, channels]
Expand Down
3 changes: 1 addition & 2 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/softmax_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/downsample/device/downsample_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/downsample/downsample.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/downsample/downsample_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/global_avg_pool/global_avg_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp
Expand Down Expand Up @@ -371,7 +371,6 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
#include "ttnn/operations/matmul/matmul_pybind.hpp"
#include "ttnn/operations/moreh/moreh_pybind.hpp"
#include "ttnn/operations/normalization/normalization_pybind.hpp"
#include "ttnn/operations/pool/avgpool/avg_pool_pybind.hpp"
#include "ttnn/operations/pool/downsample/downsample_pybind.hpp"
#include "ttnn/operations/pool/global_avg_pool/global_avg_pool_pybind.hpp"
#include "ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp"
#include "ttnn/operations/pool/upsample/upsample_pybind.hpp"
#include "ttnn/operations/reduction/reduction_pybind.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/pool/avgpool/avg_pool.hpp"
#include "ttnn/operations/pool/global_avg_pool/global_avg_pool.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"

namespace tt {
Expand All @@ -22,7 +22,7 @@ Tensor pool_2d(const Tensor& input, const MemoryConfig& memory_config, const std
}
}

Tensor avg_pool2d(const Tensor& input, const MemoryConfig& memory_config, const std::optional<DataType>& output_dtype) {
Tensor global_avg_pool2d(const Tensor& input, const MemoryConfig& memory_config, const std::optional<DataType>& output_dtype) {
TT_FATAL(input.storage_type() == StorageType::DEVICE, "Input tensor needs to be on device");
auto output = input;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
namespace tt {
namespace tt_metal {

enum class PoolType {
AVG
};
enum class PoolType { AVG };

Tensor avg_pool2d(const Tensor& input, const MemoryConfig& memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const std::optional<DataType>& output_dtype = std::nullopt);
Tensor global_avg_pool2d(
const Tensor& input,
const MemoryConfig& memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
const std::optional<DataType>& output_dtype = std::nullopt);

} // namespace tt_metal
} // namespace tt


#include "ttnn/operations/pool/avgpool/avg_pool.hpp"
#include "ttnn/operations/pool/global_avg_pool/global_avg_pool.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/operations/core/core.hpp"

Expand All @@ -36,7 +36,7 @@ struct GlobalAveragePool2D {
const std::optional<MemoryConfig>& memory_config_arg = std::nullopt,
const std::optional<DataType>& output_dtype = std::nullopt) {
auto memory_config = memory_config_arg.value_or(input.memory_config());
auto result = tt::tt_metal::avg_pool2d(input, memory_config, output_dtype);
auto result = tt::tt_metal::global_avg_pool2d(input, memory_config, output_dtype);
return result;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/pool/avgpool/avg_pool.hpp"
#include "ttnn/operations/pool/global_avg_pool/global_avg_pool.hpp"
#include "ttnn/types.hpp"

namespace py = pybind11;
Expand Down Expand Up @@ -64,23 +64,6 @@ void bind_global_avg_pool2d(py::module& module) {

void py_module(py::module& module) {
detail::bind_global_avg_pool2d(module);
module.def(
"avg_pool2d",
&avg_pool2d,
py::arg().noconvert(),
py::kw_only(),
py::arg("memory_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
py::arg("dtype").noconvert() = std::nullopt,
R"doc(
Average Pool 2D
It operates on tensors that have channels as the last dimension.
+----------+----------------------------+------------+-------------------------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==========+============================+============+===============================+==========+
| act | Input activations tensor | Tensor | | Yes |
+----------+----------------------------+------------+-------------------------------+----------+
)doc");
}

} // namespace avgpool
Expand Down
8 changes: 4 additions & 4 deletions ttnn/tt_lib/fused_ops/average_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import ttnn


def run_avg_pool_on_device_wrapper(device):
def avg_pool2d(x, output_mem_config, output_dtype=None):
out = ttnn.avg_pool2d(x, memory_config=output_mem_config, dtype=output_dtype)
def run_global_avg_pool_on_device_wrapper(device):
def global_avg_pool2d(x, output_mem_config, output_dtype=None):
out = ttnn.global_avg_pool2d(x, memory_config=output_mem_config, dtype=output_dtype)
return out

return avg_pool2d
return global_avg_pool2d
1 change: 0 additions & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL
)

from ttnn.operations.conv2d import Conv2dConfig, get_conv_padded_input_shape_and_mem_config, get_conv_output_dim
from ttnn.operations.pool import avg_pool2d
from ttnn.operations.conv1d import Conv1d, Conv1dConfig

from ttnn.operations.transformer import SDPAProgramConfig
Expand Down
4 changes: 0 additions & 4 deletions ttnn/ttnn/operations/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,3 @@ def golden_global_avg_pool2d(input_tensor: ttnn.Tensor):


ttnn.attach_golden_function(ttnn.global_avg_pool2d, golden_global_avg_pool2d)

avg_pool2d = ttnn.register_python_operation(name="ttnn.avg_pool2d", golden_function=golden_global_avg_pool2d)(
ttnn._ttnn.operations.pool.avg_pool2d
)

0 comments on commit 1f2c2d9

Please sign in to comment.