Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Pipeline to get device from model #30534

Merged
merged 12 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,8 @@ def __init__(
device = -1

if is_torch_available() and self.framework == "pt":
if device == -1 and self.model.device is not None:
device = self.model.device
if isinstance(device, torch.device):
if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
Expand All @@ -871,11 +873,10 @@ def __init__(
self.device = device if device is not None else -1

self.binary_output = binary_output

# We shouldn't call `model.to()` for models loaded with accelerate
# We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device
if (
self.framework == "pt"
and self.device is not None
and self.model.device != self.device
and not (isinstance(self.device, int) and self.device < 0)
and hf_device_map is None
):
Expand Down
47 changes: 47 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
require_tf,
require_torch,
require_torch_accelerator,
require_torch_multi_accelerator,
require_torch_or_tf,
slow,
torch_device,
Expand Down Expand Up @@ -519,6 +520,52 @@ def test_pipeline_negative_device(self):
actual_output = classifier("Test input.")
self.assertEqual(expected_output, actual_output)

@require_torch_accelerator
def test_pipeline_no_device(self):
# Test when no device is passed to pipeline
import torch

from transformers import AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Case 1: Model is manually moved to device
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16
).to(torch_device)
model_device = model.device
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
self.assertEqual(pipe.model.device, model_device)
# Case 2: Model is loaded by accelerate
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16
)
model_device = model.device
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
self.assertEqual(pipe.model.device, model_device)
# Case 3: device_map is passed to model and device is passed to pipeline
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16
)
with self.assertRaises(ValueError):
pipe = pipeline("text-generation", model=model, device="cpu", tokenizer=tokenizer)

@require_torch_multi_accelerator
def test_pipeline_device_not_equal_model_device(self):
# Test when device ids are different, pipeline should move the model to the passed device id
import torch

from transformers import AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
model_device = f"{torch_device}:1"
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16
).to(model_device)
target_device = f"{torch_device}:0"
self.assertNotEqual(model_device, target_device)
pipe = pipeline("text-generation", model=model, device=target_device, tokenizer=tokenizer)
self.assertEqual(pipe.model.device, torch.device(target_device))

@slow
@require_torch
def test_load_default_pipelines_pt(self):
Expand Down