Skip to content

Commit

Permalink
#13368: Fixup mesh_device when not passed FAKE_DEVICE
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Oct 23, 2024
1 parent a779882 commit a80cbf3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"mesh_device",
[{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)],
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_llama_vision_transformer_inference(mesh_device, use_program_cache, reset_seeds):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
)
@pytest.mark.parametrize(
"mesh_device",
[{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)],
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
)
@pytest.mark.parametrize(
"mesh_device",
[{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)],
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_llama_image_transformer_inference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"mesh_device",
[{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)],
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_seeds):
Expand Down

0 comments on commit a80cbf3

Please sign in to comment.