Skip to content

Commit

Permalink
ggml : improve ADD_REL_POS perf in SAM by doing it inplace + broadcas…
Browse files Browse the repository at this point in the history
…t BLAS mul_mat (#466)

* Improve ADD_REL_POS perf in SAM by doing it inplace

- Add unit tests for the ADD_REL_POS operation
- I am not sure if this is valid implementation as we reuse the src0
  memory in order to avoid copying it
- When running SAM with the "Example output" command, image, point and
  16 threads, this reduces the cumulative time of the ADD_REL_POS operation
  from 1000-1100 ms to 180-200ms
- There is further room for optimization in the access patterns used in
  the implementation of the opration

* Add non-inplace version for the GGML_OP_ADD_REL_POS

* Fix map_unary warnings and refactor LayerNorm2d + remove ggml_cont in it

* Fix Mac printf format warnings

* sam : add ggml_graph_print() comment

* ggml : add broadcast support for BLAS ggml_mul_mat() (#460)

* Remove not needed build_forward_expand from add-rel-pos unit test

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
Yavor Ivanov and ggerganov authored Aug 21, 2023
1 parent 08c57df commit 170388d
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 183 deletions.
199 changes: 67 additions & 132 deletions examples/sam/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,41 @@ struct sam_image_f32 {
std::vector<float> data;
};

void ggml_sam_sin(const int n, float * dst, const float * src) {
for (int i = 0; i < n; ++i) {
dst[i] = sinf(src[i]);
void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(userdata == NULL);
GGML_ASSERT(ggml_are_same_shape(dst, src));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src));

const float * src_data = ggml_get_data_f32(src);
float * dst_data = ggml_get_data_f32(dst);

const int ne = (int)ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = std::min(ie0 + dr, ne);

for (int i = ie0; i < ie1; ++i) {
dst_data[i] = sinf(src_data[i]);
}
}

void ggml_sam_cos(const int n, float * dst, const float * src) {
for (int i = 0; i < n; ++i) {
dst[i] = cosf(src[i]);
void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(userdata == NULL);
GGML_ASSERT(ggml_are_same_shape(dst, src));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src));

const float * src_data = ggml_get_data_f32(src);
float * dst_data = ggml_get_data_f32(dst);

const int ne = (int)ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = std::min(ie0 + dr, ne);

for (int i = ie0; i < ie1; ++i) {
dst_data[i] = cosf(src_data[i]);
}
}

Expand Down Expand Up @@ -888,13 +914,6 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
}
}

// key + value memory
{
// const auto & hparams = model.hparams;

// TODO
}

// load weights
{
int n_tensors = 0;
Expand Down Expand Up @@ -1037,8 +1056,8 @@ bool sam_fill_dense_pe(
// concat
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
{
struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);

cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);

Expand All @@ -1059,6 +1078,28 @@ bool sam_fill_dense_pe(
return true;
}

struct ggml_tensor* sam_layer_norm_2d(
struct ggml_context * ctx0,
struct ggml_tensor * layer,
int n_channels,
struct ggml_tensor * w,
struct ggml_tensor * b) {
// LayerNorm2d
// normalize along channel dimmension
// TODO: better implementation
layer = ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3))),
2, 0, 1, 3);

layer = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer),
layer),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer));

return layer;
}

bool sam_encode_image(
const sam_model & model,
sam_state & state,
Expand Down Expand Up @@ -1228,7 +1269,7 @@ bool sam_encode_image(
0, 2, 1, 3));
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);

struct ggml_tensor * attn = ggml_add_rel_pos(ctx0, KQ_scaled, rel_w, rel_h);
struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);

struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);

Expand Down Expand Up @@ -1306,37 +1347,11 @@ bool sam_encode_image(

cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);

// LayerNorm2d
{
// normalize along channel dimmension
// TODO: better implementation
cur = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3))),
2, 0, 1, 3));

cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_0_w, 1, 1, n_enc_out_chans), cur),
cur),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_0_b, 1, 1, n_enc_out_chans), cur));
}
cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b);

cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);

// LayerNorm2d
{
// normalize along channel dimmension
// TODO: better implementation
cur = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3))),
2, 0, 1, 3));

cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_1_w, 1, 1, n_enc_out_chans), cur),
cur),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_1_b, 1, 1, n_enc_out_chans), cur));
}
cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b);

