-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for new matmul1d op with gather_in0 #14964
base: main
Are you sure you want to change the base?
Conversation
@@ -1659,8 +1989,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( | |||
|
|||
if (fp32_dest_acc_en) { | |||
TT_FATAL( | |||
out_subblock_h * out_subblock_w <= 4, | |||
"Total number of tiles in a subblock must be less than 4 when in fp32_dest_acc mode"); | |||
out_subblock_h * out_subblock_w <= 8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this to 8?
Isn't 16 cut in half because of half dest mode and in half again because of fp32?
If anything this could be removed since
TT_FATAL(
(program_config.out_subblock_w * program_config.out_subblock_h) <= available_reg_count,
"out_subblock_w {} times out_subblock_h {} needs to be at most {} to fit in hardware",
program_config.out_subblock_w,
program_config.out_subblock_h,
available_reg_count);
has been added to validate() in matmul_op.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one of the remaining todo's in the PR is getting support for full dest mode (see here). If it isn't merged before this PR, I can revert the hard coded value.
In my testing for the shapes used in the Llama models, I have found that we can actually use 8 here, and it results in significant speedups since there's no reload.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you enable full dest mode then that will need to be reflected in the config and get_dest_reg_count needs to return 8. You cannot use a constant.
Please either remove this test or make it do the same thing as in validate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like full dest mode support has already been added in main. However, the pybind for the compute kernel config has not been updated to allow users to enable it.
I've opened a PR here that updates the pybind, and I am currently waiting for confirmation from @amahmudTT that this change is correct.
.defines = mm_kernel_defines}); | ||
|
||
/* Create circular buffers */ | ||
uint32_t src0_cb_index = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please just use the cb constants instead of raw numbers in the assignments.
# 32, 2304, 3840 | ||
(1, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (8, 3)), | ||
# 32, 2304, 3840 | ||
(3, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (8, 3)), | ||
# 32, 2304, 3840 | ||
(3, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi, False, False, (8, 3)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has the exact arbitary shard grid been tested (ie, the cores placed near the dram banks)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has not been tested just yet. But I will do that and include results here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested with the exact core grid with core near the dram banks. Results are 👍 .
#ifdef MATMUL_DRAM_SHARDED | ||
const bool is_worker_core = get_arg_val<uint32_t>(0) == 1; | ||
// if not worker core, skip | ||
if (not is_worker_core) { | ||
return; | ||
} | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be removed since it will never be triggered
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I can get rid of this. I can also get rid of all the other things in this bmm kernel that aren't being used (untilze out, fused bias, etc). Is this fine @bbradelTT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine.
Please add appropriate checks in validate() to ensure that the inputs are not dram sharded and the other things are not being used.
for (uint32_t shard_cnt = 0; shard_cnt < ring_size; shard_cnt++) { | ||
|
||
uint32_t curr_shard_write_addr = l1_write_addr_in0 + shard_size_bytes * shard_cnt; | ||
uint64_t remote_curr_shard_write_addr = get_noc_addr(next_core_noc_x, next_core_noc_y, curr_shard_write_addr, noc); | ||
uint32_t curr_shard_read_addr = shard_cnt == 0 ? local_shard_read_addr : l1_write_addr_in0 + shard_size_bytes * (shard_cnt - 1); | ||
|
||
|
||
// Wait for signal from previous core that data has been added to this core's in0 | ||
noc_semaphore_wait_min(l1_signal_sem_addr, shard_cnt); | ||
|
||
// Send data to next core | ||
if (shard_cnt < ring_size - 1) { // Skip sending the last shard | ||
noc_async_write(curr_shard_read_addr, remote_curr_shard_write_addr, shard_size_bytes, noc); | ||
|
||
// Signal the next core that data is ready | ||
noc_semaphore_inc(remote_signal_semaphore_addr, 1, noc); | ||
} | ||
|
||
// Do stuff for matmul fusion here | ||
if (shard_cnt > 0) { | ||
cb_push_back(cb_id_in2, shard_size_in_tiles); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have another issue tracking the support for back-pressure (use global CB)?
uint32_t get_preferred_noc(const uint32_t src_x, const uint32_t dst_x, const tt_metal::Device* device) { | ||
/* | ||
NOC0: Preferred +x -> +y | ||
NOC1: Preferred -y -> -x | ||
*/ | ||
uint32_t MAX_X = device->grid_size().x; | ||
|
||
// Get the wrapped distances | ||
uint32_t dist_right = src_x < dst_x ? dst_x - src_x : MAX_X - src_x + dst_x; | ||
uint32_t dist_left = src_x < dst_x ? src_x + MAX_X - dst_x : src_x - dst_x; | ||
|
||
return dist_right < dist_left ? 0 : 1; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be good try it out with the intended core arrangement (cores placed near the dram banks)
7633105
to
8a16f0b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @johanna-rock-tt what is the largest layer needed to run ? would 32, 2304, 3840 be enough, since we are buffering the full layer here, need to make sure all layers passing.
ba0da34
to
17cfe25
Compare
…bally_allocated address.
…tep: choose the correct NOC based on which core is next in the ring. Also, clean up the test.
…lity such as fuse_op and bias. TODO: test batch.
…ure they are not being used.
…re grid configuration.
17cfe25
to
03647bd
Compare
Ticket
Problem description
Currently, matmul 1d supports mcast_in0, for when the input is sharded across all the cores. However in some cases (specifically, the case of the matmuls used in the Llama models), this poses a bottleneck to the compute, as each core must wait to receive each shard prior to processing it.
To combat this, a new option to matmul1d is proposed:
gather_in0
. Using this, the activations are gathered using a ring all-gather operation. This allows each core to start processing the local activation shard that is already available, and process other shards as soon as they arrive. Essentially, this overlaps the time taken to gather the activation and to do the computation.For the FF1 matmul in llama, (M, K, N = 32, 2304, 3840), this new matmul1d takes 10us (gather in0, in1 sharded, w/ hack to enable full dest in fp32_accum), compared to the 36us (mcast in0, in1 read from dram) measured before.
See issue for a diagram of how the inputs are gathered.
What's changed
This PR adds the following changes:
gather_in0
flag, defaulted to false, in theMatmulMultiCoreReuseMultiCast1DProgramConfig
header and pybindgather_in0=True
i. a ring gather in0 kernel
ii. an in1 kerenel
iii. a bmm kernel (close copy of the existing one)
matmul_op.cpp
for the new gather_in0 casetest_matmul_1d_gathered.py
to test the new matmul configurationCaveats
Remaining TODOs:
Checklist