Skip to content

Commit

Permalink
Add data formats to perf report (#15170)
Browse files Browse the repository at this point in the history
Add data formats to perf report; added to math_fidelity column which was
already there and previously showing data formats for only matmuls.
  • Loading branch information
johanna-rock-tt authored Nov 20, 2024
1 parent be20580 commit 145050a
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions models/perf/perf_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,6 @@ def analyze_matmul(row):
dram_percentage = (dram_speed_gb_s / 288) * 100 if dram_speed_gb_s is not None else None
flops_percentage = (flops / peak_flops_value) * 100

input_0_datatype = row["INPUT_0_DATATYPE"]
input_1_datatype = row["INPUT_1_DATATYPE"]
output_datatype = row["OUTPUT_0_DATATYPE"]

return (
dram_speed_gb_s,
dram_percentage,
Expand All @@ -275,10 +271,7 @@ def analyze_matmul(row):
size,
memory_info,
math_fidelity,
output_datatype,
is_dram_sharded,
input_0_datatype,
input_1_datatype,
core_count, # Return the potentially adjusted core count
)

Expand All @@ -301,6 +294,14 @@ def analyze_op(row, prev_row):
else:
dispatch_time = Cell(None, unit="us", decimals=0)

output_datatype = row["OUTPUT_0_DATATYPE"]
input_0_datatype = row["INPUT_0_DATATYPE"]
input_1_datatype = row["INPUT_1_DATATYPE"]
output_datatype_cell = Cell(output_datatype)
input_0_datatype_cell = Cell(input_0_datatype)
input_1_datatype_cell = Cell(input_1_datatype)
short_name = lambda n: {"BFLOAT16": "BF16", "BFLOAT8_B": "BFP8", "BFLOAT4_B": "BFP4"}.get(n, n)

if "Matmul" in op_code.raw_value:
(
dram_speed,
Expand All @@ -310,10 +311,7 @@ def analyze_op(row, prev_row):
size,
memory_info,
math_fidelity,
output_datatype,
is_dram_sharded,
input_0_datatype,
input_1_datatype,
adjusted_core_count, # Get the potentially adjusted core count
) = analyze_matmul(row)
op_code = Cell(f"{op_code.raw_value} {size}")
Expand All @@ -322,26 +320,24 @@ def analyze_op(row, prev_row):
flops = Cell(flops / 1e12 if pd.notna(flops) else None, unit="TFLOPs", decimals=1)
flops_percentage = Cell(flops_percentage, unit="%", decimals=1)

short_name = lambda n: {"BFLOAT16": "BF16", "BFLOAT8_B": "BFP8", "BFLOAT4_B": "BFP4"}.get(n, n)
math_fidelity_cell = Cell(
f"{math_fidelity} {short_name(input_0_datatype)} x {short_name(input_1_datatype)} => {short_name(output_datatype)}".strip()
if math_fidelity
else None
)
output_datatype_cell = Cell(output_datatype)
input_0_datatype_cell = Cell(input_0_datatype)
input_1_datatype_cell = Cell(input_1_datatype)
else:
dram_speed = Cell(None, unit="GB/s", decimals=0)
dram_percentage = Cell(None, unit="%", decimals=1)
flops = Cell(None, unit="TFLOPs", decimals=1)
flops_percentage = Cell(None, unit="%", decimals=1)
math_fidelity, output_datatype = "", None
math_fidelity_cell = Cell(None)
output_datatype_cell = Cell(None)

math_fidelity = ""
math_fidelity += f"{short_name(input_0_datatype)}" if pd.notna(input_0_datatype) else ""
math_fidelity += f", {short_name(input_1_datatype)}" if pd.notna(input_1_datatype) else ""
math_fidelity += f" => {short_name(output_datatype)}" if pd.notna(output_datatype) else ""
math_fidelity_cell = Cell(math_fidelity.strip())

is_dram_sharded = False
input_0_datatype_cell = Cell(None)
input_1_datatype_cell = Cell(None)

output = {
"ID": None,
Expand Down

0 comments on commit 145050a

Please sign in to comment.