Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cross encoder device issue #3104

Merged
merged 6 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def __init__(
if device is None:
device = get_device_name()
logger.info(f"Use pytorch device: {device}")

self._target_device = torch.device(device)
self.model.to(device)

if default_activation_function is not None:
self.default_activation_function = default_activation_function
Expand Down Expand Up @@ -154,11 +153,11 @@ def smart_batching_collate(self, batch: list[InputExample]) -> tuple[BatchEncodi
*texts, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.max_length
)
labels = torch.tensor(labels, dtype=torch.float if self.config.num_labels == 1 else torch.long).to(
self._target_device
self.model.device
)

for name in tokenized:
tokenized[name] = tokenized[name].to(self._target_device)
tokenized[name] = tokenized[name].to(self.model.device)

return tokenized, labels

Expand All @@ -174,7 +173,7 @@ def smart_batching_collate_text_only(self, batch: list[InputExample]) -> BatchEn
)

for name in tokenized:
tokenized[name] = tokenized[name].to(self._target_device)
tokenized[name] = tokenized[name].to(self.model.device)

return tokenized

Expand Down Expand Up @@ -232,7 +231,6 @@ def fit(
scaler = torch.npu.amp.GradScaler()
else:
scaler = torch.cuda.amp.GradScaler()
self.model.to(self._target_device)

if output_path is not None:
os.makedirs(output_path, exist_ok=True)
Expand Down Expand Up @@ -272,7 +270,7 @@ def fit(
train_dataloader, desc="Iteration", smoothing=0.05, disable=not show_progress_bar
):
if use_amp:
with torch.autocast(device_type=self._target_device.type):
with torch.autocast(device_type=self.model.device.type):
model_predictions = self.model(**features, return_dict=True)
logits = activation_fct(model_predictions.logits)
if self.config.num_labels == 1:
Expand Down Expand Up @@ -438,7 +436,6 @@ def predict(

pred_scores = []
self.model.eval()
self.model.to(self._target_device)
with torch.no_grad():
for features in iterator:
model_predictions = self.model(**features, return_dict=True)
Expand Down Expand Up @@ -604,3 +601,21 @@ def push_to_hub(
tags=tags,
**kwargs,
)

def to(self, device: int | str | torch.device | None = None) -> None:
return self.model.to(device)

@property
def _target_device(self) -> torch.device:
logger.warning(
"`CrossEncoder._target_device` has been removed, please use `CrossEncoder.device` instead.",
)
return self.device

@_target_device.setter
def _target_device(self, device: int | str | torch.device | None = None) -> None:
self.to(device)

@property
def device(self) -> torch.device:
return self.model.device
32 changes: 32 additions & 0 deletions tests/test_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,35 @@ def test_bfloat16() -> None:

ranking = model.rank("Hello there!", ["Hello, World!", "Heya!"])
assert isinstance(ranking, list)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_device_assignment(device):
model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device=device)
assert model.device.type == device


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_device_switching():
# test assignment using .to
model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device="cpu")
assert model.device.type == "cpu"
assert model.model.device.type == "cpu"

model.to("cuda")
assert model.device.type == "cuda"
assert model.model.device.type == "cuda"

del model
torch.cuda.empty_cache()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_target_device_backwards_compat():
model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device="cpu")
assert model.device.type == "cpu"

assert model._target_device.type == "cpu"
model._target_device = "cuda"
assert model.device.type == "cuda"