diff --git a/sentence_transformers/cross_encoder/CrossEncoder.py b/sentence_transformers/cross_encoder/CrossEncoder.py index c8d8d1d84..c4345017a 100644 --- a/sentence_transformers/cross_encoder/CrossEncoder.py +++ b/sentence_transformers/cross_encoder/CrossEncoder.py @@ -123,8 +123,7 @@ def __init__( if device is None: device = get_device_name() logger.info(f"Use pytorch device: {device}") - - self._target_device = torch.device(device) + self.model.to(device) if default_activation_function is not None: self.default_activation_function = default_activation_function @@ -154,11 +153,11 @@ def smart_batching_collate(self, batch: list[InputExample]) -> tuple[BatchEncodi *texts, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.max_length ) labels = torch.tensor(labels, dtype=torch.float if self.config.num_labels == 1 else torch.long).to( - self._target_device + self.model.device ) for name in tokenized: - tokenized[name] = tokenized[name].to(self._target_device) + tokenized[name] = tokenized[name].to(self.model.device) return tokenized, labels @@ -174,7 +173,7 @@ def smart_batching_collate_text_only(self, batch: list[InputExample]) -> BatchEn ) for name in tokenized: - tokenized[name] = tokenized[name].to(self._target_device) + tokenized[name] = tokenized[name].to(self.model.device) return tokenized @@ -232,7 +231,6 @@ def fit( scaler = torch.npu.amp.GradScaler() else: scaler = torch.cuda.amp.GradScaler() - self.model.to(self._target_device) if output_path is not None: os.makedirs(output_path, exist_ok=True) @@ -272,7 +270,7 @@ def fit( train_dataloader, desc="Iteration", smoothing=0.05, disable=not show_progress_bar ): if use_amp: - with torch.autocast(device_type=self._target_device.type): + with torch.autocast(device_type=self.model.device.type): model_predictions = self.model(**features, return_dict=True) logits = activation_fct(model_predictions.logits) if self.config.num_labels == 1: @@ -438,7 +436,6 @@ def predict( pred_scores = [] self.model.eval() - self.model.to(self._target_device) with torch.no_grad(): for features in iterator: model_predictions = self.model(**features, return_dict=True) @@ -604,3 +601,21 @@ def push_to_hub( tags=tags, **kwargs, ) + + def to(self, device: int | str | torch.device | None = None) -> None: + return self.model.to(device) + + @property + def _target_device(self) -> torch.device: + logger.warning( + "`CrossEncoder._target_device` has been removed, please use `CrossEncoder.device` instead.", + ) + return self.device + + @_target_device.setter + def _target_device(self, device: int | str | torch.device | None = None) -> None: + self.to(device) + + @property + def device(self) -> torch.device: + return self.model.device diff --git a/tests/test_cross_encoder.py b/tests/test_cross_encoder.py index 9c17aa093..027751d1a 100644 --- a/tests/test_cross_encoder.py +++ b/tests/test_cross_encoder.py @@ -192,3 +192,35 @@ def test_bfloat16() -> None: ranking = model.rank("Hello there!", ["Hello, World!", "Heya!"]) assert isinstance(ranking, list) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_assignment(device): + model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device=device) + assert model.device.type == device + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") +def test_device_switching(): + # test assignment using .to + model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device="cpu") + assert model.device.type == "cpu" + assert model.model.device.type == "cpu" + + model.to("cuda") + assert model.device.type == "cuda" + assert model.model.device.type == "cuda" + + del model + torch.cuda.empty_cache() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") +def test_target_device_backwards_compat(): + model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device="cpu") + assert model.device.type == "cpu" + + assert model._target_device.type == "cpu" + model._target_device = "cuda" + assert model.device.type == "cuda"