Skip to content

Commit

Permalink
Merge pull request #85 from cpacker/fast-embeddings
Browse files Browse the repository at this point in the history
Parallelize embedding generation
  • Loading branch information
cpacker authored Oct 22, 2023
2 parents e7cb16f + f925735 commit 5c96d6e
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions memgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime

import asyncio
import csv
import difflib
import demjson3 as demjson
Expand Down Expand Up @@ -194,6 +195,31 @@ def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'):
ret.append(curr_file)
return ret

async def process_chunk(i, chunk, model):
try:
return i, await async_get_embedding_with_backoff(chunk['content'], model=model)
except Exception as e:
print(chunk)
raise e

async def process_concurrently(archival_database, model, concurrency=10):
# Create a semaphore to limit the number of concurrent tasks
semaphore = asyncio.Semaphore(concurrency)

async def bounded_process_chunk(i, chunk):
async with semaphore:
return await process_chunk(i, chunk, model)

# Create a list of tasks for chunks
embedding_data = [0 for _ in archival_database]
tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]

for future in tqdm(asyncio.as_completed(tasks), total=len(archival_database), desc="Processing file chunks"):
i, result = await future
embedding_data[i] = result

return embedding_data

async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'):
files = sorted(glob.glob(glob_pattern))
save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_')
Expand All @@ -206,15 +232,7 @@ async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkn

# chunk the files, make embeddings
archival_database = chunk_files(files, tkns_per_chunk, model)
embedding_data = []
for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)):
# for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False):
try:
embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model)
except Exception as e:
print(chunk)
raise e
embedding_data.append(embedding)
embedding_data = await process_concurrently(archival_database, embeddings_model)
embeddings_file = os.path.join(save_dir, "embeddings.json")
with open(embeddings_file, 'w') as f:
print(f"Saving embeddings to {embeddings_file}")
Expand Down

0 comments on commit 5c96d6e

Please sign in to comment.