-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathpredict.py
87 lines (76 loc) · 2.66 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import re
import sys
import requests
from io import BytesIO
from typing import List
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from cog import BasePredictor, Path, Input, BaseModel
class NamedEmbedding(BaseModel):
input: str
embedding: List[float]
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model = CLIPModel.from_pretrained("/weights", local_files_only=True).to(
"cuda"
)
self.processor = CLIPProcessor.from_pretrained(
"/weights", local_files_only=True
)
def predict(
self,
inputs: str = Input(
description="Newline-separated inputs. Can either be strings of text or image URIs starting with http[s]://",
default="a\nb",
),
) -> List[NamedEmbedding]:
lines = []
texts = []
image_urls = []
images = []
for line in inputs.strip().splitlines():
line = line.strip()
lines.append(line)
if re.match("^https?://", line):
try:
print(f"Downloading {line}", file=sys.stderr)
image = Image.open(BytesIO(requests.get(line).content))
images.append(image)
image_urls.append(line)
except Exception as e:
print(f"Failed to load {line}: {e}", file=sys.stderr)
else:
texts.append(line)
if not images:
images = None
if not texts:
texts = None
inputs = self.processor(
text=texts, images=images, return_tensors="pt", padding=True
).to("cuda")
if texts:
text_embeds = self.model.get_text_features(
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
)
text_outputs = dict(zip(texts, text_embeds))
else:
text_outputs = {}
if images:
image_embeds = self.model.get_image_features(
pixel_values=inputs["pixel_values"]
)
image_outputs = dict(zip(image_urls, image_embeds))
else:
iamge_outputs = {}
outputs = []
for line in lines:
if line in text_outputs:
outputs.append(
NamedEmbedding(input=line, embedding=text_outputs[line].tolist())
)
else:
outputs.append(
NamedEmbedding(input=line, embedding=image_outputs[line].tolist())
)
return outputs