-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for resharding width-sharded tensors to/from DRAM
This commit introduces a new width-shard to width-shard reshard kernel, modeled after the height-sharded tensor reshard special case implemented in `reshard_multi_core_same_width`. Supported operations include: - L1 to L1 - L1 to DRAM - DRAM to L1 Currently, only row-major tensors are supported. For unsupported cases, we fall back to the generalized reshard implementation. Unit tests have been added to validate the new kernel functionality.
- Loading branch information
Showing
5 changed files
with
511 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
...n/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_reader.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0); | ||
|
||
const uint32_t total_num_sticks = get_arg_val<uint32_t>(0); | ||
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1); | ||
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2); | ||
const uint32_t base_read_addr = get_arg_val<uint32_t>(3); | ||
const uint32_t num_segments = get_arg_val<uint32_t>(4); | ||
|
||
uint32_t args_idx = 0; | ||
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5)); | ||
|
||
uint32_t base_write_addr = get_read_ptr(shard_cb_id); | ||
|
||
for (uint32_t i = 0; i < num_segments; ++i) { | ||
uint32_t read_size = args[args_idx++]; | ||
|
||
uint32_t write_offset = args[args_idx++]; | ||
uint32_t l1_write_addr = base_write_addr + write_offset; | ||
|
||
uint32_t x_coord = args[args_idx++]; | ||
uint32_t y_coord = args[args_idx++]; | ||
uint32_t read_offset = base_read_addr + args[args_idx++]; | ||
uint64_t noc_read_addr = get_noc_addr(x_coord, y_coord, read_offset); | ||
|
||
for (uint32_t j = 0; j < total_num_sticks; ++j) { | ||
noc_async_read(noc_read_addr, l1_write_addr, read_size); | ||
l1_write_addr += local_stride_bytes; | ||
noc_read_addr += remote_stride_bytes; | ||
} | ||
} | ||
noc_async_write_barrier(); | ||
} |
40 changes: 40 additions & 0 deletions
40
...n/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_writer.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0); | ||
|
||
const uint32_t total_num_sticks = get_arg_val<uint32_t>(0); | ||
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1); | ||
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2); | ||
const uint32_t base_write_addr = get_arg_val<uint32_t>(3); | ||
const uint32_t num_segments = get_arg_val<uint32_t>(4); | ||
|
||
uint32_t args_idx = 0; | ||
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5)); | ||
|
||
uint32_t base_l1_read_addr = get_read_ptr(shard_cb_id); | ||
|
||
for (uint32_t i = 0; i < num_segments; ++i) { | ||
uint32_t write_size = args[args_idx++]; | ||
|
||
uint32_t read_offset = args[args_idx++]; | ||
uint32_t l1_read_addr = base_l1_read_addr + read_offset; | ||
|
||
uint32_t x_coord = args[args_idx++]; | ||
uint32_t y_coord = args[args_idx++]; | ||
uint32_t write_offset = base_write_addr + args[args_idx++]; | ||
uint64_t noc_write_addr = get_noc_addr(x_coord, y_coord, write_offset); | ||
|
||
for (uint32_t j = 0; j < total_num_sticks; ++j) { | ||
noc_async_write(l1_read_addr, noc_write_addr, write_size); | ||
l1_read_addr += local_stride_bytes; | ||
noc_write_addr += remote_stride_bytes; | ||
} | ||
} | ||
noc_async_write_barrier(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.