-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor torch device types out of od and into _types #829
Merged
mauicv
merged 13 commits into
SeldonIO:master
from
mauicv:feature/refactor-device-types
Jul 26, 2023
Merged
Changes from 1 commit
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3267557
Refactor torch device types out of od and into _types
mauicv 38ca48b
Update types for device throughout detect
mauicv 24e0bc8
Update saving to account for torch.device
mauicv 65b910f
Update doc string
mauicv 3b79f49
Remove redundant logic in _types
mauicv d21a882
Add saving test for torch device logic
mauicv aaa781b
Add pydantic validation for supported torch devices
mauicv 15afda1
Merge branch 'master' into feature/refactor-device-types
mauicv 59ba0c9
Fix save device config docstrings
mauicv 69322a0
Address pr comments
mauicv 0b6cdd3
Add test for device save
mauicv 2d05b32
Fix optional dependency tests
mauicv a3519f3
Minor change
mauicv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str(torch.device('cuda:0'))
will return'cuda:0'
, which makes theReturns
docstring slightly incorrect, but will also break our save/load. I think save/load itself would work, as'cuda:0'
will be resolved byget_device
just fine. However, pydantic validation will fail since we haveLiteral['cpu', 'gpu', 'cuda']
.Possible solutions to me are:
device
kwarg between detectors andpreprocess_drift
function #679 (comment) properly, by implementing a custom pydantic validator to properly validate'cuda:<int>'
strings.schemas.py
todevice: Optional[str] = None
for now.torch.device
from this PR completely.get_device
iftorch.device
passed with a device index. So user knows they cannot serialise the detector when doing this...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we just format the
str(torch.device('cuda:0'))
to remove the device index and raise a warning alerting the user to the change?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this solution. It is simple, and prevents serialised detectors being unloadable e.g. if saved with
cuda:8
and loaded on a 4 gpu machine.If we extend the pydantic validation to support the device index in the future, we could still save as
cuda
, and the user could manually add a device index in theconfig.toml
if they desired.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so when saving the detector first gets the config, then validates and then replaces the values with the string representations... 🤔 I've added the Pydantic validation as It seems like the best way of going about this.
I've kept it simple for now though, it just validates the device type from
str(device)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn I forgot about the pre-saving validation...