Skip to content

Commit

Permalink
#10428: Demonstrate new calls to SetRuntimeArgs by refactoring ttnn ops
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickroberts committed Oct 18, 2024
1 parent f7a5c0e commit 4534636
Show file tree
Hide file tree
Showing 20 changed files with 117 additions and 75 deletions.
5 changes: 5 additions & 0 deletions cmake/dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ set(ENV{CPM_SOURCE_CACHE} "${PROJECT_SOURCE_DIR}/.cpmcache")

include(${PROJECT_SOURCE_DIR}/cmake/fetch_boost.cmake)

add_library(span INTERFACE)
if(CMAKE_CXX_STANDARD LESS 20)
fetch_boost_library(core)
target_link_libraries(span INTERFACE Boost::core)
endif()
fetch_boost_library(smart_ptr)

############################################################################################################################
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
enable_testing()
include(GoogleTest)
add_library(test_common_libs INTERFACE)
target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum fmt)
target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum fmt span)

if(TT_METAL_BUILD_TESTS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal)
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ set(UNIT_TESTS_COMMON_SRC
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_link_training.cpp
)
add_library(unit_tests_common_o OBJECT ${UNIT_TESTS_COMMON_SRC})
target_link_libraries(unit_tests_common_o PUBLIC compiler_flags metal_header_directories gtest gtest_main magic_enum fmt)
target_link_libraries(unit_tests_common_o PUBLIC compiler_flags metal_header_directories gtest gtest_main magic_enum fmt span)
target_include_directories(unit_tests_common_o PUBLIC
${UMD_HOME}
${PROJECT_SOURCE_DIR}
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ set(TT_METAL_OBJECTS

add_library(tt_metal ${TT_METAL_OBJECTS})

target_link_libraries(tt_metal PUBLIC metal_header_directories umd_device metal_common_libs magic_enum fmt)
target_link_libraries(tt_metal PUBLIC metal_header_directories umd_device metal_common_libs magic_enum fmt span)

target_precompile_headers(tt_metal PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/third_party/tracy/public/tracy/Tracy.hpp
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ set(COMMON_SRCS

add_library(common OBJECT ${COMMON_SRCS})
target_link_libraries(common PRIVATE yaml-cpp::yaml-cpp)
target_link_libraries(common PUBLIC compiler_flags metal_header_directories magic_enum fmt)
target_link_libraries(common PUBLIC compiler_flags metal_header_directories magic_enum fmt span)

target_include_directories(common PUBLIC
${UMD_HOME}
Expand Down
35 changes: 32 additions & 3 deletions tt_metal/tt_stl/span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,33 @@

#pragma once

#include <ranges>
// prefer standard library implementation
#if __has_include(<span>)

#include <span>
#define _TT_STL_SPAN_NS ::std

// fallback to boost library implementation
#elif __has_include(<boost/core/span.hpp>)

#include <boost/core/span.hpp>
#define _TT_STL_SPAN_NS ::boost

#else

#error "No implementation available for tt::stl::Span"

#endif

namespace tt::stl {

using std::dynamic_extent;
using _TT_STL_SPAN_NS::dynamic_extent;

namespace detail {

using std::span;
using _TT_STL_SPAN_NS::span;

#undef _TT_STL_SPAN_NS

template <class T, std::size_t Extent>
class SpanBase : public span<T, Extent> {
Expand Down Expand Up @@ -95,8 +112,20 @@ Span(R &&) -> Span<std::remove_reference_t<std::ranges::range_reference_t<R>>>;

} // namespace tt::stl

#if __has_include(<ranges>)

#include <ranges>

// https://en.cppreference.com/w/cpp/ranges/borrowed_range
// The concept std::ranges::borrowed_range defines the requirements of a range such that a function can take it by value
// and return iterators obtained from it without danger of dangling.
template <class T, std::size_t Extent>
constexpr bool std::ranges::enable_borrowed_range<tt::stl::Span<T, Extent>> = true;

// https://en.cppreference.com/w/cpp/ranges/view
// The std::ranges::view concept specifies the requirements of a range type that has suitable semantic properties for
// use in constructing range adaptor pipelines.
template <class T, std::size_t Extent>
constexpr bool std::ranges::enable_view<tt::stl::Span<T, Extent>> = true;

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,13 @@ operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(15, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(9, 0));
constexpr std::array<uint32_t, 15> binary_reader_kernel_args{0};
constexpr std::array<uint32_t, 3> bcast_kernel_args{0};
constexpr std::array<uint32_t, 9> unary_writer_kernel_args{0};

tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, binary_reader_kernel_args);
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, bcast_kernel_args);
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, unary_writer_kernel_args);
continue;
}
uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,13 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
} else if (core_group_2.core_coord_in_core_ranges(core)) {
num_tensor_tiles_per_core = num_tiles_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(7, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, {1, 1, 0});
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(3, 0));
constexpr std::array<uint32_t, 7> binary_reader_kernel_args{0};
constexpr std::array<uint32_t, 3> bcast_kernel_args{1, 1, 0};
constexpr std::array<uint32_t, 3> unary_writer_kernel_args{0};

tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, binary_reader_kernel_args);
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, bcast_kernel_args);
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, unary_writer_kernel_args);
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,13 @@ operation::ProgramWithCallbacks bcast_multi_core_w(const Tensor &a, const Tensor
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(16, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(9, 0));
constexpr std::array<uint32_t, 16> binary_reader_kernel_args{0};
constexpr std::array<uint32_t, 3> bcast_kernel_args{0};
constexpr std::array<uint32_t, 9> unary_writer_kernel_args{0};

tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, binary_reader_kernel_args);
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, bcast_kernel_args);
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, unary_writer_kernel_args);
continue;
}
uint32_t num_tensor_tiles_per_core = NC*Ht*Wt_per_core;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Fold::SingleCore::cached_program_t fold_single_core(
SetRuntimeArgs(program, reader_kernel_id, core, {src_buffer->address(), pixel_size, num_pixels, 0});

// Writer run-time args
std::vector<uint32_t> writer_kernel_args = {
const std::array writer_kernel_args = {
dst_buffer->address(),
dst_pixel_size,
scratch_buffer->address(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ operation::ProgramWithCallbacks indexed_fill_multi_core(const Tensor &batch_ids,
uint32_t local_b = (i<B) ? b : 0;
uint32_t local_batch_size_in_sticks = (i<B) ? batch_size_in_sticks : 0;

std::vector<uint32_t> reader_runtime_args = {
const std::array reader_runtime_args = {
batch_ids.buffer()->address(),
local_b,
input_a.buffer()->address(),
Expand All @@ -108,7 +108,7 @@ operation::ProgramWithCallbacks indexed_fill_multi_core(const Tensor &batch_ids,
i
};
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
std::vector<uint32_t> writer_runtime_args = {
const std::array writer_runtime_args = {
output.buffer()->address(),
page_size,
local_batch_size_in_sticks,
Expand All @@ -135,7 +135,7 @@ operation::ProgramWithCallbacks indexed_fill_multi_core(const Tensor &batch_ids,
for (const auto &core : cores) {
uint32_t local_b = (core_id<B) ? b : 0;
uint32_t local_batch_size_in_sticks = (core_id<B) ? batch_size_in_sticks : 0;
std::vector<uint32_t> reader_runtime_args = {
const std::array reader_runtime_args = {
batch_ids.buffer()->address(),
local_b,
input_a.buffer()->address(),
Expand All @@ -146,7 +146,7 @@ operation::ProgramWithCallbacks indexed_fill_multi_core(const Tensor &batch_ids,
};
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);

std::vector<uint32_t> writer_runtime_args = {
const std::array writer_runtime_args = {
output.buffer()->address(),
page_size,
local_batch_size_in_sticks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor& input, Ten
shard_grid,
DataMovementConfig{
.processor = DataMovementProcessor::RISCV_1, .noc = NOC::NOC_1, .compile_args = reader_compile_time_args});
std::vector<uint32_t> runtime_args = {
const std::array runtime_args = {
total_size_bytes, num_chunks, move_chunk_size_bytes, remainder_chunk_size_bytes};
SetRuntimeArgs(program, kernel_id, shard_grid, runtime_args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ operation::ProgramWithCallbacks non_zero_indices_single_core(const Tensor &input
(std::uint32_t) out_is_dram_1,
};

std::vector<uint32_t> run_time_args = {
const std::array run_time_args = {
(std::uint32_t) input.buffer()->address(),
(std::uint32_t) out_num_indices.buffer()->address(),
(std::uint32_t) out_indices.buffer()->address(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ operation::ProgramWithCallbacks pad_rm_reader_writer(const Tensor &a,

uint32_t start_src_stick_id = 0;
uint32_t start_dst_stick_id = 0;
vector<uint32_t> reader_rt_args = {src0_buffer->address(),
const std::array reader_rt_args = {src0_buffer->address(),
dst_buffer->address(),
a.get_legacy_shape()[0],
output_shape[0],
Expand All @@ -125,17 +125,17 @@ operation::ProgramWithCallbacks pad_rm_reader_writer(const Tensor &a,
packed_pad_value,
start_src_stick_id,
start_dst_stick_id,
0,
0,
0,
std::uint32_t{0},
std::uint32_t{0},
std::uint32_t{0},
output_shape[2],
a.get_legacy_shape()[2],
unpadded_row_size_nbytes,
padded_row_size_nbytes,
0,
std::uint32_t{0},
output.get_legacy_shape()[0]
};
vector<uint32_t> writer_rt_args = reader_rt_args;
const auto &writer_rt_args = reader_rt_args;
tt::tt_metal::SetRuntimeArgs(program,
reader_kernel_id,
cores,
Expand Down Expand Up @@ -247,7 +247,7 @@ operation::ProgramWithCallbacks pad_rm_opt(const Tensor &a,
}
#endif

vector<uint32_t> reader_rt_args = {src0_buffer->address(),
const std::array reader_rt_args = {src0_buffer->address(),
dst_buffer->address(),
a.get_legacy_shape()[0],
output_shape[0],
Expand Down Expand Up @@ -321,7 +321,7 @@ operation::ProgramWithCallbacks pad_rm(const Tensor &a, Tensor &output, const Sh
bfloat16 bfloat_pad_value = bfloat16(pad_value);
uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value});

vector<uint32_t> reader_kernel_args = {
const std::array reader_kernel_args = {
src0_buffer->address(),
dst_buffer->address(),
a.get_legacy_shape()[0],
Expand Down Expand Up @@ -445,11 +445,12 @@ operation::ProgramWithCallbacks pad_tile(const Tensor &a, Tensor& output, const

uint32_t num_unpadded_tiles = a.volume() / TILE_HW;

vector<uint32_t> reader_kernel_args = {
const std::array reader_kernel_args = {
src0_buffer->address(),
num_unpadded_tiles, 0
num_unpadded_tiles,
std::uint32_t{0},
};
vector<uint32_t> writer_kernel_args = {
const std::array writer_kernel_args = {
dst_buffer->address(),
num_unpadded_W,
num_padded_Wt,
Expand Down Expand Up @@ -785,7 +786,7 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core(const Tensor &a,
curr_stick_diff_nbytes = dst_nbytes_per_core_w - curr_stick_size_nbytes;
rem_src_stick_size_nbytes = 0;
}
vector<uint32_t> reader_rt_args = {src0_buffer->address(),
const std::array reader_rt_args = {src0_buffer->address(),
dst_buffer->address(),
a.get_legacy_shape()[0],
output_shape[0],
Expand Down Expand Up @@ -822,7 +823,7 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core(const Tensor &a,
// log_debug("{} :: nbatch_per_core_h: {}", core.y, nbatch_per_core_h);
// log_debug("{} :: ncores_per_batch_h: {}", core.y, ncores_per_batch_h);
// }
vector<uint32_t> writer_rt_args = reader_rt_args;
const auto &writer_rt_args = reader_rt_args;
tt::tt_metal::SetRuntimeArgs(program,
reader_kernel_id,
core,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ inline __attribute__((always_inline)) void set_slice_runtime_args_tile(
}

if constexpr (initialize_args) {
vector<uint32_t> writer_kernel_args = {output_buffer->address(), num_tiles_per_core, num_tiles_written};
const std::array writer_kernel_args = {output_buffer->address(), num_tiles_per_core, num_tiles_written};
tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args);
} else {
auto& writer_kernel_args = writer_kernel_args_by_core[core.x][core.y];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void setup_runtime(
uint32_t reader_core_id = id_c * per_core_tiles_y;
reader_core_id += id_r_reader;

std::vector<uint32_t> reader_runtime_args = {
const std::array reader_runtime_args = {
(std::uint32_t)reader_core_id,
(std::uint32_t)(in0_buffer->address()), // in0_tensor_addr
(std::uint32_t)0 // split on last dim
Expand All @@ -66,7 +66,7 @@ void setup_runtime(

uint32_t writer_core_id = id_c_inner * per_core_tiles_y + (id_r_writer);

std::vector<uint32_t> writer_runtime_args = {
const std::array writer_runtime_args = {
writer_core_id,
(std::uint32_t)out0_buffer->address(), // first base addr
(std::uint32_t)out1_buffer->address(), // second base addr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ operation::ProgramWithCallbacks tilize_single_core(const Tensor& a, Tensor& outp
.set_page_size(output_cb_index, output_single_tile_size);
auto cb_output = tt::tt_metal::CreateCircularBuffer(program, core, cb_output_config);

vector<uint32_t> reader_kernel_args = {
const std::array reader_kernel_args = {
src0_buffer->address(),
num_sticks,
stick_size,
Expand All @@ -89,7 +89,7 @@ operation::ProgramWithCallbacks tilize_single_core(const Tensor& a, Tensor& outp
num_full_blocks_in_row,
num_leftover_tiles,
leftover_width_in_row,
0 // row_start_id
std::uint32_t{0}, // row_start_id
};

// Reader compile-time args
Expand Down Expand Up @@ -236,19 +236,19 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T
const CoreCoord& core = cores[i];

// reader runtime args
vector<uint32_t> reader_rt_args = {
const std::array reader_rt_args = {
src0_buffer->address(),
nblocks_per_core * TILE_HEIGHT,
block_size_nbytes,
ntiles_per_block,
block_size_nbytes,
1, // full blocks in row
0, // num leftover tiles
0, // leftover width in row
std::uint32_t{1}, // full blocks in row
std::uint32_t{0}, // num leftover tiles
std::uint32_t{0}, // leftover width in row
row_start_id};

// writer runtime args
vector<uint32_t> writer_rt_args = {
const std::array writer_rt_args = {
dst_buffer->address(),
ntiles_per_block * nblocks_per_core, // ntiles per core
tile_start_id // start id
Expand All @@ -265,19 +265,19 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T
const CoreCoord& core = cores.back();

// reader runtime args
vector<uint32_t> reader_rt_args = {
const std::array reader_rt_args = {
src0_buffer->address(),
nblocks_per_core_cliff * TILE_HEIGHT,
block_size_nbytes,
ntiles_per_block,
block_size_nbytes,
1, // full blocks in row
0, // num leftover tiles
0, // leftover width in row
std::uint32_t{1}, // full blocks in row
std::uint32_t{0}, // num leftover tiles
std::uint32_t{0}, // leftover width in row
row_start_id};

// writer runtime args
vector<uint32_t> writer_rt_args = {
const std::array writer_rt_args = {
dst_buffer->address(),
ntiles_per_block * nblocks_per_core_cliff, // ntiles per core
tile_start_id // start id
Expand Down
Loading

0 comments on commit 4534636

Please sign in to comment.