Skip to content

Commit

Permalink
#8865: Optimize softmax dispatch time
Browse files Browse the repository at this point in the history
  • Loading branch information
nemanjagrujic committed Aug 27, 2024
1 parent 8209bb4 commit b05f91e
Showing 1 changed file with 45 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,29 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core(

uint32_t curr_row = 0;
union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax

auto& cached_reader_args = GetRuntimeArgs(program, reader_kernels_id);
auto& cached_softmax_args = GetRuntimeArgs(program, softmax_kernels_id);
auto& cached_writer_args = GetRuntimeArgs(program, writer_kernels_id);

for (uint32_t i = 0; i < grid_size.x * grid_size.y; ++i) {
CoreCoord core = {i % grid_size.x, i / grid_size.x};
uint32_t num_tile_rows_per_core = 0;

auto& reader_kernel_args = cached_reader_args.at(core.x).at(core.y);
auto& softmax_kernel_args = cached_softmax_args.at(core.x).at(core.y);
auto& writer_kernel_args = cached_writer_args.at(core.x).at(core.y);

if (i >= num_cores) {
SetRuntimeArgs(program, reader_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }); // [8]=1.0f is scaler
SetRuntimeArgs(program, softmax_kernels_id, core, { 0, 0, 0, 0, 0, 0 });
SetRuntimeArgs(program, writer_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0});
reader_kernel_args[3] = 0;
softmax_kernel_args[0] = 0;
writer_kernel_args[1] = 0;
// SetRuntimeArgs(program, reader_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }); // [8]=1.0f is scaler
// SetRuntimeArgs(program, softmax_kernels_id, core, { 0, 0, 0, 0, 0, 0 });
// SetRuntimeArgs(program, writer_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0});
continue;
}

uint32_t num_tile_rows_per_core = 0;
if (core_group_1.core_coord_in_core_ranges(core)) {
num_tile_rows_per_core = num_tile_rows_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Expand All @@ -387,15 +400,37 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core(
uint32_t mask_offset = curr_row / Ht * Wt * Wt; // causal mask batch offset
uint32_t mask_id = causal_mask ? (mask_curr_ht * Wt + mask_offset) : (curr_row / Ht * Wt); // causal mask start offset + causal mask batch offset

reader_kernel_args[0] = src_buffer_address;
reader_kernel_args[1] = block_size;
reader_kernel_args[2] = s.u;
reader_kernel_args[3] = num_tile_rows_per_core;
reader_kernel_args[4] = tile_offset;
reader_kernel_args[5] = Wt;
reader_kernel_args[6] = Ht;
reader_kernel_args[7] = mask_buffer_address;
reader_kernel_args[8] = curr_ht;
reader_kernel_args[9] = mask_id;
reader_kernel_args[10] = 0x3f803f80;

if (causal_mask) {
SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80, mask_curr_ht, mask_offset }); // [10]=1.0f is scaler
} else {
SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80 }); // [10]=1.0f is scaler
reader_kernel_args[11] = mask_curr_ht;
reader_kernel_args[12] = mask_offset;
}

SetRuntimeArgs(program, softmax_kernels_id, core, { num_tile_rows_per_core, Ht, Wt, block_size, curr_ht, mask_padded_data });

SetRuntimeArgs(program, writer_kernels_id, core, { dst_buffer_address, num_tile_rows_per_core * Wt, tile_offset, block_size, mask_padded_data, num_datum_padded, 0xFF00FF00});
softmax_kernel_args[0] = num_tile_rows_per_core;
softmax_kernel_args[1] = Ht;
softmax_kernel_args[2] = Wt;
softmax_kernel_args[3] = block_size;
softmax_kernel_args[4] = curr_ht;
softmax_kernel_args[5] = mask_padded_data;

writer_kernel_args[0] = dst_buffer_address;
writer_kernel_args[1] = num_tile_rows_per_core * Wt;
writer_kernel_args[2] = tile_offset;
writer_kernel_args[3] = block_size;
writer_kernel_args[4] = mask_padded_data;
writer_kernel_args[5] = num_datum_padded;
writer_kernel_args[6] = 0xFF00FF00;

curr_row += num_tile_rows_per_core;
}
Expand Down

0 comments on commit b05f91e

Please sign in to comment.