Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/onnx argmax #1814

Merged
merged 8 commits into from
May 31, 2024
Merged

Feature/onnx argmax #1814

merged 8 commits into from
May 31, 2024

Conversation

will-maclean
Copy link
Contributor

ArgMax Onnx Op for Burn-Import

Checklist

  • [✅] Confirmed that run-checks all script has been executed.
  • [✅] Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Original Issue

Changes

  • Created crates/burn-import/onnx/tests/argmax/argmax.py with a simple model that calls .argmax(dim) on an input tensor, and onnx code to convert the model into a ONNX file
  • Created an ArgMaxNode in crates/burn-import/src/burn/node/argmax.rs to store ArgMax functionality. Also added necessary helper functions
  • Created tests in onnx_tests.rs

Testing

A test called argmax in onnx_tests.rs checks that the created model is able to generate the correct argmax outputs.

One thing I did notice was that the only params implemented for the existing tensor argmax function are the dim/axis - keepdims and select_last_index, which are also params for the ONNX argmax node, don't seem to exist in burn. I set them to their defauls, however this may cause issues if trying to import an ONNX model where e.g. keepdims=false. Happy to hear people's input on this.

Copy link

codecov bot commented May 27, 2024

Codecov Report

Attention: Patch coverage is 89.72603% with 15 lines in your changes are missing coverage. Please review.

Project coverage is 86.42%. Comparing base (e61b026) to head (f9d8e32).
Report is 7 commits behind head on main.

Files Patch % Lines
crates/burn-import/src/onnx/op_configuration.rs 61.76% 13 Missing ⚠️
crates/burn-import/src/onnx/dim_inference.rs 87.50% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1814      +/-   ##
==========================================
- Coverage   86.43%   86.42%   -0.02%     
==========================================
  Files         753      761       +8     
  Lines       87602    87987     +385     
==========================================
+ Hits        75723    76041     +318     
- Misses      11879    11946      +67     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for tackling the ONNX track!

Overall, the implementation looks good 🙂 Just some minor comments/changes to complete the PR.

crates/burn-import/onnx-tests/tests/argmax/argmax.py Outdated Show resolved Hide resolved
crates/burn-import/onnx-tests/tests/argmax/argmax.py Outdated Show resolved Hide resolved
crates/burn-import/src/burn/node/argmax.rs Outdated Show resolved Hide resolved
crates/burn-import/src/burn/node/argmax.rs Outdated Show resolved Hide resolved
crates/burn-import/src/burn/node/argmax.rs Outdated Show resolved Hide resolved
crates/burn-import/src/burn/node/base.rs Outdated Show resolved Hide resolved
crates/burn-import/src/onnx/dim_inference.rs Outdated Show resolved Hide resolved
crates/burn-import/src/onnx/dim_inference.rs Outdated Show resolved Hide resolved
crates/burn-import/src/onnx/op_configuration.rs Outdated Show resolved Hide resolved
@will-maclean
Copy link
Contributor Author

Hi laggui, thanks for the feedback. I addressed those points (hopefully I didn't miss any) and set any keepdims=false to panic, with warnings for trying to set select_last_index.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a minor request just to limit the warning based on the value for select_last_index.

Otherwise, looks good! So I'll approve in advance and we can merge when addressed.

Comment on lines 494 to 497
"select_last_index" => log::warn!(
"select_last_index param for argmax is ignored in burn (got {:?})",
value
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still capture the select_last_index value here, but only warn if it is 1 (because the default implementation pretty much everywhere including Burn is to return the first max value, not last).

@will-maclean
Copy link
Contributor Author

Should be all good now :)

@antimora antimora requested a review from laggui May 31, 2024 16:44
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!!

@laggui laggui merged commit 13a6f84 into tracel-ai:main May 31, 2024
14 checks passed
LilDojd pushed a commit to LilDojd/burn that referenced this pull request Jun 5, 2024
* pre-test

* implementing argmax for burn-import from onnx

* tidying

* fixing return types and tests

* addressing feedback

* only warn when select_last_index!=0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants