Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor torch device types out of od and into _types #829

Merged
merged 13 commits into from
Jul 26, 2023
3 changes: 2 additions & 1 deletion alibi_detect/saving/_pytorch/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,5 @@ def save_device_config(device: TorchDeviceType):
-------
a string with value ``'cuda'`` or ``'cpu'``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str(torch.device('cuda:0')) will return 'cuda:0', which makes the Returns docstring slightly incorrect, but will also break our save/load. I think save/load itself would work, as 'cuda:0' will be resolved by get_device just fine. However, pydantic validation will fail since we have Literal['cpu', 'gpu', 'cuda'].

Possible solutions to me are:

  1. Implement Inconsistency in device kwarg between detectors and preprocess_drift function #679 (comment) properly, by implementing a custom pydantic validator to properly validate 'cuda:<int>' strings.
  2. Relax the pydantic validation in schemas.py to device: Optional[str] = None for now.
  3. Remove support for passing torch.device from this PR completely.
  4. Do nothing, except throw a warning/error in get_device if torch.device passed with a device index. So user knows they cannot serialise the detector when doing this...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we just format the str(torch.device('cuda:0')) to remove the device index and raise a warning alerting the user to the change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution. It is simple, and prevents serialised detectors being unloadable e.g. if saved with cuda:8 and loaded on a 4 gpu machine.

If we extend the pydantic validation to support the device index in the future, we could still save as cuda, and the user could manually add a device index in the config.toml if they desired.

Copy link
Collaborator Author

@mauicv mauicv Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so when saving the detector first gets the config, then validates and then replaces the values with the string representations... 🤔 I've added the Pydantic validation as It seems like the best way of going about this.

I've kept it simple for now though, it just validates the device type from str(device).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn I forgot about the pre-saving validation...

"""
return str(device)
device_str = str(device)
return device_str.split(':')[0]
67 changes: 47 additions & 20 deletions alibi_detect/saving/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,33 @@ def validate_optimizer(cls, optimizer: Any, values: dict) -> Any:
# of preprocess_drift.


class SupportedDevice:
"""
Pydantic custom type to check the device is correct for the choice of backend (conditional on what optional deps
are installed).
"""
@classmethod
def __get_validators__(cls):
yield cls.validate_device

@classmethod
def validate_device(cls, device: Any, values: dict) -> Any:
backend = values['backend']
if backend == Framework.TENSORFLOW or backend == Framework.SKLEARN:
if device is not None:
raise TypeError('`device` should not be specified for TensorFlow or Sklearn models. Leave as `None`.')
ascillitoe marked this conversation as resolved.
Show resolved Hide resolved
else:
return device
elif backend == Framework.PYTORCH or backend == Framework.KEOPS:
device_str = str(device).split(':')[0]
if device_str not in ['cpu', 'cuda', 'gpu']:
raise TypeError(f'`device` should be one of `cpu`, `cuda`, `gpu` or a torch.Device. Got {device}.')
else:
return device
else: # Catch any other unexpected issues
raise TypeError('The device is not recognised as a supported type.')


# Custom BaseModel so that we can set default config
class CustomBaseModel(BaseModel):
"""
Expand Down Expand Up @@ -682,7 +709,7 @@ class MMDDriftConfig(DriftDetectorConfig):
configure_kernel_from_x_ref: bool = True
n_permutations: int = 100
batch_size_permutations: int = 1000000
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None


class MMDDriftConfigResolved(DriftDetectorConfigResolved):
Expand All @@ -702,7 +729,7 @@ class MMDDriftConfigResolved(DriftDetectorConfigResolved):
configure_kernel_from_x_ref: bool = True
n_permutations: int = 100
batch_size_permutations: int = 1000000
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None


class LSDDDriftConfig(DriftDetectorConfig):
Expand All @@ -721,7 +748,7 @@ class LSDDDriftConfig(DriftDetectorConfig):
n_permutations: int = 100
n_kernel_centers: Optional[int] = None
lambda_rd_max: float = 0.2
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None


class LSDDDriftConfigResolved(DriftDetectorConfigResolved):
Expand All @@ -740,7 +767,7 @@ class LSDDDriftConfigResolved(DriftDetectorConfigResolved):
n_permutations: int = 100
n_kernel_centers: Optional[int] = None
lambda_rd_max: float = 0.2
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None


class ClassifierDriftConfig(DriftDetectorConfig):
Expand Down Expand Up @@ -772,7 +799,7 @@ class ClassifierDriftConfig(DriftDetectorConfig):
verbose: int = 0
train_kwargs: Optional[dict] = None
dataset: Optional[str] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
dataloader: Optional[str] = None # TODO: placeholder, will need to be updated for pytorch implementation
use_calibration: bool = False
calibration_kwargs: Optional[dict] = None
Expand Down Expand Up @@ -808,7 +835,7 @@ class ClassifierDriftConfigResolved(DriftDetectorConfigResolved):
verbose: int = 0
train_kwargs: Optional[dict] = None
dataset: Optional[Callable] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
dataloader: Optional[Callable] = None # TODO: placeholder, will need to be updated for pytorch implementation
use_calibration: bool = False
calibration_kwargs: Optional[dict] = None
Expand Down Expand Up @@ -843,7 +870,7 @@ class SpotTheDiffDriftConfig(DriftDetectorConfig):
n_diffs: int = 1
initial_diffs: Optional[str] = None
l1_reg: float = 0.01
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
dataloader: Optional[str] = None # TODO: placeholder, will need to be updated for pytorch implementation


