Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for new matmul1d op with gather_in0 #14964

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Conversation

avoraTT
Copy link
Contributor

@avoraTT avoraTT commented Nov 12, 2024

Ticket

Problem description

Currently, matmul 1d supports mcast_in0, for when the input is sharded across all the cores. However in some cases (specifically, the case of the matmuls used in the Llama models), this poses a bottleneck to the compute, as each core must wait to receive each shard prior to processing it.

To combat this, a new option to matmul1d is proposed: gather_in0. Using this, the activations are gathered using a ring all-gather operation. This allows each core to start processing the local activation shard that is already available, and process other shards as soon as they arrive. Essentially, this overlaps the time taken to gather the activation and to do the computation.

For the FF1 matmul in llama, (M, K, N = 32, 2304, 3840), this new matmul1d takes 10us (gather in0, in1 sharded, w/ hack to enable full dest in fp32_accum), compared to the 36us (mcast in0, in1 read from dram) measured before.

See issue for a diagram of how the inputs are gathered.

What's changed

This PR adds the following changes:

  1. A new gather_in0 flag, defaulted to false, in the MatmulMultiCoreReuseMultiCast1DProgramConfig header and pybind
  2. A new helper function in the matmul1d program factory, to specifically handle the case when gather_in0=True
  3. 3 new kernels
    i. a ring gather in0 kernel
    ii. an in1 kerenel
    iii. a bmm kernel (close copy of the existing one)
  4. Validation in matmul_op.cpp for the new gather_in0 case
  5. test_matmul_1d_gathered.py to test the new matmul configuration

Caveats

  • Inputs MUST be sharded, and be on the same cores. For it's intended use cases the DRAM prefetcher will be used to distribute the weights across the cores, so this does not degrade performance.
  • This op does not support bias, as it is not required in the llama use case and simplifies implementation
  • The existing bmm kernel is duplicated. Most of the structure remains the same, however there are changes to how the input CBs are read
  • This op supports sharding on arbitrary cores (not rectangluar) and uses the dynamic noc functionality to keep performance

Remaining TODOs:

  • Test on grayskull
  • Get official support for full dest when fp32_accum_mode = True (this is not necessary but leads to significant perf gains)
  • Create a PR that adds support for array inputs in CoreRangeSet, to retain ordering of arbitrary cores in a shard spec
  • Check perf for DRAM prefetcher core grid configuration

Checklist

  • Post commit CI passes
  • New/Existing tests provide coverage for changes

@avoraTT avoraTT added metal tt-metal issue LLMs on Metal labels Nov 12, 2024
@avoraTT avoraTT self-assigned this Nov 12, 2024
@@ -1659,8 +1989,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_(

if (fp32_dest_acc_en) {
TT_FATAL(
out_subblock_h * out_subblock_w <= 4,
"Total number of tiles in a subblock must be less than 4 when in fp32_dest_acc mode");
out_subblock_h * out_subblock_w <= 8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this to 8?

Isn't 16 cut in half because of half dest mode and in half again because of fp32?

If anything this could be removed since

                TT_FATAL(
                    (program_config.out_subblock_w * program_config.out_subblock_h) <= available_reg_count,
                    "out_subblock_w {} times out_subblock_h {} needs to be at most {} to fit in hardware",
                    program_config.out_subblock_w,
                    program_config.out_subblock_h,
                    available_reg_count);

has been added to validate() in matmul_op.cpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So one of the remaining todo's in the PR is getting support for full dest mode (see here). If it isn't merged before this PR, I can revert the hard coded value.

In my testing for the shapes used in the Llama models, I have found that we can actually use 8 here, and it results in significant speedups since there's no reload.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you enable full dest mode then that will need to be reflected in the config and get_dest_reg_count needs to return 8. You cannot use a constant.

Please either remove this test or make it do the same thing as in validate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like full dest mode support has already been added in main. However, the pybind for the compute kernel config has not been updated to allow users to enable it.

I've opened a PR here that updates the pybind, and I am currently waiting for confirmation from @amahmudTT that this change is correct.

.defines = mm_kernel_defines});

/* Create circular buffers */
uint32_t src0_cb_index = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just use the cb constants instead of raw numbers in the assignments.

Comment on lines +28 to +95
# 32, 2304, 3840
(1, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (8, 3)),
# 32, 2304, 3840
(3, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (8, 3)),
# 32, 2304, 3840
(3, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi, False, False, (8, 3)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has the exact arbitary shard grid been tested (ie, the cores placed near the dram banks)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has not been tested just yet. But I will do that and include results here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested with the exact core grid with core near the dram banks. Results are 👍 .

ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp Outdated Show resolved Hide resolved
Comment on lines 86 to 92
#ifdef MATMUL_DRAM_SHARDED
const bool is_worker_core = get_arg_val<uint32_t>(0) == 1;
// if not worker core, skip
if (not is_worker_core) {
return;
}
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be removed since it will never be triggered

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I can get rid of this. I can also get rid of all the other things in this bmm kernel that aren't being used (untilze out, fused bias, etc). Is this fine @bbradelTT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine.

Please add appropriate checks in validate() to ensure that the inputs are not dram sharded and the other things are not being used.

Comment on lines +49 to +69
for (uint32_t shard_cnt = 0; shard_cnt < ring_size; shard_cnt++) {

uint32_t curr_shard_write_addr = l1_write_addr_in0 + shard_size_bytes * shard_cnt;
uint64_t remote_curr_shard_write_addr = get_noc_addr(next_core_noc_x, next_core_noc_y, curr_shard_write_addr, noc);
uint32_t curr_shard_read_addr = shard_cnt == 0 ? local_shard_read_addr : l1_write_addr_in0 + shard_size_bytes * (shard_cnt - 1);


// Wait for signal from previous core that data has been added to this core's in0
noc_semaphore_wait_min(l1_signal_sem_addr, shard_cnt);

// Send data to next core
if (shard_cnt < ring_size - 1) { // Skip sending the last shard
noc_async_write(curr_shard_read_addr, remote_curr_shard_write_addr, shard_size_bytes, noc);

// Signal the next core that data is ready
noc_semaphore_inc(remote_signal_semaphore_addr, 1, noc);
}

// Do stuff for matmul fusion here
if (shard_cnt > 0) {
cb_push_back(cb_id_in2, shard_size_in_tiles);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have another issue tracking the support for back-pressure (use global CB)?

Comment on lines 25 to 54
uint32_t get_preferred_noc(const uint32_t src_x, const uint32_t dst_x, const tt_metal::Device* device) {
/*
NOC0: Preferred +x -> +y
NOC1: Preferred -y -> -x
*/
uint32_t MAX_X = device->grid_size().x;

// Get the wrapped distances
uint32_t dist_right = src_x < dst_x ? dst_x - src_x : MAX_X - src_x + dst_x;
uint32_t dist_left = src_x < dst_x ? src_x + MAX_X - dst_x : src_x - dst_x;

return dist_right < dist_left ? 0 : 1;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good try it out with the intended core arrangement (cores placed near the dram banks)

Copy link
Contributor

@yugaoTT yugaoTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @johanna-rock-tt what is the largest layer needed to run ? would 32, 2304, 3840 be enough, since we are buffering the full layer here, need to make sure all layers passing.

@avoraTT avoraTT marked this pull request as ready for review November 26, 2024 12:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants