Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler crash when sharding model weights #1030

Open
Corendos opened this issue Nov 9, 2024 · 1 comment
Open

Compiler crash when sharding model weights #1030

Corendos opened this issue Nov 9, 2024 · 1 comment
Labels
bug Something isn't working Inf2

Comments

@Corendos
Copy link

Corendos commented Nov 9, 2024

Hi !

I was playing with JAX on Neuron recently and came across a bug that is quite annoying.

When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.

Here is a snippet of code that demonstrates the issue:

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

mesh = jax.sharding.Mesh(jax.devices(), ('x'))

BATCH_SIZE = 1
SIZE = 4096
HIDDEN_SIZE = 8192
DTYPE = jnp.bfloat16
SHARD_ON_HIDDEN = True

def mlp(x: jax.Array, gate_up_proj: jax.Array, down_proj: jax.Array) -> jax.Array:
    # x: (B, D)
    # gate_up_proj: (D, 2 * H)
    # down_proj: (H, D)
    hidden = jax.lax.dot_general(x, gate_up_proj, (([1], [0]), ([],[])))
    # hidden: (B, 2 * H)
    x1, x2 = jnp.split(hidden, 2, 1)
    hidden = jax.nn.gelu(x1) * x2
    # hidden: (B, H)
    hidden = jax.lax.dot_general(hidden, down_proj, (([1], [1]), ([],[])))
    return hidden

if SHARD_ON_HIDDEN:
    weight_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
else:
    weight_sharding = jax.sharding.NamedSharding(mesh, P('x', None))

x = jax.ShapeDtypeStruct(shape=(BATCH_SIZE, SIZE), dtype=DTYPE, sharding=jax.sharding.NamedSharding(mesh, P(None, None)))
gate_up_proj = jax.ShapeDtypeStruct((SIZE, 2 * HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)
down_proj = jax.ShapeDtypeStruct((SIZE, HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)

lowered = jax.jit(mlp).lower(x, gate_up_proj, down_proj)
print(lowered.as_text())

compiled = lowered.compile()
print(compiled.as_text())

If SHARD_ON_HIDDEN is set to False everything works fine but when it's set to True (which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).

Here is the error:

2024-Nov-09 16:14:06.605467 66954:67403 ERROR   ENC:enc_parse_replica_groups                [nec_dev 2] replica groups (0/1) does not have myself 2
2024-Nov-09 16:14:06.605522 66954:67403 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.605546 66954:67403 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.605670 66954:67403 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.605695 66954:67403 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.605715 66954:67403 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.605735 66954:67403 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.605981 66954:67403 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.606000 66954:67403 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.606012 66954:67403 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.606025 66954:67403 ERROR   NRT:nrt_infodump                            Neuron runtime information - please include in any support request:
2024-Nov-09 16:14:06.606040 66954:67403 ERROR   NRT:nrt_infodump                            ------------->8------------[ cut here ]------------>8-------------
2024-Nov-09 16:14:06.606062 66954:67403 ERROR   NRT:nrt_infodump                            NRT version: 2.22.14.0 (6e27b8d5b22dea0e0b8375517f4d8a009b6de5a8)
2024-Nov-09 16:14:06.606084 66954:67403 ERROR   NRT:nrt_infodump                            Embedded FW version: 1.12.2.0 (f152b70c827a52701d6b9ee74ec7ff7a15971f7d)
2024-Nov-09 16:14:06.606112 66954:67403 ERROR   NRT:nrt_infodump                            CCOM version: 2.22.26.0- (compat 48)
2024-Nov-09 16:14:06.606134 66954:67403 ERROR   NRT:nrt_infodump                            Instance ID: i-0b19e4a1cf3fd70d9
2024-Nov-09 16:14:06.606156 66954:67403 ERROR   NRT:nrt_infodump                            Cluster ID: N/A
2024-Nov-09 16:14:06.606178 66954:67403 ERROR   NRT:nrt_infodump                            Kernel: Linux 6.8.0-1015-aws #16~22.04.1-Ubuntu SMP Mon Aug 19 19:38:17 UTC 2024
2024-Nov-09 16:14:06.606200 66954:67403 ERROR   NRT:nrt_infodump                            Nodename: ip-172-31-42-39
2024-Nov-09 16:14:06.606254 66954:67403 ERROR   NRT:nrt_infodump                            Driver version: 2.18.12.0

2024-Nov-09 16:14:06.606276 66954:67403 ERROR   NRT:nrt_infodump                            Failure: NRT_RESOURCE in nrt_load()
2024-Nov-09 16:14:06.606298 66954:67403 ERROR   NRT:nrt_infodump                            Visible cores: 0, 1, 2, 3
2024-Nov-09 16:14:06.606318 66954:67403 ERROR   NRT:nrt_infodump                            Environment:
2024-Nov-09 16:14:06.606341 66954:67403 ERROR   NRT:nrt_infodump                                NEURON_CC_FLAGS=--model-type=transformer --auto-cast=none
2024-Nov-09 16:14:06.606362 66954:67403 ERROR   NRT:nrt_infodump                                NEURON_RT_NUM_CORES=4
2024-Nov-09 16:14:06.606382 66954:67403 ERROR   NRT:nrt_infodump                                NEURON_RT_ROOT_COMM_ID=localhost:49255
2024-Nov-09 16:14:06.606401 66954:67403 ERROR   NRT:nrt_infodump                            -------------8<-----------[ cut to here ]-----------8<------------
2024-Nov-09 16:14:06.602248 66954:67406 ERROR   ENC:enc_parse_replica_groups                [nec_dev 1] replica groups (0/1) does not have myself 1
2024-Nov-09 16:14:06.603586 66954:67404 ERROR   ENC:enc_parse_replica_groups                [nec_dev 3] replica groups (0/1) does not have myself 3
2024-Nov-09 16:14:06.614656 66954:67406 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.626936 66954:67404 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.639079 66954:67406 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.649859 66954:67404 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.662122 66954:67406 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.672477 66954:67404 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.682738 66954:67406 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.692175 66954:67404 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.702800 66954:67406 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.712256 66954:67404 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.723884 66954:67406 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.751248 66954:67405 ERROR   ENC:enc_parse_replica_groups                [nec_dev 0] replica groups (0/1) does not have myself 0
2024-Nov-09 16:14:06.751330 66954:67405 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.751345 66954:67405 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.751929 66954:67405 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.751957 66954:67405 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.751974 66954:67405 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.751990 66954:67405 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.753195 66954:67405 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.753230 66954:67405 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.753248 66954:67405 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.735901 66954:67404 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.749448 66954:67406 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.762705 66954:67404 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.775586 66954:67406 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.789108 66954:67404 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.799686 66954:67406 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.810033 66954:67404 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4

Environment

  • Python 3.10
  • Packages:
    • neuronx-cc==2.15.141.0+d3cfc8ca
    • libneuronxla==2.0.4986.0
    • jaxlib==0.4.31
    • jax-neuronx==0.1.1
    • jax==0.4.31
  • inf2.48xlarge instance

Thanks for the help !

@AWSNB
Copy link
Contributor

AWSNB commented Nov 9, 2024 via email

@aws-taylor aws-taylor added bug Something isn't working Inf2 labels Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Inf2
Projects
None yet
Development

No branches or pull requests

3 participants