Skip to content

Commit

Permalink
Fix cross encoder device issue (#3104)
Browse files Browse the repository at this point in the history
* pass1

* tests running

* Use parametrize in CE test

* Return the self.model.to, just in case

* Reintroduce _target_device as a property, so `model._target_device = ...` still works

But with a warning that it shouldn't be used

* Add tests showing that we shouldn't have backwards incompat.

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
susnato and tomaarsen authored Dec 2, 2024
1 parent 39b6eae commit a49ffc5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
31 changes: 23 additions & 8 deletions sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def __init__(
if device is None:
device = get_device_name()
logger.info(f"Use pytorch device: {device}")

self._target_device = torch.device(device)
self.model.to(device)

if default_activation_function is not None:
self.default_activation_function = default_activation_function
Expand Down Expand Up @@ -154,11 +153,11 @@ def smart_batching_collate(self, batch: list[InputExample]) -> tuple[BatchEncodi
*texts, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.max_length
)
labels = torch.tensor(labels, dtype=torch.float if self.config.num_labels == 1 else torch.long).to(
self._target_device
self.model.device
)

for name in tokenized:
tokenized[name] = tokenized[name].to(self._target_device)
tokenized[name] = tokenized[name].to(self.model.device)

return tokenized, labels

Expand All @@ -174,7 +173,7 @@ def smart_batching_collate_text_only(self, batch: list[InputExample]) -> BatchEn
)

for name in tokenized:
tokenized[name] = tokenized[name].to(self._target_device)
tokenized[name] = tokenized[name].to(self.model.device)

return tokenized

Expand Down Expand Up @@ -232,7 +231,6 @@ def fit(
scaler = torch.npu.amp.GradScaler()
else:
scaler = torch.cuda.amp.GradScaler()
self.model.to(self._target_device)

if output_path is not None:
os.makedirs(output_path, exist_ok=True)
Expand Down Expand Up @@ -272,7 +270,7 @@ def fit(
train_dataloader, desc="Iteration", smoothing=0.05, disable=not show_progress_bar
):
if use_amp:
with torch.autocast(device_type=self._target_device.type):
with torch.autocast(device_type=self.model.device.type):
model_predictions = self.model(**features, return_dict=True)
logits = activation_fct(model_predictions.logits)
if self.config.num_labels == 1:
Expand Down Expand Up @@ -438,7 +436,6 @@ def predict(

pred_scores = []
self.model.eval()
self.model.to(self._target_device)
with torch.no_grad():
for features in iterator:
model_predictions = self.model(**features, return_dict=True)
Expand Down Expand Up @@ -604,3 +601,21 @@ def push_to_hub(
tags=tags,
**kwargs,
)

def to(self, device: int | str | torch.device | None = None) -> None:
return self.model.to(device)

@property
def _target_device(self) -> torch.device:
logger.warning(
"`CrossEncoder._target_device` has been removed, please use `CrossEncoder.device` instead.",
)
return self.device

@_target_device.setter
def _target_device(self, device: int | str | torch.device | None = None) -> None:
self.to(device)

@property
def device(self) -> torch.device:
return self.model.device
32 changes: 32 additions & 0 deletions tests/test_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,35 @@ def test_bfloat16() -> None:

ranking = model.rank("Hello there!", ["Hello, World!", "Heya!"])
assert isinstance(ranking, list)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_device_assignment(device):
model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device=device)
assert model.device.type == device


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_device_switching():
# test assignment using .to
model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device="cpu")
assert model.device.type == "cpu"
assert model.model.device.type == "cpu"

model.to("cuda")
assert model.device.type == "cuda"
assert model.model.device.type == "cuda"

del model
torch.cuda.empty_cache()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_target_device_backwards_compat():
model = CrossEncoder("cross-encoder/stsb-distilroberta-base", device="cpu")
assert model.device.type == "cpu"

assert model._target_device.type == "cpu"
model._target_device = "cuda"
assert model.device.type == "cuda"

0 comments on commit a49ffc5

Please sign in to comment.