Skip to content

Commit

Permalink
chore: adding math tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Oct 5, 2024
1 parent 55abdee commit 4459e8d
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 96 deletions.
42 changes: 3 additions & 39 deletions jflux/math.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,21 @@
import typing

import jax
from flax import nnx
from chex import Array
from einops import rearrange
from jax import numpy as jnp


@typing.no_type_check
def attention(q: Array, k: Array, v: Array, pe: Array) -> Array:
# TODO (ariG23498): Change all usage of attention to use this function
q, k = apply_rope(q, k, pe)

# jax expects this shape
x = rearrange(x, "B H L D -> B L H D") # noqa
x = jax.nn.dot_product_attention(q, k, v)
x = rearrange(x, "B L H D -> B L (H D)") # reshape again
x = nnx.dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

return x


def rope(pos: Array, dim: int, theta: int) -> Array:
"""
Generate Rotary Position Embedding (RoPE) for positional encoding.
Args:
pos (Array): Positional values, typically a sequence of positions in an array format.
dim (int): The embedding dimension, which must be an even number.
theta (int): A scaling parameter for RoPE that controls the frequency range of rotations.
Returns:
Array: Rotary embeddings with cosine and sine components for each position and dimension.
"""

# Embedding dimension must be an even number
assert dim % 2 == 0

# Generate the RoPE embeddings
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = jnp.einsum("...n,d->...nd", pos, omega)
Expand All @@ -45,26 +25,10 @@ def rope(pos: Array, dim: int, theta: int) -> Array:


def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
"""
Apply RoPE to the input query and key tensors.
Args:
xq (Array): Query tensor.
xk (Array): Key tensor.
freqs_cis (Array): RoPE frequencies.
Returns:
tuple[Array, Array]: Query and key tensors with RoPE applied.
"""
# Reshape and typecast the input tensors
xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)

# Apply RoPE to the input tensors
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]

# Reshape and typecast the output tensors
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(
xk.dtype
)
38 changes: 21 additions & 17 deletions jflux/modules/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,36 @@
from jflux.math import rope


class Embed(nnx.Module):
"""
Embedding module for Positional Embeddings.
Args:
dim (int): Dimension of the embedding.
theta (int): theta parameter for the RoPE embedding
axes_dim (list[int]): List of axes dimensions.
Returns:
RoPE embeddings
"""

def __init__(self, dim: int, theta: int, axes_dim: list[int]) -> None:
class EmbedND(nnx.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim

def __call__(self, ids: Array) -> Array:
def forward(self, ids: Array) -> Array:
n_axes = ids.shape[-1]
emb = jnp.concat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
dim=-3,
)

return jnp.expand_dims(emb, 1)
return emb.unsqueeze(1)


# class Embed(nnx.Module):
# def __init__(self, dim: int, theta: int, axes_dim: list[int]) -> None:
# self.dim = dim
# self.theta = theta
# self.axes_dim = axes_dim

# def __call__(self, ids: Array) -> Array:
# n_axes = ids.shape[-1]
# emb = jnp.concat(
# [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
# axis=-3,
# )

# return jnp.expand_dims(emb, 1)


@partial(jax.jit, static_argnums=(1, 2, 3))
Expand Down
Empty file added tests/modules/test_layers.py
Empty file.
197 changes: 157 additions & 40 deletions tests/test_math.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,178 @@
import unittest

import numpy as np
import torch
import jax.numpy as jnp
import pytest

from jflux.math import apply_rope, attention, rope
from flux.math import rope as torch_rope
from flux.math import apply_rope as torch_apply_rope
from flux.math import attention as torch_attention

from jflux.math import rope as jax_rope
from jflux.math import apply_rope as jax_apply_rope
from jflux.math import attention as jax_attention

