forked from chroma-core/chroma
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Aggregation result query for metadata
- Loading branch information
Showing
3 changed files
with
425 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "initial_id", | ||
"metadata": { | ||
"collapsed": true, | ||
"ExecuteTime": { | ||
"end_time": "2024-02-24T14:20:19.396517Z", | ||
"start_time": "2024-02-24T14:20:18.329658Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n", | ||
"Insert of existing embedding ID: 0\n", | ||
"Add of existing embedding ID: 0\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import chromadb\n", | ||
"\n", | ||
"client = chromadb.PersistentClient(\"aggr\")\n", | ||
"\n", | ||
"col = client.get_or_create_collection(\"test\")\n", | ||
"\n", | ||
"col.add(ids=[str(i) for i in range(1)],documents=[\"doc 1\"],metadatas=[{\"a\":\"1\",\"b\":\"2\"}])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"SELECT \"embeddings\".\"id\",\"embeddings\".\"embedding_id\",\"embeddings\".\"seq_id\",MAX(CASE WHEN embedding_metadata.key = 'chroma:document' THEN embedding_metadata.string_value END) AS document_content,'{' || GROUP_CONCAT(DISTINCT CASE WHEN embedding_metadata.key <> 'chroma:document' THEN '\"' || embedding_metadata.key || '\": ' || CASE\n", | ||
" WHEN embedding_metadata.string_value IS NOT NULL\n", | ||
" THEN '\"' || replace(embedding_metadata.string_value, '\"','\"')|| '\"'\n", | ||
" WHEN embedding_metadata.int_value IS NOT NULL\n", | ||
" THEN embedding_metadata.int_value\n", | ||
" WHEN embedding_metadata.float_value IS NOT NULL\n", | ||
" THEN embedding_metadata.float_value\n", | ||
" ELSE embedding_metadata.bool_value\n", | ||
" END END) || '}' FROM \"embeddings\" LEFT JOIN \"embedding_metadata\" ON \"embeddings\".\"id\"=\"embedding_metadata\".\"id\" WHERE \"embeddings\".\"id\" IN (SELECT \"id\" FROM \"embeddings\" WHERE \"segment_id\"=? ORDER BY \"embedding_id\" LIMIT 9223372036854775807) GROUP BY \"embeddings\".\"id\" ORDER BY \"embeddings\".\"embedding_id\" ('d3ac4e3a-db42-429f-89a9-61f2181fcc80',)\n", | ||
"[(1, '0', b'\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x01', 'doc 1', '{\"a\": \"1\",\"b\": \"2\"}')]\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": "{'ids': ['0'],\n 'embeddings': None,\n 'metadatas': [{'a': '1', 'b': '2'}],\n 'documents': [None],\n 'uris': None,\n 'data': None}" | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"col.get()" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-02-24T14:20:20.498487Z", | ||
"start_time": "2024-02-24T14:20:20.426113Z" | ||
} | ||
}, | ||
"id": "57d374a1d5313d0a", | ||
"execution_count": 2 | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"from typing import Any\n", | ||
"from pypika.utils import builder\n", | ||
"from pypika.terms import Function\n", | ||
"\n", | ||
"\n", | ||
"class CustomMax(Function):\n", | ||
" is_aggregate = True\n", | ||
"\n", | ||
" def __init__(self, name, *args, **kwargs):\n", | ||
" super(CustomMax, self).__init__(name, *args, **kwargs)\n", | ||
"\n", | ||
" self._filters = []\n", | ||
" self._include_filter = False\n", | ||
"\n", | ||
" # @builder\n", | ||
" # def filter(self, *filters: Any) -> \"AnalyticFunction\":\n", | ||
" # self._include_filter = True\n", | ||
" # self._filters += filters\n", | ||
" # \n", | ||
" # def get_filter_sql(self, **kwargs: Any) -> str:\n", | ||
" # if self._include_filter:\n", | ||
" # return \"WHERE {criterions}\".format(criterions=Criterion.all(self._filters).get_sql(**kwargs))\n", | ||
"\n", | ||
" def get_function_sql(self, **kwargs: Any):\n", | ||
" sql = super(CustomMax, self).get_function_sql(**kwargs)\n", | ||
"\n", | ||
" return sql\n", | ||
"\n", | ||
"custom_max = CustomMax(\"max\", \"column\")\n", | ||
"\n", | ||
"print(custom_max.get_sql())" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "23df5d0d67f046f3", | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"from pypika import CustomFunction\n", | ||
"\n", | ||
"\n", | ||
"class MyCustomFunction(CustomFunction):\n", | ||
" def __init__(self, function_name, params, alias=None):\n", | ||
" super().__init__(function_name, params)\n", | ||
" self.function_name = function_name\n", | ||
" self._alias = alias\n", | ||
" \n", | ||
" def get_function_sql(self, **kwargs: Any):\n", | ||
" sql = f\"{self.function_name}({self.params})\"\n", | ||
" if self._alias:\n", | ||
" sql += f\" AS {self._alias}\"\n", | ||
" return sql\n", | ||
" \n", | ||
" def _as(self, alias):\n", | ||
" self._alias = alias\n", | ||
" \n", | ||
"print(MyCustomFunction(\"max\", \"CASE WHEN emd.key = 'chroma:document' THEN emd.string_value END\", \"document_content\").get_function_sql())\n", | ||
" \n" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "850774b7c68bc8c9", | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"class CustomStringFunction(CustomFunction):\n", | ||
" def __init__(self, params, alias=None):\n", | ||
" super().__init__(\"CustomStringFunction\", params)\n", | ||
" self._alias = alias\n", | ||
"\n", | ||
" def get_function_sql(self, **kwargs: Any):\n", | ||
" return f\"{self.params}\"\n", | ||
"\n", | ||
" def _as(self, alias):\n", | ||
" self._alias = alias\n", | ||
"\n", | ||
"print(CustomStringFunction(\"\"\"'{' || GROUP_CONCAT(DISTINCT CASE WHEN emd.key <> 'chroma:document' THEN '\"' || emd.key || '\": ' || CASE\n", | ||
" WHEN emd.string_value IS NOT NULL\n", | ||
" THEN '\"' || replace(emd.string_value, '\"','\\\"')|| '\"'\n", | ||
" WHEN emd.int_value IS NOT NULL\n", | ||
" THEN emd.int_value\n", | ||
" WHEN float_value IS NOT NULL\n", | ||
" THEN float_value\n", | ||
" ELSE bool_value\n", | ||
" END END) || '}'\"\"\",alias=\"metadata\").get_function_sql())" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "d7775896077e9bb4", | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "e6ec8967a68d32dc" | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.