Skip to content

Commit

Permalink
ggml : implementation of xPos RoPE (#441); also extends ggml_rope_bac…
Browse files Browse the repository at this point in the history
…k with additional parameters (breaking API change); does not include CUDA version (#442)
  • Loading branch information
jploski authored Aug 22, 2023
1 parent f03d74a commit 896b089
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 18 deletions.
15 changes: 14 additions & 1 deletion include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,15 @@ extern "C" {
float freq_base,
float freq_scale);

// xPos RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
float scale_base,
bool downscale);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
Expand All @@ -1225,7 +1234,11 @@ extern "C" {
int n_past,
int n_dims,
int mode,
int n_ctx);
int n_ctx,
float freq_base,
float freq_scale,
float xpos_base,
bool xpos_downscale);

// alibi position embedding
// in-place, returns view(a)
Expand Down
102 changes: 85 additions & 17 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6715,6 +6715,8 @@ static struct ggml_tensor * ggml_rope_impl(
int n_ctx,
float freq_base,
float freq_scale,
float xpos_base,
bool xpos_downscale,
bool inplace) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;
Expand All @@ -6725,9 +6727,11 @@ static struct ggml_tensor * ggml_rope_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

int32_t params[6] = { n_past, n_dims, mode, n_ctx };
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float));
memcpy(params + 7, &xpos_downscale, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_ROPE;
Expand All @@ -6744,7 +6748,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
}

struct ggml_tensor * ggml_rope_inplace(
Expand All @@ -6754,7 +6758,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
}

struct ggml_tensor * ggml_rope_custom(
Expand All @@ -6766,7 +6770,7 @@ struct ggml_tensor * ggml_rope_custom(
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
}

struct ggml_tensor * ggml_rope_custom_inplace(
Expand All @@ -6778,7 +6782,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
}

struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
float scale_base,
bool downscale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, scale_base, downscale, true);
}

// ggml_rope_back
Expand All @@ -6789,7 +6803,11 @@ struct ggml_tensor * ggml_rope_back(
int n_past,
int n_dims,
int mode,
int n_ctx) {
int n_ctx,
float freq_base,
float freq_scale,
float xpos_base,
bool xpos_downscale) {
GGML_ASSERT(n_past >= 0);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");

Expand All @@ -6801,7 +6819,11 @@ struct ggml_tensor * ggml_rope_back(

struct ggml_tensor * result = ggml_dup_tensor(ctx, a);

int32_t params[] = { n_past, n_dims, mode, n_ctx };
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float));
memcpy(params + 7, &xpos_downscale, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_ROPE_BACK;
Expand Down Expand Up @@ -12065,7 +12087,6 @@ static void ggml_compute_forward_alibi(
}
}


// ggml_compute_forward_clamp

static void ggml_compute_forward_clamp_f32(
Expand Down Expand Up @@ -12154,12 +12175,18 @@ static void ggml_compute_forward_rope_f32(
float freq_base;
float freq_scale;

// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_downscale;

const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool));

assert(n_past >= 0);

Expand Down Expand Up @@ -12231,6 +12258,9 @@ static void ggml_compute_forward_rope_f32(
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
if (xpos_downscale) zeta = 1.0f / zeta;

theta *= theta_scale;

Expand All @@ -12240,8 +12270,8 @@ static void ggml_compute_forward_rope_f32(
const float x0 = src[0];
const float x1 = src[1];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else {
// TODO: this is probably wrong, but I can't figure it out ..
Expand Down Expand Up @@ -12435,9 +12465,21 @@ static void ggml_compute_forward_rope_back_f32(
// dx = rope_back(dy, src1)
// src0 is dy, src1 contains options

float freq_base;
float freq_scale;

// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_downscale;

const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool));

assert(n_past >= 0);

Expand All @@ -12463,7 +12505,7 @@ static void ggml_compute_forward_rope_back_f32(
// row index used to determine which thread to use
int ir = 0;

const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f/n_dims);

const bool is_neox = mode & 2;

Expand All @@ -12474,12 +12516,15 @@ static void ggml_compute_forward_rope_back_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = (float)p;
float theta = freq_scale * (float)p;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
if (xpos_downscale) zeta = 1.0f / zeta;

theta *= theta_scale;

Expand All @@ -12489,8 +12534,8 @@ static void ggml_compute_forward_rope_back_f32(
const float dy0 = dy[0];
const float dy1 = dy[1];

dx[0] = dy0*cos_theta + dy1*sin_theta;
dx[1] = - dy0*sin_theta + dy1*cos_theta;
dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
Expand Down Expand Up @@ -15967,14 +16012,25 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
float freq_base, freq_scale, xpos_base;
bool xpos_downscale;
memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool));

src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
n_past,
n_dims,
mode,
n_ctx),
n_ctx,
freq_base,
freq_scale,
xpos_base,
xpos_downscale),
inplace);
}
} break;
Expand All @@ -15985,14 +16041,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
float freq_base, freq_scale, xpos_base;
bool xpos_downscale;
memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool));

