Skip to content

Commit

Permalink
Merge branch 'main' into ragna-base-poc
Browse files Browse the repository at this point in the history
  • Loading branch information
arjxn-py authored Jun 29, 2024
2 parents ed1cfde + a0bf68c commit 079093e
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def retrieve(
str(chat_id), embedding_function=self._embedding_function
)

include = ["distances", "metadatas", "documents"]
result = collection.query(
query_texts=prompt,
n_results=min(
Expand All @@ -100,22 +101,19 @@ def retrieve(
max(int(num_tokens * 2 / chunk_size), 100),
collection.count(),
),
include=["distances", "metadatas", "documents"],
include=include, # type: ignore[arg-type]
)

num_results = len(result["ids"][0])
result = {
key: [None] * num_results if value is None else value[0] # type: ignore[index]
for key, value in result.items()
}
result = {key: result[key][0] for key in ["ids", *include]} # type: ignore[literal-required]
# dict of lists -> list of dicts
results = [
{key[:-1]: value[idx] for key, value in result.items()}
{key: value[idx] for key, value in result.items()}
for idx in range(num_results)
]

# That should be the default, but let's make extra sure here
results = sorted(results, key=lambda r: r["distance"])
results = sorted(results, key=lambda r: r["distances"])

# TODO: we should have some functionality here to remove results with a high
# distance to keep only "valid" sources. However, there are two issues:
Expand All @@ -127,11 +125,11 @@ def retrieve(
return self._take_sources_up_to_max_tokens(
(
Source(
id=result["id"],
document=document_map[result["metadata"]["document_id"]],
location=result["metadata"]["page_numbers"],
content=result["document"],
num_tokens=result["metadata"]["num_tokens"],
id=result["ids"],
document=document_map[result["metadatas"]["document_id"]],
location=result["metadatas"]["page_numbers"],
content=result["documents"],
num_tokens=result["metadatas"]["num_tokens"],
)
for result in results
),
Expand Down

0 comments on commit 079093e

Please sign in to comment.