Skip to content

Commit

Permalink
Merge pull request #101 from fastmachinelearning/inference_cost_break…
Browse files Browse the repository at this point in the history
…down

inference cost breakdown
  • Loading branch information
maltanar authored May 21, 2024
2 parents c5bd87f + a4e7e35 commit db969e6
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 103 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Inference cost for CNV_2W2A.onnx
}
```

You can use the `--cost-breakdown` option to generate a more detailed report that covers per-node (by name) and per-op-type information.
You can read more about the BOPS metric in [this paper](https://www.frontiersin.org/articles/10.3389/frai.2021.676564/full), Section 4.2 Bit Operations.

### Convert between different quantization representations
Expand Down
31 changes: 24 additions & 7 deletions src/qonnx/analysis/inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def inference_cost_conv(model, node, discount_sparsity):
mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
w_mem_type_str = "mem_w_%s" % (wdt_name)
o_mem_type_str = "mem_o_%s" % (odt_name)
# keep in floats to remain compatible with json serialization
n_macs, w_mem, o_mem = float(n_macs), float(w_mem), float(o_mem)
ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}
return ret

Expand Down Expand Up @@ -161,6 +163,8 @@ def inference_cost_matmul(model, node, discount_sparsity):
mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
w_mem_type_str = "mem_w_%s" % (wdt_name)
o_mem_type_str = "mem_o_%s" % (odt_name)
# keep in floats to remain compatible with json serialization
n_macs, w_mem, o_mem = float(n_macs), float(w_mem), float(o_mem)
ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}
return ret

Expand Down Expand Up @@ -197,14 +201,16 @@ def inference_cost_upsample(model, node, discount_sparsity):
mac_op_type_str = "op_mac_%s_%s" % (idt_name, idt_name)
o_mem_type_str = "mem_o_%s" % (odt_name)

# keep in floats to remain compatible with json serialization
n_macs, o_mem = float(n_macs), float(o_mem)
ret = {mac_op_type_str: n_macs, o_mem_type_str: o_mem}
return ret


def inference_cost(model, discount_sparsity=True):
def inference_cost(model, discount_sparsity=True, cost_breakdown=False):
"Ensure all nodes have unique names prior to calling this analysis pass."

node_costs = {}
ret, node_costs, nodes_per_optype = {}, {}, {}
zero_cost_ops = [
"MaxPool",
"AveragePool",
Expand Down Expand Up @@ -240,13 +246,24 @@ def inference_cost(model, discount_sparsity=True):
if node.op_type in inference_cost_fxn_map.keys():
node_cost = inference_cost_fxn_map[node.op_type](model, node, discount_sparsity)
node_costs[node.name] = node_cost
if node.op_type not in nodes_per_optype.keys():
new_optype = {}
new_optype[node.name] = node_cost
nodes_per_optype[node.op_type] = new_optype
else:
nodes_per_optype[node.op_type][node.name] = node_cost
elif node.op_type in zero_cost_ops:
continue
else:
unsupported_ops.add(node.op_type)

ret = aggregate_dict_keys(node_costs)
ret["unsupported"] = unsupported_ops
ret["discount_sparsity"] = discount_sparsity

total = aggregate_dict_keys(node_costs)
total["unsupported"] = unsupported_ops
total["discount_sparsity"] = discount_sparsity
ret["total_cost"] = total
if cost_breakdown:
optype_cost = {}
for optype, resources in nodes_per_optype.items():
optype_cost[optype] = aggregate_dict_keys(resources)
ret["optype_cost"] = optype_cost
ret["node_cost"] = node_costs
return ret
61 changes: 42 additions & 19 deletions src/qonnx/util/inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,24 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
return total_mem_bits, total_mem_elems


def assign_mem_bits_and_elems(res_dict):
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res_dict, "mem_w")
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res_dict, "mem_o")
res_dict["total_mem_w_bits"] = mem_w_bits
res_dict["total_mem_w_elems"] = mem_w_elems
res_dict["total_mem_o_bits"] = mem_o_bits
res_dict["total_mem_o_elems"] = mem_o_elems
return res_dict


def inference_cost(
model_filename_or_wrapper, *, output_json=None, output_onnx=None, preprocess=True, discount_sparsity=True
model_filename_or_wrapper,
*,
output_json=None,
output_onnx=None,
preprocess=True,
discount_sparsity=True,
cost_breakdown=False
):
"""Return the inference cost estimate metric for given ONNX model.
Supports the Quant op for weight/activation quantization.
Expand All @@ -84,7 +100,10 @@ def inference_cost(
datatype inference and constant folding. Strongly recommended.
:param discount_sparsity: If set, will discount op cost of MAC ops with a
constant zero weight, and the mem cost of constant zero weights.
"""
:param cost_breakdown: If set, include per-node (by name) and per-node-type
breakdowns as part of the returned inference cost dict."""

combined_results = {}
if isinstance(model_filename_or_wrapper, ModelWrapper):
model = model_filename_or_wrapper
else:
Expand All @@ -104,25 +123,29 @@ def inference_cost(
model = model.transform(GiveReadableTensorNames())
if output_onnx is not None:
model.save(output_onnx)
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity))
bops, macs = compute_bops_and_macs(ret)
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(ret, "mem_w")
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(ret, "mem_o")
ret["total_bops"] = bops
ret["total_macs"] = macs
ret["total_mem_w_bits"] = mem_w_bits
ret["total_mem_w_elems"] = mem_w_elems
ret["total_mem_o_bits"] = mem_o_bits
ret["total_mem_o_elems"] = mem_o_elems

if "unsupported" in ret:
ret["unsupported"] = str(ret["unsupported"])

ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity, cost_breakdown))
for i, res in ret.items():
if i == "total_cost":
bops, macs = compute_bops_and_macs(res)
res = assign_mem_bits_and_elems(res)
res["total_bops"] = bops
res["total_macs"] = macs
if "unsupported" in res:
res["unsupported"] = str(res["unsupported"])
combined_results[i] = res
elif i in ["optype_cost", "node_cost"]:
per_optype_or_node_breakdown = {}
for optype, op_res in res.items():
bops, macs = compute_bops_and_macs(op_res)
op_res = assign_mem_bits_and_elems(op_res)
op_res["total_bops"] = bops
op_res["total_macs"] = macs
per_optype_or_node_breakdown[optype] = op_res
combined_results[i] = per_optype_or_node_breakdown
if output_json is not None:
with open(output_json, "w") as f:
json.dump(ret, f, sort_keys=True, indent=2)

return ret
json.dump(combined_results, f, sort_keys=True, indent=2)
return combined_results


def main():
Expand Down
152 changes: 82 additions & 70 deletions tests/analysis/test_inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,90 +34,102 @@
model_details_infcost = {
"FINN-CNV_W2A2": {
"expected_sparse": {
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
"mem_w_INT2": 908033.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 35615771.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 163991084.0,
"total_macs": 36961271.0,
"total_mem_w_bits": 1816066.0,
"total_mem_w_elems": 908033.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
"total_cost": {
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
"mem_w_INT2": 908033.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 35615771.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 163991084.0,
"total_macs": 36961271.0,
"total_mem_w_bits": 1816066.0,
"total_mem_w_elems": 908033.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
}
},
"expected_dense": {
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
"mem_w_INT2": 1542848.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 57906176.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 256507904.0,
"total_macs": 59461376.0,
"total_mem_w_bits": 3085696.0,
"total_mem_w_elems": 1542848.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
"total_cost": {
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
"mem_w_INT2": 1542848.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 57906176.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 256507904.0,
"total_macs": 59461376.0,
"total_mem_w_bits": 3085696.0,
"total_mem_w_elems": 1542848.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
}
},
},
"FINN-TFC_W2A2": {
"expected_sparse": {
"op_mac_INT2_INT2": 22355.0,
"mem_w_INT2": 22355.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 89420.0,
"total_macs": 22355.0,
"total_mem_w_bits": 44710.0,
"total_mem_w_elems": 22355.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
"total_cost": {
"op_mac_INT2_INT2": 22355.0,
"mem_w_INT2": 22355.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 89420.0,
"total_macs": 22355.0,
"total_mem_w_bits": 44710.0,
"total_mem_w_elems": 22355.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
}
},
"expected_dense": {
"op_mac_INT2_INT2": 59008.0,
"mem_w_INT2": 59008.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 236032.0,
"total_macs": 59008.0,
"total_mem_w_bits": 118016.0,
"total_mem_w_elems": 59008.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
"total_cost": {
"op_mac_INT2_INT2": 59008.0,
"mem_w_INT2": 59008.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 236032.0,
"total_macs": 59008.0,
"total_mem_w_bits": 118016.0,
"total_mem_w_elems": 59008.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
}
},
},
"RadioML_VGG10": {
"expected_sparse": {
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
"mem_w_SCALEDINT<8>": 155617.0,
"mem_o_SCALEDINT<32>": 130328.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 807699904.0,
"total_macs": 12620311.0,
"total_mem_w_bits": 1244936.0,
"total_mem_w_elems": 155617.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
"total_cost": {
"unsupported": "set()",
"discount_sparsity": True,
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
"mem_w_SCALEDINT<8>": 155617.0,
"mem_o_SCALEDINT<32>": 130328.0,
"total_bops": 807699904.0,
"total_macs": 12620311.0,
"total_mem_w_bits": 1244936.0,
"total_mem_w_elems": 155617.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
}
},
"expected_dense": {
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
"mem_w_SCALEDINT<8>": 159104.0,
"mem_o_SCALEDINT<32>": 130328.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 823328768.0,
"total_macs": 12864512.0,
"total_mem_w_bits": 1272832.0,
"total_mem_w_elems": 159104.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
"total_cost": {
"unsupported": "set()",
"discount_sparsity": False,
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
"mem_w_SCALEDINT<8>": 159104.0,
"mem_o_SCALEDINT<32>": 130328.0,
"total_bops": 823328768.0,
"total_macs": 12864512.0,
"total_mem_w_bits": 1272832.0,
"total_mem_w_elems": 159104.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
}
},
},
}
Expand Down
Loading

0 comments on commit db969e6

Please sign in to comment.