Skip to content

Commit

Permalink
Handle .to for 4bit quantized models (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
g8a9 authored May 25, 2023
1 parent 1789acf commit e3b1f59
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,18 @@ def load(
def device(self, new_device: str) -> None:
check_device(new_device)
self._device = new_device
is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
is_quantized = is_loaded_in_8bit or is_loaded_in_4bit

# Enable compatibility with 8bit models
if self.model:
if not (hasattr(self.model, "is_loaded_in_8bit") and self.model.is_loaded_in_8bit):
if not is_quantized:
self.model.to(self._device)
else:
mode = "8bit" if is_loaded_in_8bit else "4bit"
logger.warning(
"The model is loaded in 8bit mode. The device cannot be changed after loading the model."
f"The model is loaded in {mode} mode. The device cannot be changed after loading the model."
)

@abstractmethod
Expand Down

0 comments on commit e3b1f59

Please sign in to comment.