-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor torch device types out of od and into _types #829
Refactor torch device types out of od and into _types #829
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #829 +/- ##
==========================================
+ Coverage 81.90% 81.98% +0.08%
==========================================
Files 159 159
Lines 10338 10375 +37
==========================================
+ Hits 8467 8506 +39
+ Misses 1871 1869 -2
|
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend. | ||
Device type used. The default tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of | ||
``torch.device``. Only relevant for 'pytorch' backend. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity, if you update the intersphinx_mapping
like pytorch/pytorch#10400 and then reference torch.device
like :py:class:torch.device
(should be in backticks), does it work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I leave this to explore in a separate issue? This PR already has a much wider scope than initially intended! 😅
@@ -37,3 +34,5 @@ | |||
# type aliases, for use with mypy (must be FwdRef's if involving opt. deps.) | |||
OptimizerTF: TypeAlias = Union['tf.keras.optimizers.Optimizer', 'tf.keras.optimizers.legacy.Optimizer', | |||
Type['tf.keras.optimizers.Optimizer'], Type['tf.keras.optimizers.legacy.Optimizer']] | |||
|
|||
TorchDeviceType: TypeAlias = Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re the forward reference 'torch.device'
in here, I can't think of a good fix at the moment, but just noting that this introduces lots of additional sphinx warnings, and is not rendered "perfectly" in the docs (we've gone from 6 to 29 warnings, which makes me sad).
I suspect the forward ref would be resolved during docs compilation if we installed alibi-detect[all]
on read-the-docs (#499) which is now allowed, but it seems wasteful...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😮💨 arg... I'll open an issue. Maybe this PR might need to be reigned in! Or split into two!
device | ||
Device type used. The default tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of | ||
``torch.device``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary indents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring also seems slightly inaccurate? Maybe just something like Torch device to be serialised.
?
|
||
Returns | ||
------- | ||
a string with value ``'cuda'`` or ``'cpu'``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str(torch.device('cuda:0'))
will return 'cuda:0'
, which makes the Returns
docstring slightly incorrect, but will also break our save/load. I think save/load itself would work, as 'cuda:0'
will be resolved by get_device
just fine. However, pydantic validation will fail since we have Literal['cpu', 'gpu', 'cuda']
.
Possible solutions to me are:
- Implement Inconsistency in
device
kwarg between detectors andpreprocess_drift
function #679 (comment) properly, by implementing a custom pydantic validator to properly validate'cuda:<int>'
strings. - Relax the pydantic validation in
schemas.py
todevice: Optional[str] = None
for now. - Remove support for passing
torch.device
from this PR completely. - Do nothing, except throw a warning/error in
get_device
iftorch.device
passed with a device index. So user knows they cannot serialise the detector when doing this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we just format the str(torch.device('cuda:0'))
to remove the device index and raise a warning alerting the user to the change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this solution. It is simple, and prevents serialised detectors being unloadable e.g. if saved with cuda:8
and loaded on a 4 gpu machine.
If we extend the pydantic validation to support the device index in the future, we could still save as cuda
, and the user could manually add a device index in the config.toml
if they desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so when saving the detector first gets the config, then validates and then replaces the values with the string representations... 🤔 I've added the Pydantic validation as It seems like the best way of going about this.
I've kept it simple for now though, it just validates the device type from str(device)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn I forgot about the pre-saving validation...
@@ -188,6 +189,11 @@ def _save_detector_config(detector: ConfigurableDetector, | |||
if optimizer is not None: | |||
cfg['optimizer'] = _save_optimizer_config(optimizer) | |||
|
|||
# Serialize device | |||
device = cfg.get('device') | |||
if device is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of the _save_device_config
wrapper, isn't it easier just to do cfg['device'] = save_device_config_pt(device)
here?
Granted, we do have a _save_optimizer_config
wrapper, but that is a little different since we do have some sort of optimizer
for tensorflow and torch. For device
, is is torch
only atm so not sure we need the wrapper...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p.s. Maybe _save_device
would be more accurate than _save_device_config
? _save_optimizer_config
etc are named _config
since they do actually return a "config dict
", whereas _save_device_config
is only returning a str
.
alibi_detect/saving/saving.py
Outdated
|
||
# if device is not none then we're using pytorch | ||
if device is not None: | ||
return save_device_config_pt(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't if device is not None
unnecessary? Can you even arrive inside _save_device_config
if device is None
? Because its already checked here?
@@ -295,7 +295,7 @@ class PreprocessConfig(CustomBaseModel): | |||
Optional tokenizer for text drift. Either a string referencing a HuggingFace tokenizer model name, or a | |||
:class:`~alibi_detect.utils.schemas.TokenizerConfig`. | |||
""" | |||
device: Optional[Literal['cpu', 'cuda']] = None | |||
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Content of this pending decision on https://github.com/SeldonIO/alibi-detect/pull/829/files#r1262574030
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: agreed to format device string to remove device index prior to saving. See this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few minor comments, main one regarding serialisation.
I'll do a final pass once tests are written. Regarding tests, I reckon we could get away with a single unit test saving with save_device_config
and then running through get_device
? Parameterised with all the supported device
types...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM bar one minor nitpick
What is this:
Defines
TorchDeviceTypes: TypeAlias = Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']]
in_types.py
and refactors the typing for the device in the detectors.fixes #779, #679. Also fixes #763