// TODO: avoid copy
cur = ggml_cpy(ctx0, cur, state.embd_img);
Expand All @@ -1349,6 +1364,8 @@ bool sam_encode_image(
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

//ggml_graph_print(&gf);

ggml_free(ctx0);
return true;
}
Expand Down Expand Up @@ -1423,8 +1440,8 @@ bool sam_encode_prompt(
// concat
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
{
struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);

cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]);

Expand Down Expand Up @@ -1462,74 +1479,6 @@ bool sam_encode_prompt(
// run the computation
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// print
{
// auto print_t_f32 = [&](struct ggml_tensor * t) {
// float * data = (float *)t->data;
// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
// printf("data: ");
// for (int i = 0; i < std::min((int) t->ne[0], 256); i++) {
// printf("%f ", data[i]);
// }
// printf("\n");
// //for (int y = 0; y < 64; ++y) {
// // for (int x = 0; x < 64; ++x) {
// // printf("%5.2f ", data[y*64 + x]);
// // }
// // printf("\n");
// //}
// //printf("\n");
// // for (int y = 0; y < 64; ++y) {
// // for (int x = 0; x < 64; ++x) {
// // printf("%5.2f ", data[255*64*64 + y*64 + x]);
// // }
// // printf("\n");
// // }
// // printf("\n");
// //for (int y = 0; y < 64; ++y) {
// // for (int x = 0; x < 64; ++x) {
// // printf("%5.2f ", data[(y*64 + x)*768 + 231]);
// // }
// // printf("\n");
// //}
// //printf("\n");
// double sum = 0.0;
// for (int i = 0; i < ggml_nelements(t); i++) {
// sum += data[i];
// }
// printf("sum: %f\n", sum);
// };

// auto print_t_f16 = [&](struct ggml_tensor * t) {
// ggml_fp16_t * data = (ggml_fp16_t *)t->data;
// printf("dims: %jd %jd %jd %jd f16\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
// printf("data: ");
// for (int i = 0; i < std::min((int) t->ne[0], 256); i++) {
// printf("%f ", ggml_fp16_to_fp32(data[i]));
// }
// printf("\n");
// for (int y = 0; y < 14; ++y) {
// for (int x = 0; x < 14; ++x) {
// printf("%7.4f ", ggml_fp16_to_fp32(data[(y*14 + x)*64 + 23]));
// }
// printf("\n");
// }
// printf("\n");
// double sum = 0.0;
// for (int i = 0; i < ggml_nelements(t); i++) {
// sum += ggml_fp16_to_fp32(data[i]);
// }
// printf("sum: %f\n", sum);
// };

// auto * t = ggml_get_tensor(ctx0, "check");
// if (t->type == GGML_TYPE_F32) {
// print_t_f32(t);
// } else {
// print_t_f16(t);
// }
}

//printf("used_mem = %zu\n", ggml_used_mem(ctx0));

ggml_free(ctx0);
Expand Down Expand Up @@ -1595,7 +1544,7 @@ struct ggml_tensor* sam_decode_mask_transformer_attn(

struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV));
KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3));
KQV_merged = ggml_cont(ctx0, ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]));
KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]);
KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged);
KQV_merged = ggml_add(ctx0,
ggml_repeat(ctx0, attn.out_b, KQV_merged),
Expand Down Expand Up @@ -1859,21 +1808,7 @@ bool sam_decode_mask(
// ConvTranspose2d
keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_0_b, keys), keys);

// LayerNorm2d
{
// normalize along channel dimmension
// TODO: better implementation
keys = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, keys, 1, 2, 0, 3))),
2, 0, 1, 3));

keys = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_1_w, 1, 1, n_img_embd), keys),
keys),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_1_b, 1, 1, n_img_embd), keys));
}
keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b);

// GELU activation
keys = ggml_gelu(ctx0, keys);
Expand All @@ -1898,7 +1833,7 @@ bool sam_decode_mask(

struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding);
masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed
masks = ggml_cont(ctx0, ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]));
masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]);

// Generate mask quality predictions
// ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146
Expand Down Expand Up @@ -1941,7 +1876,7 @@ bool sam_decode_mask(
bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
if (state.low_res_masks->ne[2] == 0) return true;
if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
printf("Error: number of masks (%jd) does not match number of iou predictions (%jd)\n", state.low_res_masks->ne[2], state.iou_predictions->ne[0]);
printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
return false;
}

Expand Down
7 changes: 7 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,12 +1384,19 @@ extern "C" {
int kh);

// used in sam

GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * pw,
struct ggml_tensor * ph);

GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * pw,
struct ggml_tensor * ph);

// custom operators

typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
Expand Down
Loading

0 comments on commit 170388d

Please sign in to comment.