From 5436bd5de1930256c98cd5f2eb3cbcd0a6e73f2b Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 20 Sep 2023 20:28:38 +0300 Subject: [PATCH] [ENH]: Support for $in and $nin metadata filters (#1151) Refs: #1105 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - JS Client support for $in and $nin ## Test plan *How are these changes tested?* - [x] Tests pass locally `yarn test` for js ## Documentation Changes TBD --- clients/js/src/types.ts | 8 +-- clients/js/test/query.collection.test.ts | 68 ++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/clients/js/src/types.ts b/clients/js/src/types.ts index 8787d5e5659..1f1dd04c4c8 100644 --- a/clients/js/src/types.ts +++ b/clients/js/src/types.ts @@ -20,13 +20,15 @@ export type IDs = ID[]; export type PositiveInteger = number; -type LiteralValue = string | number; +type LiteralValue = string | number | boolean; +type ListLiteralValue = LiteralValue[]; type LiteralNumber = number; type LogicalOperator = "$and" | "$or"; +type InclusionOperator = "$in" | "$nin"; type WhereOperator = "$gt" | "$gte" | "$lt" | "$lte" | "$ne" | "$eq"; type OperatorExpression = { - [key in WhereOperator | LogicalOperator]?: LiteralValue | LiteralNumber; + [key in WhereOperator | InclusionOperator | LogicalOperator ]?: LiteralValue | ListLiteralValue; }; type BaseWhere = { @@ -77,4 +79,4 @@ export type CollectionMetadata = Record; // see all options here: https://www.jsdocs.io/package/@types/node-fetch#RequestInit export type ConfigOptions = { options?: RequestInit; -}; \ No newline at end of file +}; diff --git a/clients/js/test/query.collection.test.ts b/clients/js/test/query.collection.test.ts index 05125a27ffa..878ed0a71df 100644 --- a/clients/js/test/query.collection.test.ts +++ b/clients/js/test/query.collection.test.ts @@ -86,3 +86,71 @@ test("it should query a collection with text", async () => { expect.arrayContaining(results.documents[0]) ); }) + + +test("it should query a collection with text and where", async () => { + await chroma.reset(); + let embeddingFunction = new TestEmbeddingFunction(); + const collection = await chroma.createCollection({ name: "test", embeddingFunction: embeddingFunction }); + await collection.add({ ids: IDS, embeddings: EMBEDDINGS, metadatas: METADATAS, documents: DOCUMENTS }); + + const results = await collection.query({ + queryTexts: ["test"], + nResults: 3, + where: { "float_value" : 2 } + }); + + expect(results).toBeDefined(); + expect(results).toBeInstanceOf(Object); + expect(results.ids.length).toBe(1); + expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); + expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(["This is a third test"]).toEqual( + expect.arrayContaining(results.documents[0]) + ); +}) + + +test("it should query a collection with text and where in", async () => { + await chroma.reset(); + let embeddingFunction = new TestEmbeddingFunction(); + const collection = await chroma.createCollection({ name: "test", embeddingFunction: embeddingFunction }); + await collection.add({ ids: IDS, embeddings: EMBEDDINGS, metadatas: METADATAS, documents: DOCUMENTS }); + + const results = await collection.query({ + queryTexts: ["test"], + nResults: 3, + where: { "float_value" : { '$in': [2,5,10] }} + }); + + expect(results).toBeDefined(); + expect(results).toBeInstanceOf(Object); + expect(results.ids.length).toBe(1); + expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); + expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(["This is a third test"]).toEqual( + expect.arrayContaining(results.documents[0]) + ); +}) + +test("it should query a collection with text and where nin", async () => { + await chroma.reset(); + let embeddingFunction = new TestEmbeddingFunction(); + const collection = await chroma.createCollection({ name: "test", embeddingFunction: embeddingFunction }); + await collection.add({ ids: IDS, embeddings: EMBEDDINGS, metadatas: METADATAS, documents: DOCUMENTS }); + + const results = await collection.query({ + queryTexts: ["test"], + nResults: 3, + where: { "float_value" : { '$nin': [-2,0] }} + }); + + expect(results).toBeDefined(); + expect(results).toBeInstanceOf(Object); + expect(results.ids.length).toBe(1); + expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); + expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(["This is a third test"]).toEqual( + expect.arrayContaining(results.documents[0]) + ); +})