Skip to content

Commit

Permalink
test proposel for segment ids that fails.
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 21, 2024
1 parent 0c79164 commit 0608900
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -98,6 +104,36 @@ def test_flash_attention_backward_spmd_data_parallel(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
jax.config.update('jax_default_matmul_precision', "highest")
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))

q = torch.randn(16, 32, 2048, 64).to("xla")
k = torch.randn(16, 32, 128, 64).to("xla")
v = torch.randn(16, 32, 128, 64).to("xla")
q_segment_ids = torch.ones(16, 2048, dtype=torch.float32).to("xla")
kv_segment_ids = torch.zeros(16, 1, 128, dtype=torch.float32).to("xla")
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
kv_segment_ids[:8, :, 30:] = -10000.0
kv_segment_ids[8:, :, 60:] = -10000.0

o = flash_attention(
q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(n_devices))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")

attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
attention_mask = attention_mask.view(16, 32, 1, 128)

expected_o = self._attention(q, k, v, attn_mask=attention_mask)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 0608900

Please sign in to comment.