Skip to content

Commit

Permalink
feat: significantly optimize the time consumption of clip vision (#2474)
Browse files Browse the repository at this point in the history
Signed-off-by: storyicon <[email protected]>
  • Loading branch information
storyicon authored Jan 15, 2024
1 parent e7b5b60 commit 8b5f7c1
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions annotator/clipvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,22 @@ def __init__(self, config):
if not os.path.exists(file_path):
load_file_from_url(url=self.download_link, model_dir=self.model_path, file_name=self.file_name)
config = CLIPVisionConfig(**self.config)
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224)

sd = torch.load(file_path, map_location=torch.device('cpu'))
self.model.load_state_dict(sd, strict=False)
del sd

self.model.eval()
self.model.cpu()
with self.device:
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224)
sd = torch.load(file_path, map_location=self.device)
self.model.load_state_dict(sd, strict=False)
del sd
self.model.eval()

def unload_model(self):
if self.model is not None:
Expand All @@ -123,10 +122,9 @@ def unload_model(self):
def __call__(self, input_image):
with torch.no_grad():
input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA)
clip_vision_model = self.model.cpu()
feat = self.processor(images=input_image, return_tensors="pt")
feat['pixel_values'] = feat['pixel_values'].cpu()
result = clip_vision_model(**feat, output_hidden_states=True)
result['hidden_states'] = [v.to(devices.get_device_for("controlnet")) for v in result['hidden_states']]
result = {k: v.to(devices.get_device_for("controlnet")) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
feat['pixel_values'] = feat['pixel_values'].to(self.device)
result = self.model(**feat, output_hidden_states=True)
result['hidden_states'] = [v.to(self.device) for v in result['hidden_states']]
result = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
return result

1 comment on commit 8b5f7c1

@zongmi
Copy link

@zongmi zongmi commented on 8b5f7c1 Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, for me, the controlnet speed has been significantly improved

Please sign in to comment.