src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope(ctx,
ggml_rope_impl(ctx,
tensor->grad,
n_past,
n_dims,
mode,
n_ctx),
n_ctx,
freq_base,
freq_scale,
xpos_base,
xpos_downscale,
false),
inplace);
}
} break;
Expand Down
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,12 @@ if (MSVC)
endif()
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")

#
# test-xpos

set(TEST_TARGET test-xpos)
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
87 changes: 87 additions & 0 deletions tests/test-xpos.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "ggml/ggml.h"

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

bool is_close(float a, float b, float epsilon) {
return fabs(a - b) < epsilon;
}

int main(int argc, char ** argv) {
const int n_threads = 1;
const int n_embd_head = 4; // aka head_dim
const int n_head = 1;
const int N = 8;

struct ggml_init_params params = {
.mem_size = 16*1024*1024,
.mem_buffer = NULL,
};

// memory allocation happens here
struct ggml_context * ctx = ggml_init(params);

struct ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
struct ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);

for (int i = 0; i < ggml_nelements(Q); i++) {
((float*) Q->data)[i] = 2.0f;
((float*) K->data)[i] = 2.0f;
}

struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, 1, n_embd_head, 512.0f, false);
struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, 1, n_embd_head, 512.0f, true);

struct ggml_cgraph gf = ggml_build_forward(Qx);
ggml_build_forward_expand(&gf, Kx);
ggml_graph_compute_with_ctx(ctx, &gf, n_threads);

// expected output for Qx:
// -0.6009 2.7568 1.9782 2.0182
// -2.6379 0.9815 1.9562 2.0361
// -2.2457 -1.6853 1.9341 2.0538
// 0.2043 -2.7934 1.9118 2.0712
// 2.4550 -1.3341 1.8894 2.0884
// 2.4430 1.3417 1.8668 2.1054
// 0.1905 2.7739 1.8440 2.1221
// -2.2257 1.6550 1.8212 2.1386

for (int i = 0; i < ggml_nelements(Q); i++) {
if (((float*) Qx->data)[i] > 0) printf(" ");
printf("%.4f ", ((float*) Qx->data)[i]);
if ((i+1) % n_embd_head == 0) printf("\n");
}
printf("\n");

GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 0], -2.2257f, 0.0001f));
GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 1], 1.6550f, 0.0001f));
GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 2], 1.8212f, 0.0001f));
GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 3], 2.1386f, 0.0001f));

// expected output for Kx:
// -0.6038 2.7703 1.9816 2.0216
// -2.6639 0.9911 1.9630 2.0431
// -2.2789 -1.7103 1.9441 2.0644
// 0.2083 -2.8486 1.9251 2.0856
// 2.5158 -1.3671 1.9057 2.1065
// 2.5158 1.3816 1.8862 2.1273
// 0.1972 2.8705 1.8665 2.1479
// -2.3146 1.7211 1.8465 2.1684

for (int i = 0; i < ggml_nelements(K); i++) {
if (((float*) Kx->data)[i] > 0) printf(" ");
printf("%.4f ", ((float*) Kx->data)[i]);
if ((i+1) % n_embd_head == 0) printf("\n");
}
printf("\n");

GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 0], -2.3146f, 0.0001f));
GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 1], 1.7211f, 0.0001f));
GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 2], 1.8465f, 0.0001f));
GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 3], 2.1684f, 0.0001f));

ggml_free(ctx);

return 0;
}

0 comments on commit 896b089

Please sign in to comment.