class TestAttentionMechanism(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.num_heads = 4
self.seq_len = 8
self.dim = 64
self.theta = 10000

self.q = jnp.ones((self.batch_size, self.num_heads, self.seq_len, self.dim))
self.k = jnp.ones((self.batch_size, self.num_heads, self.seq_len, self.dim))
self.v = jnp.ones((self.batch_size, self.num_heads, self.seq_len, self.dim))

class TestMath(np.testing.TestCase):
def test_rope(self):
pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
pos = jnp.repeat(pos, self.batch_size, axis=0)
B, L, H, D = (
2,
4,
2,
8,
) # Batch size, sequence length, number of heads, embedding dimension
theta = 10000

# Position indices (e.g., positions in the sequence)
np_positions = (
np.expand_dims(np.arange(L), 0).repeat(B, 1).astype(np.int32)
) # Shape: [B, L]
torch_positions = torch.from_numpy(np_positions).to(torch.int32)
jax_positions = jnp.array(np_positions, dtype=jnp.int32)

rope_output = rope(pos, self.dim, self.theta)
expected_shape = (self.batch_size, self.seq_len, self.dim // 2, 2, 2)
np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy())

self.assertEqual(
rope_output.shape, expected_shape, "rope function output shape is incorrect"
torch_pe = torch_rope(pos=torch_positions, dim=D, theta=theta)
jax_pe = jax_rope(
pos=jax_positions, dim=D, theta=theta
) # Shape: [B, L, D/2, 2, 2]

np.testing.assert_allclose(
np.array(jax_pe),
torch_pe.numpy(),
rtol=1e-5,
atol=1e-5,
)

@pytest.mark.xfail
def test_apply_rope(self):
pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
pos = jnp.repeat(pos, self.batch_size, axis=0)
B, L, H, D = (
2,
4,
2,
8,
) # Batch size, sequence length, number of heads, embedding dimension
theta = 10000

# Inputs
np_q = np.random.randn(B, H, L, D).astype(np.float32)
np_k = np.random.randn(B, H, L, D).astype(np.float32)
np_v = np.random.randn(B, H, L, D).astype(np.float32)

jax_q = jnp.array(np_q, dtype=jnp.float32)
jax_k = jnp.array(np_k, dtype=jnp.float32)
jax_v = jnp.array(np_v, dtype=jnp.float32)

torch_q = torch.from_numpy(np_q).to(torch.float32)
torch_k = torch.from_numpy(np_k).to(torch.float32)
torch_v = torch.from_numpy(np_v).to(torch.float32)

np.testing.assert_allclose(np.array(jax_q), torch_q.numpy())
np.testing.assert_allclose(np.array(jax_k), torch_k.numpy())
np.testing.assert_allclose(np.array(jax_v), torch_v.numpy())

# Position indices (e.g., positions in the sequence)
np_positions = np.repeat(np.expand_dims(np.arange(L), 0), repeats=B, axis=1)
torch_positions = torch.from_numpy(np_positions).to(torch.int32)
jax_positions = jnp.array(np_positions, dtype=jnp.int32)

np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy())

torch_pe = torch_rope(pos=torch_positions, dim=D, theta=theta)
jax_pe = jax_rope(
pos=jax_positions, dim=D, theta=theta
) # Shape: [B, L, D/2, 2, 2]

np.testing.assert_allclose(
np.array(jax_pe),
torch_pe.numpy(),
rtol=1e-5,
atol=1e-5,
)

torch_pe = torch_pe.unsqueeze(1).expand(
-1, H, -1, -1, -1, -1
) # Shape: [B, H, L, D//2, 2, 2]
jax_pe = jnp.repeat(jnp.expand_dims(jax_pe, axis=1), repeats=H, axis=1)

freqs_cis = rope(pos, self.dim, self.theta)
xq_out, xk_out = apply_rope(self.q, self.k, freqs_cis)
# Apply RoPE to q and k
torch_q_rotated, torch_k_rotated = torch_apply_rope(
xq=torch_q, xk=torch_k, freqs_cis=torch_pe
)
jax_q_rotated, jax_k_rotated = jax_apply_rope(
xq=jax_q, xk=jax_k, freqs_cis=jax_pe
)

self.assertEqual(
xq_out.shape, self.q.shape, "apply_rope xq output shape is incorrect"
np.testing.assert_allclose(
np.array(jax_q_rotated),
torch_q_rotated.numpy(),
rtol=1e-5,
atol=1e-5,
)
self.assertEqual(
xk_out.shape, self.k.shape, "apply_rope xk output shape is incorrect"
np.testing.assert_allclose(
np.array(jax_k_rotated),
torch_k_rotated.numpy(),
rtol=1e-5,
atol=1e-5,
)

@pytest.mark.xfail
def test_attention(self):
pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
pos = jnp.repeat(pos, self.batch_size, axis=0)
# def test_attention(self):
# # Generate random inputs
# np_input = np.random.randn(2, 32, 4, 4).astype(np.float32)
# jax_input = jnp.array(np_input, dtype=jnp.float32)
# torch_input = torch.from_numpy(np_input).to(torch.float32)

freqs_cis = rope(pos, self.dim, self.theta)
attention_output = attention(self.q, self.k, self.v, freqs_cis)
# np.testing.assert_allclose(np.array(jax_input), torch_input.numpy())

expected_shape = (self.batch_size, self.seq_len, self.num_heads * self.dim)
# # Forward pass
# torch_output = torch_downsample(torch_input)
# jax_output = jax_downsample(rearrange(jax_input, "b c h w -> b h w c"))

self.assertEqual(
attention_output.shape,
expected_shape,
"attention function output shape is incorrect",
)
# # Assertions
# np.testing.assert_allclose(
# np.array(rearrange(jax_output, "b h w c -> b c h w")),
# torch_output.detach().numpy(),
# rtol=1e-5,
# atol=1e-5,
# )

# def test_rope(self):
# pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
# pos = jnp.repeat(pos, self.batch_size, axis=0)

# rope_output = rope(pos, self.dim, self.theta)
# expected_shape = (self.batch_size, self.seq_len, self.dim // 2, 2, 2)

# self.assertEqual(
# rope_output.shape, expected_shape, "rope function output shape is incorrect"
# )

# @pytest.mark.xfail
# def test_apply_rope(self):
# pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
# pos = jnp.repeat(pos, self.batch_size, axis=0)

# freqs_cis = rope(pos, self.dim, self.theta)
# xq_out, xk_out = apply_rope(self.q, self.k, freqs_cis)

# self.assertEqual(
# xq_out.shape, self.q.shape, "apply_rope xq output shape is incorrect"
# )
# self.assertEqual(
# xk_out.shape, self.k.shape, "apply_rope xk output shape is incorrect"
# )

# @pytest.mark.xfail
# def test_attention(self):
# pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
# pos = jnp.repeat(pos, self.batch_size, axis=0)

# freqs_cis = rope(pos, self.dim, self.theta)
# attention_output = attention(self.q, self.k, self.v, freqs_cis)

# expected_shape = (self.batch_size, self.seq_len, self.num_heads * self.dim)

# self.assertEqual(
# attention_output.shape,
# expected_shape,
# "attention function output shape is incorrect",
# )

0 comments on commit 4459e8d

Please sign in to comment.