Expand Down Expand Up @@ -875,7 +902,7 @@ class SpotTheDiffDriftConfigResolved(DriftDetectorConfigResolved):
n_diffs: int = 1
initial_diffs: Optional[np.ndarray] = None
l1_reg: float = 0.01
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
dataloader: Optional[Callable] = None # TODO: placeholder, will need to be updated for pytorch implementation


Expand Down Expand Up @@ -909,7 +936,7 @@ class LearnedKernelDriftConfig(DriftDetectorConfig):
verbose: int = 0
train_kwargs: Optional[dict] = None
dataset: Optional[str] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
dataloader: Optional[str] = None # TODO: placeholder, will need to be updated for pytorch implementation


Expand Down Expand Up @@ -943,7 +970,7 @@ class LearnedKernelDriftConfigResolved(DriftDetectorConfigResolved):
verbose: int = 0
train_kwargs: Optional[dict] = None
dataset: Optional[Callable] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
dataloader: Optional[Callable] = None # TODO: placeholder, will need to be updated for pytorch implementation


Expand All @@ -968,7 +995,7 @@ class ContextMMDDriftConfig(DriftDetectorConfig):
n_folds: int = 5
batch_size: Optional[int] = 256
verbose: bool = False
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None


class ContextMMDDriftConfigResolved(DriftDetectorConfigResolved):
Expand All @@ -991,7 +1018,7 @@ class ContextMMDDriftConfigResolved(DriftDetectorConfigResolved):
n_folds: int = 5
batch_size: Optional[int] = 256
verbose: bool = False
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None


class MMDDriftOnlineConfig(DriftDetectorConfig):
Expand All @@ -1009,7 +1036,7 @@ class MMDDriftOnlineConfig(DriftDetectorConfig):
kernel: Optional[Union[str, KernelConfig]] = None
sigma: Optional[np.ndarray] = None
n_bootstraps: int = 1000
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
verbose: bool = True


Expand All @@ -1028,7 +1055,7 @@ class MMDDriftOnlineConfigResolved(DriftDetectorConfigResolved):
kernel: Optional[Callable] = None
sigma: Optional[np.ndarray] = None
n_bootstraps: int = 1000
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
verbose: bool = True


Expand All @@ -1048,7 +1075,7 @@ class LSDDDriftOnlineConfig(DriftDetectorConfig):
n_bootstraps: int = 1000
n_kernel_centers: Optional[int] = None
lambda_rd_max: float = 0.2
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
verbose: bool = True


Expand All @@ -1068,7 +1095,7 @@ class LSDDDriftOnlineConfigResolved(DriftDetectorConfigResolved):
n_bootstraps: int = 1000
n_kernel_centers: Optional[int] = None
lambda_rd_max: float = 0.2
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
verbose: bool = True


Expand Down Expand Up @@ -1178,7 +1205,7 @@ class ClassifierUncertaintyDriftConfig(DetectorConfig):
margin_width: float = 0.1
batch_size: int = 32
preprocess_batch_fn: Optional[str] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
tokenizer: Optional[Union[str, TokenizerConfig]] = None
max_len: Optional[int] = None
input_shape: Optional[tuple] = None
Expand All @@ -1205,7 +1232,7 @@ class ClassifierUncertaintyDriftConfigResolved(DetectorConfig):
margin_width: float = 0.1
batch_size: int = 32
preprocess_batch_fn: Optional[Callable] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
tokenizer: Optional[Union[str, Callable]] = None
max_len: Optional[int] = None
input_shape: Optional[tuple] = None
Expand All @@ -1231,7 +1258,7 @@ class RegressorUncertaintyDriftConfig(DetectorConfig):
n_evals: int = 25
batch_size: int = 32
preprocess_batch_fn: Optional[str] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
tokenizer: Optional[Union[str, TokenizerConfig]] = None
max_len: Optional[int] = None
input_shape: Optional[tuple] = None
Expand All @@ -1257,7 +1284,7 @@ class RegressorUncertaintyDriftConfigResolved(DetectorConfig):
n_evals: int = 25
batch_size: int = 32
preprocess_batch_fn: Optional[Callable] = None
device: Optional[Literal['cpu', 'cuda', 'gpu']] = None
device: Optional[SupportedDevice] = None
tokenizer: Optional[Callable] = None
max_len: Optional[int] = None
input_shape: Optional[tuple] = None
Expand Down
10 changes: 7 additions & 3 deletions alibi_detect/saving/tests/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,7 @@ def test_cleanup(tmp_path):
('pytorch', 'cuda:0'),
('pytorch', torch.device('cuda')),
('pytorch', torch.device('cuda:0')),
('tensorflow', None),
])
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_detector_device(backend, device, data, tmp_path, classifier_model): # noqa: F811
Expand All @@ -1392,6 +1393,9 @@ def test_save_detector_device(backend, device, data, tmp_path, classifier_model)
)
save_detector(detector, tmp_path)
detector_config = toml.load(tmp_path / 'config.toml')
assert detector_config['device'] in {'cpu', 'gpu', 'cuda', 'cuda:0'}
detector = load_detector(tmp_path)
assert detector._detector.device in {torch.device('cpu'), torch.device('cuda')}
loaded_detector = load_detector(tmp_path)
if backend == 'tensorflow':
assert detector_config['device'] == 'None'
else:
assert detector_config['device'] in {'cpu', 'gpu', 'cuda'}
assert loaded_detector._detector.device in {torch.device('cpu'), torch.device('cuda')}