Skip to content

Commit

Permalink
feat: add brevitas channel-wise support
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Jul 29, 2024
1 parent 3d76ba7 commit ce8de8a
Show file tree
Hide file tree
Showing 10 changed files with 468 additions and 111 deletions.
9 changes: 8 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def pytest_addoption(parser):
)

parser.addoption(
"--no-flaky", action="store_true", default=False, help="Don't run known flaky tests."
"--no-flaky",
action="store_true",
default=False,
help="Don't run known flaky tests.",
)


Expand Down Expand Up @@ -382,8 +385,12 @@ def check_float_array_equal_impl(
a, b, rtol=0, atol=0.001, error_information: Optional[str] = ""
):

max_atol = numpy.abs(a - b).max()
max_rtol = (numpy.abs(a - b) / numpy.abs(b)).max()

error_message = (
f"Not equal to tolerance rtol={rtol}, atol={atol}\na: {a}\nb: {b}\n"
f"Found {max_atol=}, {max_rtol=}\n"
f"{error_information}"
)

Expand Down
24 changes: 14 additions & 10 deletions deps_licenses/licenses_mac_silicon_user.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
Name, Version, License
Jinja2, 3.1.4, BSD License
MarkupSafe, 2.1.5, BSD License
PyYAML, 6.0.1, MIT License
brevitas, 0.8.0, UNKNOWN
certifi, 2024.6.2, Mozilla Public License 2.0 (MPL 2.0)
brevitas, 0.10.2, UNKNOWN
certifi, 2024.7.4, Mozilla Public License 2.0 (MPL 2.0)
charset-normalizer, 3.3.2, MIT License
coloredlogs, 15.0.1, MIT License
concrete-python, 2.7.0, BSD-3-Clause
dependencies, 2.0.1, BSD License
dill, 0.3.8, BSD License
filelock, 3.15.3, The Unlicense (Unlicense)
filelock, 3.15.4, The Unlicense (Unlicense)
flatbuffers, 24.3.25, Apache Software License
fsspec, 2024.6.0, BSD License
fsspec, 2024.6.1, BSD License
huggingface-hub, 0.23.4, Apache Software License
humanfriendly, 10.0, MIT License
hummingbird-ml, 0.4.8, MIT License
hummingbird-ml, 0.4.11, MIT License
idna, 3.7, BSD License
importlib_resources, 6.4.0, Apache Software License
joblib, 1.4.2, BSD License
Expand All @@ -22,7 +24,7 @@ networkx, 3.1, BSD License
numpy, 1.23.5, BSD License
onnx, 1.16.1, Apache License v2.0
onnxconverter-common, 1.13.0, MIT License
onnxmltools, 1.11.0, Apache Software License
onnxmltools, 1.12.0, Apache Software License
onnxoptimizer, 0.3.13, Apache License v2.0
onnxruntime, 1.18.0, MIT License
packaging, 24.1, Apache Software License; BSD License
Expand All @@ -35,16 +37,18 @@ requests, 2.32.3, Apache Software License
scikit-learn, 1.1.3, BSD License
scipy, 1.10.1, BSD License
six, 1.16.0, MIT License
skl2onnx, 1.12, Apache Software License
skl2onnx, 1.17.0, Apache Software License
skops, 0.5.0, MIT
skorch, 0.11.0, new BSD 3-Clause
sympy, 1.12.1, BSD License
sympy, 1.13.0, BSD License
tabulate, 0.8.10, MIT License
threadpoolctl, 3.5.0, BSD License
torch, 1.13.1, BSD License
torch, 2.3.1, BSD License
tqdm, 4.66.4, MIT License; Mozilla Public License 2.0 (MPL 2.0)
typing_extensions, 4.5.0, Python Software Foundation License
typing_extensions, 4.12.2, Python Software Foundation License
tzdata, 2024.1, Apache Software License
unfoldNd, 0.2.2, MIT License
urllib3, 2.2.2, MIT License
xgboost, 1.6.2, Apache Software License
z3-solver, 4.13.0.0, MIT License
zipp, 3.19.2, MIT License
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_silicon_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
adb925c3b7be3e651975febcf49b6543
6d367701c3ef5eff8763f4e994e03681
13 changes: 11 additions & 2 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[
matmul_node.input[0],
matmul_node.input[1],
bias_other_input_node_name,
], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
Expand Down Expand Up @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch(

arguments = list(inspect.signature(torch_module.forward).parameters)

if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Expand Down
Loading

0 comments on commit ce8de8a

Please sign in to comment.