Skip to content

Commit

Permalink
feat: Aggregation result query for metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Feb 27, 2024
1 parent 3908b7b commit b07c740
Show file tree
Hide file tree
Showing 3 changed files with 425 additions and 27 deletions.
100 changes: 73 additions & 27 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import orjson as json
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List
from chromadb.segment import MetadataReader
from chromadb.ingest import Consumer
Expand Down Expand Up @@ -30,7 +31,7 @@
from pypika import Table, Tables
from pypika.queries import QueryBuilder
import pypika.functions as fn
from pypika.terms import Criterion
from pypika.terms import Criterion, CustomFunction, Function
from itertools import groupby
from functools import reduce
import sqlite3
Expand Down Expand Up @@ -127,6 +128,36 @@ def get_metadata(
if limit < 0:
raise ValueError("Limit cannot be negative")

class CustomStringFunction(Function):
def __init__(self, params, alias=None):
super().__init__("CustomStringFunction", params)
self.params = params
self._alias = alias

def get_function_sql(self, **kwargs: Any):
return f"{self.params}"

def _as(self, alias):
self._alias = alias
def get_sql(self, **kwargs: Any):
return self.get_function_sql(**kwargs)
class MyCustomFunction(Function):
def __init__(self, function_name, params, alias=None):
super().__init__(function_name, params)
self.function_name = function_name
self.params = params
self._alias = alias

def get_function_sql(self, **kwargs: Any):
sql = f"{self.function_name}({self.params})"
if self._alias:
sql += f" AS {self._alias}"
return sql
def get_sql(self, **kwargs: Any):
return self.get_function_sql(**kwargs)

def _as(self, alias):
self._alias = alias
q = (
(
self._db.querybuilder()
Expand All @@ -138,13 +169,20 @@ def get_metadata(
embeddings_t.id,
embeddings_t.embedding_id,
embeddings_t.seq_id,
metadata_t.key,
metadata_t.string_value,
metadata_t.int_value,
metadata_t.float_value,
metadata_t.bool_value,
MyCustomFunction("MAX", "CASE WHEN embedding_metadata.key = 'chroma:document' THEN embedding_metadata.string_value END", "document_content"),
CustomStringFunction("""'{' || GROUP_CONCAT(DISTINCT CASE WHEN embedding_metadata.key <> 'chroma:document' THEN '"' || embedding_metadata.key || '": ' || CASE
WHEN embedding_metadata.string_value IS NOT NULL
THEN '"' || replace(embedding_metadata.string_value, '"','\"')|| '"'
WHEN embedding_metadata.int_value IS NOT NULL
THEN embedding_metadata.int_value
WHEN embedding_metadata.float_value IS NOT NULL
THEN embedding_metadata.float_value
ELSE embedding_metadata.bool_value
END END) || '}'""",alias="metadata")

)
.orderby(embeddings_t.embedding_id)
.groupby(embeddings_t.id)
)

# If there is a query that touches the metadata table, it uses
Expand Down Expand Up @@ -222,31 +260,39 @@ def _records(
sql, params = get_sql(q)
cur.execute(sql, params)

cur_iterator = iter(cur.fetchone, None)
group_iterator = groupby(cur_iterator, lambda r: int(r[0]))

for _, group in group_iterator:
yield self._record(list(group))
# cur_iterator = iter(cur.fetchone, None)
# # group_iterator = groupby(cur_iterator, lambda r: int(r[0]))
#
# for record in cur_iterator:
# yield self._record(record)

batch_size = 100 # Adjust batch size as needed
while True:
records = cur.fetchmany(batch_size)
if not records:
break
for record in records:
yield self._record(record)

@trace_method("SqliteMetadataSegment._record", OpenTelemetryGranularity.ALL)
def _record(self, rows: Sequence[Tuple[Any, ...]]) -> MetadataEmbeddingRecord:
def _record(self, rows: Tuple[Any, ...]) -> MetadataEmbeddingRecord:
"""Given a list of DB rows with the same ID, construct a
MetadataEmbeddingRecord"""
_, embedding_id, seq_id = rows[0][:3]
metadata = {}
for row in rows:
key, string_value, int_value, float_value, bool_value = row[3:]
if string_value is not None:
metadata[key] = string_value
elif int_value is not None:
metadata[key] = int_value
elif float_value is not None:
metadata[key] = float_value
elif bool_value is not None:
if bool_value == 1:
metadata[key] = True
else:
metadata[key] = False
_, embedding_id, seq_id = rows[:3]
metadata = json.loads(rows[4])
# for row in rows:
# key, string_value, int_value, float_value, bool_value = row[3:]
# if string_value is not None:
# metadata[key] = string_value
# elif int_value is not None:
# metadata[key] = int_value
# elif float_value is not None:
# metadata[key] = float_value
# elif bool_value is not None:
# if bool_value == 1:
# metadata[key] = True
# else:
# metadata[key] = False

return MetadataEmbeddingRecord(
id=embedding_id,
Expand Down
219 changes: 219 additions & 0 deletions query_aggr.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-02-24T14:20:19.396517Z",
"start_time": "2024-02-24T14:20:18.329658Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n",
"Insert of existing embedding ID: 0\n",
"Add of existing embedding ID: 0\n"
]
}
],
"source": [
"import chromadb\n",
"\n",
"client = chromadb.PersistentClient(\"aggr\")\n",
"\n",
"col = client.get_or_create_collection(\"test\")\n",
"\n",
"col.add(ids=[str(i) for i in range(1)],documents=[\"doc 1\"],metadatas=[{\"a\":\"1\",\"b\":\"2\"}])"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT \"embeddings\".\"id\",\"embeddings\".\"embedding_id\",\"embeddings\".\"seq_id\",MAX(CASE WHEN embedding_metadata.key = 'chroma:document' THEN embedding_metadata.string_value END) AS document_content,'{' || GROUP_CONCAT(DISTINCT CASE WHEN embedding_metadata.key <> 'chroma:document' THEN '\"' || embedding_metadata.key || '\": ' || CASE\n",
" WHEN embedding_metadata.string_value IS NOT NULL\n",
" THEN '\"' || replace(embedding_metadata.string_value, '\"','\"')|| '\"'\n",
" WHEN embedding_metadata.int_value IS NOT NULL\n",
" THEN embedding_metadata.int_value\n",
" WHEN embedding_metadata.float_value IS NOT NULL\n",
" THEN embedding_metadata.float_value\n",
" ELSE embedding_metadata.bool_value\n",
" END END) || '}' FROM \"embeddings\" LEFT JOIN \"embedding_metadata\" ON \"embeddings\".\"id\"=\"embedding_metadata\".\"id\" WHERE \"embeddings\".\"id\" IN (SELECT \"id\" FROM \"embeddings\" WHERE \"segment_id\"=? ORDER BY \"embedding_id\" LIMIT 9223372036854775807) GROUP BY \"embeddings\".\"id\" ORDER BY \"embeddings\".\"embedding_id\" ('d3ac4e3a-db42-429f-89a9-61f2181fcc80',)\n",
"[(1, '0', b'\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x01', 'doc 1', '{\"a\": \"1\",\"b\": \"2\"}')]\n"
]
},
{
"data": {
"text/plain": "{'ids': ['0'],\n 'embeddings': None,\n 'metadatas': [{'a': '1', 'b': '2'}],\n 'documents': [None],\n 'uris': None,\n 'data': None}"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"col.get()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-02-24T14:20:20.498487Z",
"start_time": "2024-02-24T14:20:20.426113Z"
}
},
"id": "57d374a1d5313d0a",
"execution_count": 2
},
{
"cell_type": "code",
"outputs": [],
"source": [
"from typing import Any\n",
"from pypika.utils import builder\n",
"from pypika.terms import Function\n",
"\n",
"\n",
"class CustomMax(Function):\n",
" is_aggregate = True\n",
"\n",
" def __init__(self, name, *args, **kwargs):\n",
" super(CustomMax, self).__init__(name, *args, **kwargs)\n",
"\n",
" self._filters = []\n",
" self._include_filter = False\n",
"\n",
" # @builder\n",
" # def filter(self, *filters: Any) -> \"AnalyticFunction\":\n",
" # self._include_filter = True\n",
" # self._filters += filters\n",
" # \n",
" # def get_filter_sql(self, **kwargs: Any) -> str:\n",
" # if self._include_filter:\n",
" # return \"WHERE {criterions}\".format(criterions=Criterion.all(self._filters).get_sql(**kwargs))\n",
"\n",
" def get_function_sql(self, **kwargs: Any):\n",
" sql = super(CustomMax, self).get_function_sql(**kwargs)\n",
"\n",
" return sql\n",
"\n",
"custom_max = CustomMax(\"max\", \"column\")\n",
"\n",
"print(custom_max.get_sql())"
],
"metadata": {
"collapsed": false
},
"id": "23df5d0d67f046f3",
"execution_count": null
},
{
"cell_type": "code",
"outputs": [],
"source": [
"from pypika import CustomFunction\n",
"\n",
"\n",
"class MyCustomFunction(CustomFunction):\n",
" def __init__(self, function_name, params, alias=None):\n",
" super().__init__(function_name, params)\n",
" self.function_name = function_name\n",
" self._alias = alias\n",
" \n",
" def get_function_sql(self, **kwargs: Any):\n",
" sql = f\"{self.function_name}({self.params})\"\n",
" if self._alias:\n",
" sql += f\" AS {self._alias}\"\n",
" return sql\n",
" \n",
" def _as(self, alias):\n",
" self._alias = alias\n",
" \n",
"print(MyCustomFunction(\"max\", \"CASE WHEN emd.key = 'chroma:document' THEN emd.string_value END\", \"document_content\").get_function_sql())\n",
" \n"
],
"metadata": {
"collapsed": false
},
"id": "850774b7c68bc8c9",
"execution_count": null
},
{
"cell_type": "code",
"outputs": [],
"source": [
"class CustomStringFunction(CustomFunction):\n",
" def __init__(self, params, alias=None):\n",
" super().__init__(\"CustomStringFunction\", params)\n",
" self._alias = alias\n",
"\n",
" def get_function_sql(self, **kwargs: Any):\n",
" return f\"{self.params}\"\n",
"\n",
" def _as(self, alias):\n",
" self._alias = alias\n",
"\n",
"print(CustomStringFunction(\"\"\"'{' || GROUP_CONCAT(DISTINCT CASE WHEN emd.key <> 'chroma:document' THEN '\"' || emd.key || '\": ' || CASE\n",
" WHEN emd.string_value IS NOT NULL\n",
" THEN '\"' || replace(emd.string_value, '\"','\\\"')|| '\"'\n",
" WHEN emd.int_value IS NOT NULL\n",
" THEN emd.int_value\n",
" WHEN float_value IS NOT NULL\n",
" THEN float_value\n",
" ELSE bool_value\n",
" END END) || '}'\"\"\",alias=\"metadata\").get_function_sql())"
],
"metadata": {
"collapsed": false
},
"id": "d7775896077e9bb4",
"execution_count": null
},
{
"cell_type": "code",
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "e6ec8967a68d32dc"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit b07c740

Please sign in to comment.