Skip to content

Commit

Permalink
Use .exists to get the count of results, rather than actually getting…
Browse files Browse the repository at this point in the history
… the count

- .count forces the query to be resolved in memory and loads a ton of data in. Avoid that by using the .exists method
  • Loading branch information
sabaimran committed Aug 29, 2023
1 parent 529b330 commit 06c8c06
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions src/flint/embeddings/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

logger = logging.getLogger(__name__)


@dataclass
class Embedding:
compiled: str
vector: List[float]


class EmbeddingsManager():
class EmbeddingsManager:
def __init__(self):
model_name = "intfloat/multilingual-e5-large"
encode_kwargs = {'normalize_embeddings': True}
encode_kwargs = {"normalize_embeddings": True}

if torch.cuda.is_available():
# Use CUDA GPU
Expand All @@ -37,18 +38,16 @@ def __init__(self):
else:
device = torch.device("cpu")

model_kwargs = {'device': device}
model_kwargs = {"device": device}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=model_name,
encode_kwargs=encode_kwargs,
model_kwargs=model_kwargs
model_name=model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
self.max_tokens = 512

def generate_embeddings(self, text: str):
# Split into chunks of 512 tokens
for chunk_index in range(0, len(text), self.max_tokens):
chunk = text[chunk_index:chunk_index+self.max_tokens]
chunk = text[chunk_index : chunk_index + self.max_tokens]
embedding_chunk = f"passage: {chunk}"
embeddings = self.embeddings_model.embed_documents([embedding_chunk])
yield Embedding(chunk, embeddings[0])
Expand All @@ -57,22 +56,27 @@ async def search(self, query: str, user: User, top_n: int = 3, debug: bool = Fal
conversations_to_search = user.conversations.all()
formatted_query = f"query: {query}"
embedded_query = self.embeddings_model.embed_query(formatted_query)
sorted_vectors = ConversationVector.objects.filter(conversation__in=conversations_to_search).alias(distance=CosineDistance('vector', embedded_query)).filter(distance__lte=0.20).order_by('distance')
sorted_vectors = (
ConversationVector.objects.filter(conversation__in=conversations_to_search)
.alias(distance=CosineDistance("vector", embedded_query))
.filter(distance__lte=0.20)
.order_by("distance")[:top_n]
)

num_vectors = await sync_to_async(sorted_vectors.count)()
if num_vectors == 0:
if not sync_to_async(sorted_vectors.exists)():
return Conversation.objects.none()

if num_vectors > top_n:
sorted_vectors = sorted_vectors[:top_n]

if debug:
annotated_result = ConversationVector.objects.filter(conversation__in=conversations_to_search).annotate(distance=CosineDistance('vector', embedded_query)).order_by('distance')[:10]
annotated_result = (
ConversationVector.objects.filter(conversation__in=conversations_to_search)
.annotate(distance=CosineDistance("vector", embedded_query))
.order_by("distance")[:10]
)
debugging_vectors = await sync_to_async(list)(annotated_result.all())

for vector in debugging_vectors:
logger.debug(f"Compiled: {vector.compiled}")
logger.debug(f"Distance: {vector.distance}")

n_matching_conversations = sorted_vectors.values_list('conversation', flat=True)
n_matching_conversations = sorted_vectors.values_list("conversation", flat=True)
return Conversation.objects.filter(id__in=n_matching_conversations).distinct()

0 comments on commit 06c8c06

Please sign in to comment.