diff --git a/clients/js/src/AdminClient.ts b/clients/js/src/AdminClient.ts index d6fd56c1a06..99f4974bcbe 100644 --- a/clients/js/src/AdminClient.ts +++ b/clients/js/src/AdminClient.ts @@ -74,10 +74,6 @@ export class AdminClient { ...this.api.options.headers, ...this.authProvider.authenticate(), }; - this.api.options.headers = { - ...this.api.options.headers, - ...this.authProvider.authenticate(), - }; } } diff --git a/clients/js/src/ChromaClient.ts b/clients/js/src/ChromaClient.ts index 725332eafcf..a2e85909482 100644 --- a/clients/js/src/ChromaClient.ts +++ b/clients/js/src/ChromaClient.ts @@ -1,5 +1,5 @@ import { Configuration, ApiApi as DefaultApi } from "./generated"; -import { handleSuccess } from "./utils"; +import { handleSuccess, validateTenantDatabase } from "./utils"; import { Collection } from "./Collection"; import { ChromaClientParams, @@ -25,10 +25,11 @@ export class ChromaClient { * @ignore */ private api: DefaultApi & ConfigOptions; - private tenant: string = DEFAULT_TENANT; - private database: string = DEFAULT_DATABASE; - private _adminClient?: AdminClient; + private tenant: string; + private database: string; + private _adminClient: AdminClient; private authProvider: ClientAuthProvider | undefined; + private _initPromise: Promise | undefined; /** * Creates a new ChromaClient instance. @@ -44,13 +45,12 @@ export class ChromaClient { * ``` */ constructor({ - path, + path = "http://localhost:8000", fetchOptions, auth, tenant = DEFAULT_TENANT, database = DEFAULT_DATABASE, }: ChromaClientParams = {}) { - if (path === undefined) path = "http://localhost:8000"; this.tenant = tenant; this.database = database; this.authProvider = undefined; @@ -71,17 +71,25 @@ export class ChromaClient { } this._adminClient = new AdminClient({ - path: path, - fetchOptions: fetchOptions, - auth: auth, - tenant: tenant, - database: database, + path, + fetchOptions, + auth, + tenant, + database, }); + } + + /** @ignore */ + private init(): Promise { + if (!this._initPromise) { + this._initPromise = validateTenantDatabase( + this._adminClient, + this.tenant, + this.database, + ); + } - // TODO: Validate tenant and database on client creation - // this got tricky because: - // - the constructor is sync but the generated api is async - // - we need to inject auth information so a simple rewrite/fetch does not work + return this._initPromise; } /** @@ -96,7 +104,8 @@ export class ChromaClient { * await client.reset(); * ``` */ - public async reset(): Promise { + async reset(): Promise { + await this.init(); return await this.api.reset(this.api.options); } @@ -110,7 +119,7 @@ export class ChromaClient { * const version = await client.version(); * ``` */ - public async version(): Promise { + async version(): Promise { const response = await this.api.version(this.api.options); return await handleSuccess(response); } @@ -125,7 +134,7 @@ export class ChromaClient { * const heartbeat = await client.heartbeat(); * ``` */ - public async heartbeat(): Promise { + async heartbeat(): Promise { const response = await this.api.heartbeat(this.api.options); let ret = await handleSuccess(response); return ret["nanosecond heartbeat"]; @@ -153,15 +162,12 @@ export class ChromaClient { * }); * ``` */ - public async createCollection({ + async createCollection({ name, metadata, - embeddingFunction, + embeddingFunction = new DefaultEmbeddingFunction(), }: CreateCollectionParams): Promise { - if (embeddingFunction === undefined) { - embeddingFunction = new DefaultEmbeddingFunction(); - } - + await this.init(); const newCollection = await this.api .createCollection( this.tenant, @@ -211,15 +217,12 @@ export class ChromaClient { * }); * ``` */ - public async getOrCreateCollection({ + async getOrCreateCollection({ name, metadata, - embeddingFunction, + embeddingFunction = new DefaultEmbeddingFunction(), }: GetOrCreateCollectionParams): Promise { - if (embeddingFunction === undefined) { - embeddingFunction = new DefaultEmbeddingFunction(); - } - + await this.init(); const newCollection = await this.api .createCollection( this.tenant, @@ -259,10 +262,10 @@ export class ChromaClient { * }); * ``` */ - public async listCollections({ - limit, - offset, - }: ListCollectionsParams = {}): Promise { + async listCollections({ limit, offset }: ListCollectionsParams = {}): Promise< + CollectionType[] + > { + await this.init(); const response = await this.api.listCollections( limit, offset, @@ -284,7 +287,9 @@ export class ChromaClient { * const collections = await client.countCollections(); * ``` */ - public async countCollections(): Promise { + async countCollections(): Promise { + await this.init(); + const response = await this.api.countCollections( this.tenant, this.database, @@ -308,10 +313,12 @@ export class ChromaClient { * }); * ``` */ - public async getCollection({ + async getCollection({ name, embeddingFunction, }: GetCollectionParams): Promise { + await this.init(); + const response = await this.api .getCollection(name, this.tenant, this.database, this.api.options) .then(handleSuccess); @@ -339,9 +346,9 @@ export class ChromaClient { * }); * ``` */ - public async deleteCollection({ - name, - }: DeleteCollectionParams): Promise { + async deleteCollection({ name }: DeleteCollectionParams): Promise { + await this.init(); + return await this.api .deleteCollection(name, this.tenant, this.database, this.api.options) .then(handleSuccess); diff --git a/clients/js/src/auth.ts b/clients/js/src/auth.ts index ddc1ef87163..b73aacca9d7 100644 --- a/clients/js/src/auth.ts +++ b/clients/js/src/auth.ts @@ -39,9 +39,7 @@ export class BasicAuthClientProvider implements ClientAuthProvider { * @throws {Error} If neither credentials provider or text credentials are supplied. */ constructor(textCredentials: string | undefined) { - const envVarTextCredentials = process.env.CHROMA_CLIENT_AUTH_CREDENTIALS; - - const creds = textCredentials ?? envVarTextCredentials; + const creds = textCredentials ?? process.env.CHROMA_CLIENT_AUTH_CREDENTIALS; if (creds === undefined) { throw new Error( "Credentials must be supplied via environment variable (CHROMA_CLIENT_AUTH_CREDENTIALS) or passed in as configuration.", @@ -64,9 +62,7 @@ export class TokenAuthClientProvider implements ClientAuthProvider { textCredentials: any, headerType: TokenHeaderType = "AUTHORIZATION", ) { - const envVarTextCredentials = process.env.CHROMA_CLIENT_AUTH_CREDENTIALS; - - const creds = textCredentials ?? envVarTextCredentials; + const creds = textCredentials ?? process.env.CHROMA_CLIENT_AUTH_CREDENTIALS; if (creds === undefined) { throw new Error( "Credentials must be supplied via environment variable (CHROMA_CLIENT_AUTH_CREDENTIALS) or passed in as configuration.", diff --git a/clients/js/src/utils.ts b/clients/js/src/utils.ts index e44de16ed71..dcf773cee59 100644 --- a/clients/js/src/utils.ts +++ b/clients/js/src/utils.ts @@ -1,6 +1,7 @@ import { Api } from "./generated"; import Count200Response = Api.Count200Response; import { AdminClient } from "./AdminClient"; +import { ChromaConnectionError } from "./Errors"; // a function to convert a non-Array object to an Array export function toArray(obj: T | Array): Array { @@ -72,16 +73,24 @@ export async function validateTenantDatabase( try { await adminClient.getTenant({ name: tenant }); } catch (error) { + if (error instanceof ChromaConnectionError) { + throw error; + } throw new Error( - `Error: ${error}, Could not connect to tenant ${tenant}. Are you sure it exists?`, + `Could not connect to tenant ${tenant}. Are you sure it exists? Underlying error: +${error}`, ); } try { await adminClient.getDatabase({ name: database, tenantName: tenant }); } catch (error) { + if (error instanceof ChromaConnectionError) { + throw error; + } throw new Error( - `Error: ${error}, Could not connect to database ${database} for tenant ${tenant}. Are you sure it exists?`, + `Could not connect to database ${database} for tenant ${tenant}. Are you sure it exists? Underlying error: +${error}`, ); } } diff --git a/clients/js/test/auth.basic.test.ts b/clients/js/test/auth.basic.test.ts index 056a1424012..44d8fc41080 100644 --- a/clients/js/test/auth.basic.test.ts +++ b/clients/js/test/auth.basic.test.ts @@ -1,7 +1,6 @@ import { expect, test } from "@jest/globals"; import { chromaBasic } from "./initClientWithAuth"; import chromaNoAuth from "./initClient"; -import { ChromaForbiddenError } from "../src/Errors"; test("it should get the version without auth needed", async () => { const version = await chromaNoAuth.version(); @@ -15,19 +14,23 @@ test("it should get the heartbeat without auth needed", async () => { expect(heartbeat).toBeGreaterThan(0); }); -test("it should raise error when non authenticated", async () => { - await expect(chromaNoAuth.listCollections()).rejects.toBeInstanceOf( - ChromaForbiddenError, - ); +test("it should throw error when non authenticated", async () => { + try { + await chromaNoAuth.listCollections(); + } catch (e) { + expect(e).toBeInstanceOf(Error); + } }); test("it should list collections", async () => { - await chromaBasic.reset(); - let collections = await chromaBasic.listCollections(); + const client = chromaBasic(); + + await client.reset(); + let collections = await client.listCollections(); expect(collections).toBeDefined(); expect(collections).toBeInstanceOf(Array); expect(collections.length).toBe(0); - await chromaBasic.createCollection({ name: "test" }); - collections = await chromaBasic.listCollections(); + await client.createCollection({ name: "test" }); + collections = await client.listCollections(); expect(collections.length).toBe(1); }); diff --git a/clients/js/test/auth.token.test.ts b/clients/js/test/auth.token.test.ts index 3444d1dbca4..11328569cf4 100644 --- a/clients/js/test/auth.token.test.ts +++ b/clients/js/test/auth.token.test.ts @@ -6,8 +6,6 @@ import { cloudClient, } from "./initClientWithAuth"; import chromaNoAuth from "./initClient"; -import { ChromaForbiddenError } from "../src/Errors"; - test("it should get the version without auth needed", async () => { const version = await chromaNoAuth.version(); expect(version).toBeDefined(); @@ -21,59 +19,41 @@ test("it should get the heartbeat without auth needed", async () => { }); test("it should raise error when non authenticated", async () => { - await expect(chromaNoAuth.listCollections()).rejects.toBeInstanceOf( - ChromaForbiddenError, - ); + await expect(chromaNoAuth.listCollections()).rejects.toBeInstanceOf(Error); }); -if (!process.env.XTOKEN_TEST) { - test("it should list collections with default token config", async () => { - await chromaTokenDefault.reset(); - let collections = await chromaTokenDefault.listCollections(); +if (process.env.XTOKEN_TEST) { + test.each([ + ["x-token", chromaTokenXToken], + ["cloud client", cloudClient], + ])(`it should list collections with %s`, async (_, clientBuilder) => { + const client = clientBuilder(); + await client.reset(); + let collections = await client.listCollections(); expect(collections).toBeDefined(); expect(collections).toBeInstanceOf(Array); expect(collections.length).toBe(0); - const collection = await chromaTokenDefault.createCollection({ + await client.createCollection({ name: "test", }); - collections = await chromaTokenDefault.listCollections(); - expect(collections.length).toBe(1); - }); - - test("it should list collections with explicit bearer token config", async () => { - await chromaTokenBearer.reset(); - let collections = await chromaTokenBearer.listCollections(); - expect(collections).toBeDefined(); - expect(collections).toBeInstanceOf(Array); - expect(collections.length).toBe(0); - const collection = await chromaTokenBearer.createCollection({ - name: "test", - }); - collections = await chromaTokenBearer.listCollections(); + collections = await client.listCollections(); expect(collections.length).toBe(1); }); } else { - test("it should list collections with explicit x-token token config", async () => { - await chromaTokenXToken.reset(); - let collections = await chromaTokenXToken.listCollections(); + test.each([ + ["default token", chromaTokenDefault], + ["bearer token", chromaTokenBearer], + ])(`it should list collections with %s`, async (_, clientBuilder) => { + const client = clientBuilder(); + await client.reset(); + let collections = await client.listCollections(); expect(collections).toBeDefined(); expect(collections).toBeInstanceOf(Array); expect(collections.length).toBe(0); - const collection = await chromaTokenXToken.createCollection({ + await client.createCollection({ name: "test", }); - collections = await chromaTokenXToken.listCollections(); - expect(collections.length).toBe(1); - }); - - test("it should list collections with explicit x-token token config in CloudClient", async () => { - await cloudClient.reset(); - let collections = await cloudClient.listCollections(); - expect(collections).toBeDefined(); - expect(collections).toBeInstanceOf(Array); - expect(collections.length).toBe(0); - const collection = await cloudClient.createCollection({ name: "test" }); - collections = await cloudClient.listCollections(); + collections = await client.listCollections(); expect(collections.length).toBe(1); }); } diff --git a/clients/js/test/initClientWithAuth.ts b/clients/js/test/initClientWithAuth.ts index d0eb4348a10..a57f55bdef3 100644 --- a/clients/js/test/initClientWithAuth.ts +++ b/clients/js/test/initClientWithAuth.ts @@ -3,31 +3,37 @@ import { CloudClient } from "../src/CloudClient"; const PORT = process.env.PORT || "8000"; const URL = "http://localhost:" + PORT; -export const chromaBasic = new ChromaClient({ - path: URL, - auth: { provider: "basic", credentials: "admin:admin" }, -}); -export const chromaTokenDefault = new ChromaClient({ - path: URL, - auth: { provider: "token", credentials: "test-token" }, -}); -export const chromaTokenBearer = new ChromaClient({ - path: URL, - auth: { - provider: "token", - credentials: "test-token", - }, -}); -export const chromaTokenXToken = new ChromaClient({ - path: URL, - auth: { - provider: "token", - credentials: "test-token", - tokenHeaderType: "X_CHROMA_TOKEN", - }, -}); -export const cloudClient = new CloudClient({ - apiKey: "test-token", - cloudPort: PORT, - cloudHost: "http://localhost", -}); +export const chromaBasic = () => + new ChromaClient({ + path: URL, + auth: { provider: "basic", credentials: "admin:admin" }, + }); +export const chromaTokenDefault = () => + new ChromaClient({ + path: URL, + auth: { provider: "token", credentials: "test-token" }, + }); +export const chromaTokenBearer = () => + new ChromaClient({ + path: URL, + auth: { + provider: "token", + credentials: "test-token", + tokenHeaderType: "AUTHORIZATION", + }, + }); +export const chromaTokenXToken = () => + new ChromaClient({ + path: URL, + auth: { + provider: "token", + credentials: "test-token", + tokenHeaderType: "X_CHROMA_TOKEN", + }, + }); +export const cloudClient = () => + new CloudClient({ + apiKey: "test-token", + cloudPort: PORT, + cloudHost: "http://localhost", + }); diff --git a/clients/js/test/offline.test.ts b/clients/js/test/offline.test.ts index d35fe85a417..e9ce54d13be 100644 --- a/clients/js/test/offline.test.ts +++ b/clients/js/test/offline.test.ts @@ -1,5 +1,6 @@ import { expect, test } from "@jest/globals"; import { ChromaClient } from "../src/ChromaClient"; +import { ChromaConnectionError } from "../src/Errors"; test("it fails with a nice error", async () => { const chroma = new ChromaClient({ path: "http://example.invalid" }); @@ -7,7 +8,7 @@ test("it fails with a nice error", async () => { await chroma.createCollection({ name: "test" }); throw new Error("Should have thrown an error."); } catch (e) { - expect(e instanceof Error).toBe(true); + expect(e).toBeInstanceOf(ChromaConnectionError); expect((e as Error).message).toMatchInlineSnapshot( `"Failed to connect to chromadb. Make sure your server is running and try again. If you are running from a browser, make sure that your chromadb instance is configured to allow requests from the current origin using the CHROMA_SERVER_CORS_ALLOW_ORIGINS environment variable."`, ); diff --git a/docker-compose.test-auth.yml b/docker-compose.test-auth.yml index ad0ab102414..36faec10e4d 100644 --- a/docker-compose.test-auth.yml +++ b/docker-compose.test-auth.yml @@ -1,5 +1,3 @@ -version: '3.9' - networks: test_net: driver: bridge diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 262a850b7fa..ffbf802d851 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -1,5 +1,3 @@ -version: '3.9' - services: test_server: build: