-
Notifications
You must be signed in to change notification settings - Fork 1
/
get_embeddings.py
148 lines (123 loc) · 4.74 KB
/
get_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
import functools
import itertools
import json
import logging
import os
import random
import time
from concurrent.futures import ThreadPoolExecutor
import google.generativeai as genai
import numpy as np
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
genai.configure(transport="rest")
parser = argparse.ArgumentParser()
parser.add_argument('--crawl_result_dir', type=str, default='outputs')
parser.add_argument('--max_thread', type=int, default=8)
args = parser.parse_args()
crawl_results = os.listdir(args.crawl_result_dir)
crawl_results = [os.path.join(args.crawl_result_dir, file_name) for file_name in crawl_results if file_name.endswith(".json")]
crawl_results.sort(key=lambda x: int(x.split('result')[1].split('.json')[0]))
get_embedding_funcs = [] # type: list[Callable]
embedding_to_file_name = []
embedding_to_inner_file_indices = []
embedding_to_fields = []
def retry_with_timeout_decorator(
max_retries: int = 3,
base_delay: int = 10,
factor: int = 2,
jitter: bool = True,
):
"""
A retry decorator with exponentially increasing timeout.
:param max_retries: Maximum number of retries
:param base_delay: Base delay time in seconds
:param factor: Factor by which to increase the delay
:param jitter: If True, add jitter to the delay
:return: the decorator that wraps the function
"""
def decorator(func):
@functools.wraps(func)
def retry_calling() -> dict:
retries = 0
while retries < max_retries:
try:
# Compute the current delay
delay = base_delay * (factor**retries)
if jitter:
delay += random.uniform(0, 1) # Add jitter
return func()
except Exception as e:
retries += 1
if retries >= max_retries:
return {
"embeddings": [0.0] * 768,
}
logger.info(
f"Fail in the try {retries}/{max_retries} in {delay:.2f} seconds..."
)
time.sleep(delay)
return retry_calling
return decorator
def combine_title_abs(title: str, abs: str):
return f"# {title}\n\nAbstract: {abs}"
for file_name in crawl_results:
with open(file_name, "rt") as f:
dict_list = json.load(f)
## title embeddings
batch_size = len(dict_list)
get_embedding_funcs += [
functools.partial(
genai.embed_content,
model="models/embedding-001",
content=item["title"],
task_type="clustering",
)
for item in dict_list
]
embedding_to_file_name += [file_name] * batch_size
embedding_to_inner_file_indices += list(range(batch_size))
embedding_to_fields += ["title_embedding"] * batch_size
## title + abstract embeddings
batch_size = len(dict_list)
get_embedding_funcs += [
functools.partial(
genai.embed_content,
model="models/embedding-001",
content=combine_title_abs(item["title"], item.get("abstract", "")),
task_type="clustering",
)
for item in dict_list
]
embedding_to_file_name += [file_name] * batch_size
embedding_to_inner_file_indices += list(range(batch_size))
embedding_to_fields += ["title_abs_embedding"] * batch_size
# get_embedding_funcs = get_embedding_funcs[:1000] # TODO: delete it
get_embedding_funcs = [
retry_with_timeout_decorator(
max_retries=3,
base_delay=60 / 1200, # 1500 RPM at peak, leave some redundency
factor=2,
jitter=True,
)(func)
for func in get_embedding_funcs
]
with ThreadPoolExecutor(max_workers=args.max_thread) as executor:
embeddings = list(executor.map(lambda func: func(), get_embedding_funcs))
embeddings = [d["embedding"] for d in embeddings]
embeddings = np.array(embeddings)
embeddings = embeddings.astype(np.float16)
np.save(os.path.join(args.crawl_result_dir, "embeddings.npy"), embeddings)
grouped_by_file_indices = list(range(len(embeddings)))
grouped_by_file_indices.sort(key=lambda x: embedding_to_file_name[x])
for file_name, indices in itertools.groupby(grouped_by_file_indices, key=lambda x: embedding_to_file_name[x]):
indices = list(indices)
with open(file_name, "rt") as f:
output_dicts = json.load(f)
for index, embedding in zip(indices, embeddings):
inner_index = embedding_to_inner_file_indices[index]
field = embedding_to_fields[index]
output_dicts[inner_index][f"{embedding_to_fields[index]}_index"] = index
with open(file_name + 'l', "wt") as f: # TODO turn it to json
json.dump(output_dicts, f)