diff --git a/annotator/clipvision/__init__.py b/annotator/clipvision/__init__.py index e54540998..e468ed5d0 100644 --- a/annotator/clipvision/__init__.py +++ b/annotator/clipvision/__init__.py @@ -98,23 +98,22 @@ def __init__(self, config): if not os.path.exists(file_path): load_file_from_url(url=self.download_link, model_dir=self.model_path, file_name=self.file_name) config = CLIPVisionConfig(**self.config) - self.model = CLIPVisionModelWithProjection(config) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[0.48145466, 0.4578275, 0.40821073], - image_std=[0.26862954, 0.26130258, 0.27577711], - resample=3, - size=224) - sd = torch.load(file_path, map_location=torch.device('cpu')) - self.model.load_state_dict(sd, strict=False) - del sd - - self.model.eval() - self.model.cpu() + with self.device: + self.model = CLIPVisionModelWithProjection(config) + self.processor = CLIPImageProcessor(crop_size=224, + do_center_crop=True, + do_convert_rgb=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224) + sd = torch.load(file_path, map_location=self.device) + self.model.load_state_dict(sd, strict=False) + del sd + self.model.eval() def unload_model(self): if self.model is not None: @@ -123,10 +122,9 @@ def unload_model(self): def __call__(self, input_image): with torch.no_grad(): input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA) - clip_vision_model = self.model.cpu() feat = self.processor(images=input_image, return_tensors="pt") - feat['pixel_values'] = feat['pixel_values'].cpu() - result = clip_vision_model(**feat, output_hidden_states=True) - result['hidden_states'] = [v.to(devices.get_device_for("controlnet")) for v in result['hidden_states']] - result = {k: v.to(devices.get_device_for("controlnet")) if isinstance(v, torch.Tensor) else v for k, v in result.items()} + feat['pixel_values'] = feat['pixel_values'].to(self.device) + result = self.model(**feat, output_hidden_states=True) + result['hidden_states'] = [v.to(self.device) for v in result['hidden_states']] + result = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in result.items()} return result