Skip to content

Commit

Permalink
refine heuristic mode inside
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang committed Mar 1, 2022
1 parent 3301314 commit 090d5c1
Showing 1 changed file with 4 additions and 34 deletions.
38 changes: 4 additions & 34 deletions driver/igemm_bwd_gtc_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ class igemm_bwd_gtc_t : public igemm_driver_base_t{
return result;
}

if(this->driver_mode == driver_mode_heuristic)
current_gks = tunable->gemm_k_global_split ? current_gks : 0;

int hi = arg->get_int("in_h");
int wi = arg->get_int("in_w");
int n = arg->get_int("batchsize");
Expand Down Expand Up @@ -840,7 +843,7 @@ class igemm_bwd_gtc_t : public igemm_driver_base_t{
result_t result;
result.kernel_name = kernel_name;

if(this->driver_mode == driver_mode_normal){
if(true){
float min_duration = FLT_MAX;
float duration = .0;
int selected_gks = 0;
Expand Down Expand Up @@ -955,39 +958,6 @@ class igemm_bwd_gtc_t : public igemm_driver_base_t{
result.return_code = 0;
result.duration_ms = min_duration;
result.gks = selected_gks;
}else if(this->driver_mode == driver_mode_heuristic){
int gks = tunable->gemm_k_global_split ? current_gks : 0; // sync with is_tunable_predicted
size_t grid_size = get_grid_size(arg, tunable) * (1 << gks);
if(tunable->multihead){
if(tunable->tensor_layout == "nhwc"){
int gemm_m = n * h_tilda_slice * w_tilda_slice;
int gemm_n = c / group;
// This is hacky, but in MIOpen we prefer a heuristic way to set gks, so ok now.
igemm_bwd_gtc_nhwc_karg_t *karg = (igemm_bwd_gtc_nhwc_karg_t *)(karg_buffer);
magic_div_u32_t mdiv_x_tilda = magic_div_u32_gen(x_tilda);
magic_div_u32_t mdiv_y_tilda = magic_div_u32_gen(y_tilda);
magic_div_u32_t mdiv_group_mn = magic_div_u32_gen(group * utility_integer_divide_ceil(gemm_n, gemm_n_per_block) * utility_integer_divide_ceil(gemm_m, gemm_m_per_block));
karg->dtile_iy = num_of_gemm > 1 ? mdiv_x_tilda.magic : 0;
karg->dtile_ix = num_of_gemm > 1 ? mdiv_x_tilda.shift : 0;
karg->dslice_y = num_of_gemm > 1 ? mdiv_y_tilda.magic : y;
karg->dslice_x = num_of_gemm > 1 ? mdiv_y_tilda.shift : x;
karg->dtile_h = num_of_gemm > 1 ? mdiv_group_mn.magic : h_tilda;
karg->dtile_w = num_of_gemm > 1 ? mdiv_group_mn.shift : w_tilda;
karg->ks = gks;
}else{
assert(0);
}

float duration = igemm_launch_kernels({
{kernel_func, karg_buffer, karg_size, {grid_size * block_size, splits, 1}, {block_size, 1, 1}}
}, bwd_prolog, bwd_postlog, this->warmup, this->repeat);

result.return_code = 0;
result.duration_ms = duration;
result.gks = gks;
}else{
assert(0); // to be supported
}
}

#ifdef IGEMM_SPLIT_KERNEL
Expand Down

0 comments on commit 090d5c1

Please sign in to comment.