-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#11512: Improve isfinite sharded sweeps
- Loading branch information
1 parent
4020280
commit 31a517b
Showing
2 changed files
with
175 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import ttnn | ||
import itertools | ||
import random | ||
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import _gen_reshape_args_from_volume | ||
|
||
|
||
def gen_sharded_spec_unary(num_shapes, y, x, max_tensor_size=4 * 1024 * 1024): | ||
# ["BLOCK", "WIDTH", "HEIGHT", "tensor_wh"] | ||
sharding_strategy_list = ["BLOCK", "WIDTH", "HEIGHT", "tensor_wh"] | ||
shard_orientation_list = ["COL_MAJOR", "ROW_MAJOR"] | ||
spec_list = [] | ||
|
||
for sharding_strategy, shard_orientation, rank, layout in itertools.product( | ||
sharding_strategy_list, shard_orientation_list, [4, 3, 2], ["TILE_LAYOUT", "ROW_MAJOR_LAYOUT"] | ||
): | ||
if sharding_strategy == "tensor_wh": | ||
tensor_hw_as_shard_shape = True | ||
sharding_strategy = "BLOCK" | ||
else: | ||
tensor_hw_as_shard_shape = False | ||
|
||
for _ in range(num_shapes): | ||
if tensor_hw_as_shard_shape: | ||
# Gets stuck: | ||
# X 8 Y 8 input_shape [1, 17792, 8] DataType.BFLOAT8_B Layout.TILE ShardStrategy.BLOCK ShardOrientation.COL_MAJOR tensor_hw_as_shard_shape True | ||
|
||
if layout == "TILE_LAYOUT": | ||
# In shard mode ShardMode::PHYSICAL, physical shard shape {12, 13312} is not compatible with alignment Alignment([32, 32])! | ||
min_shard_size_x = 32 | ||
min_shard_size_y = 32 | ||
else: # if layout == "ROW_MAJOR_LAYOUT": | ||
# Shard Size must be multiple of input_tile_size (width * height is multiple of 1024) | ||
min_shard_size_x = random.choice([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]) | ||
min_shard_size_y = 1024 // min_shard_size_x | ||
|
||
rest_volume = random.randint(1, max_tensor_size // (min_shard_size_x * min_shard_size_y * x * y)) | ||
input_shape = random.choice(_gen_reshape_args_from_volume(rest_volume, step=1, out_dims=rank)) | ||
input_shape = list(input_shape["reshape_dims"]) | ||
input_shape[-2] = input_shape[-2] * min_shard_size_x | ||
input_shape[-1] = input_shape[-1] * min_shard_size_y | ||
|
||
# Shard width should be multiple of 16 to satisfy L1 alignment (width = multiple 8 for bfloat16) | ||
while input_shape[-1] % 16 != 0: | ||
input_shape[-1] *= 2 | ||
input_shape[-2] //= 2 | ||
|
||
if shard_orientation == "COL_MAJOR": | ||
tmp = input_shape[-2] | ||
input_shape[-2] = input_shape[-1] | ||
input_shape[-1] = tmp | ||
|
||
elif sharding_strategy == "BLOCK": | ||
min_shard_size_y = 32 * y | ||
min_shard_size_x = 32 * x | ||
mul_x = random.randint(1, 10) | ||
mul_y = random.randint(1, 64 // mul_x) | ||
|
||
input_shape = random.choice( | ||
_gen_reshape_args_from_volume(mul_y * min_shard_size_y, step=1, out_dims=rank - 1) | ||
) | ||
input_shape = list(input_shape["reshape_dims"]) | ||
input_shape.append(mul_x * min_shard_size_x) | ||
|
||
elif sharding_strategy == "WIDTH" or sharding_strategy == "HEIGHT": | ||
# if shard_width % total_cores != 0: raise RuntimeError("Invalid sharding core_grid") | ||
# Shard Size must be multiple of input_tile_size | ||
|
||
if layout == "TILE_LAYOUT": | ||
# In shard mode ShardMode::PHYSICAL, physical shard shape {12, 13312} is not compatible with alignment Alignment([32, 32])! | ||
min_shard_size_x = 32 | ||
min_shard_size_y = 32 * x * y | ||
else: # if layout == "ROW_MAJOR_LAYOUT": | ||
# Shard Size must be multiple of input_tile_size | ||
# Shard width should be multiple of 16 to satisfy L1 alignment | ||
mul_32_y = random.choice([16, 32, 64, 128, 256, 512, 1024]) | ||
mul_32_x = 1024 // mul_32_y | ||
|
||
if sharding_strategy == "HEIGHT": | ||
# Shard width should be multiple of 16 to satisfy L1 alignment | ||
while mul_32_x % 16 != 0: | ||
mul_32_x *= 2 | ||
mul_32_y //= 2 | ||
|
||
min_shard_size_x = mul_32_x | ||
min_shard_size_y = mul_32_y * x * y | ||
|
||
rest_volume = random.randint(1, max_tensor_size // (min_shard_size_x * min_shard_size_y)) | ||
input_shape = random.choice(_gen_reshape_args_from_volume(rest_volume, step=1, out_dims=rank)) | ||
input_shape = list(input_shape["reshape_dims"]) | ||
input_shape[-2] = input_shape[-2] * min_shard_size_x | ||
input_shape[-1] = input_shape[-1] * min_shard_size_y | ||
|
||
if sharding_strategy == "HEIGHT": | ||
tmp = input_shape[-2] | ||
input_shape[-2] = input_shape[-1] | ||
input_shape[-1] = tmp | ||
|
||
# print(input_shape) | ||
|
||
spec_list.append( | ||
{ | ||
"input_shape": input_shape, | ||
"sharding_strategy": sharding_strategy, | ||
"shard_orientation": shard_orientation, | ||
"tensor_hw_as_shard_shape": tensor_hw_as_shard_shape, | ||
"input_layout": layout, | ||
} | ||
) | ||
|
||
return spec_list | ||
|
||
|
||
def parse_sharding_spec(input_spec): | ||
input_shape = input_spec["input_shape"] | ||
sharding_strategy = input_spec["sharding_strategy"] | ||
shard_orientation = input_spec["shard_orientation"] | ||
tensor_hw_as_shard_shape = input_spec["tensor_hw_as_shard_shape"] | ||
input_layout = input_spec["input_layout"] | ||
|
||
if sharding_strategy == "HEIGHT": | ||
sharding_strategy = ttnn.ShardStrategy.HEIGHT | ||
elif sharding_strategy == "WIDTH": | ||
sharding_strategy = ttnn.ShardStrategy.WIDTH | ||
else: # sharding_strategy == "BLOCK": | ||
sharding_strategy = ttnn.ShardStrategy.BLOCK | ||
|
||
if shard_orientation == "COL_MAJOR": | ||
shard_orientation = ttnn.ShardOrientation.COL_MAJOR | ||
else: | ||
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR | ||
|
||
if input_layout == "TILE_LAYOUT": | ||
input_layout = ttnn.TILE_LAYOUT | ||
else: | ||
input_layout = ttnn.ROW_MAJOR_LAYOUT | ||
|
||
return input_shape, sharding_strategy, shard_orientation, tensor_hw_as_shard_shape, input_layout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters