diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3b828375e8adb..db9da24594cb2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5313,7 +5313,7 @@ template static __global__ void template static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) { const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par; @@ -5352,7 +5352,7 @@ static __global__ void mul_mat_vec_q( tmp[j] = warp_reduce_sum(tmp[j]); if (threadIdx.x == 0) { - dst[j*nrows_x + row] = tmp[j]; + dst[j*nrows_dst + row] = tmp[j]; } } } @@ -6828,7 +6828,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa template static void mul_mat_vec_q_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_y <= 4); @@ -6839,40 +6839,40 @@ static void mul_mat_vec_q_cuda( switch (ncols_y) { case 1: mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; case 2: mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; case 3: mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; case 4: mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; // case 5: // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); // break; // case 6: // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); // break; // case 7: // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); // break; // case 8: // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); // break; default: GGML_ASSERT(false); // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; } } @@ -8391,7 +8391,7 @@ static void ggml_cuda_op_mul_mat_q( CUDA_CHECK(cudaGetDevice(&id)); // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; switch (src0->type) { @@ -8525,58 +8525,70 @@ static void ggml_cuda_op_mul_mat_vec_q( const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + switch (src0->type) { case GGML_TYPE_Q4_0: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_cuda - (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream); + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; default: GGML_ASSERT(false); @@ -9909,7 +9921,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); } } else { - if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) { + if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);