Skip to content

Commit

Permalink
Web AI embedder fixes (#929)
Browse files Browse the repository at this point in the history
## Description of changes

After the recent changes in the Web AI structure, the embedder function
became invalid because of different imports and because the browser
version of the library is only ESM, as mentioned
[here](#824 (comment)).

To fix the issue, I updated the import paths and changed the `require`
statements to `import` for the browser version of the function.

**Important note:**

When using the embedded in the browser, you should import
`chromadb/dist/module` because the [browser
version](https://www.npmjs.com/package/@visheratin/web-ai) of Web AI is
ESM-only.
For Node.js, you can use regular `chromadb` import because [Node.js
version](https://www.npmjs.com/package/@visheratin/web-ai-node) of the
library is CJS.

Additionally, because Web AI since recently supports multimodal
embeddings, I added building multimodal embeddings using CLIP-base.

## Test plan

I tested the changes locally for both browser and Node.js. The built
`chromadb` package if someone wants to test it in their projects, is
available
[here](https://drive.google.com/file/d/1cNLsHGd1VmiFiamvsEaMVGA7ng56QQKG/view?usp=sharing).

## Documentation Changes

There likely will be needed changes regarding the multimodal
functionality.
  • Loading branch information
visheratin authored Aug 7, 2023
1 parent 608d01e commit 221bdfa
Showing 1 changed file with 190 additions and 62 deletions.
252 changes: 190 additions & 62 deletions clients/js/src/embeddings/WebAIEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -1,82 +1,40 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";
let webAI: any;

/**
* WebAIEmbeddingFunction is a function that uses the Web AI package to generate embeddings.
* @remarks
* This embedding function can be used in both NodeJS and browser environments.
* Browser version of Web AI (@visheratin/web-ai) is an ESM module.
* NodeJS version of Web AI (@visheratin/web-ai-node) is a CommonJS module.
*/
export class WebAIEmbeddingFunction implements IEmbeddingFunction {
private model;
private model: any;
private proxy?: boolean;
private initPromise: Promise<any> | null;
private modality: "text" | "image" | "multimodal";

/**
* WebAIEmbeddingFunction constructor.
* @param modality - the modality of the embedding function, either "text" or "image".
* @param modality - the modality of the embedding function, either "text", "image", or "multimodal".
* @param node - whether the embedding function is being used in a NodeJS environment.
* @param proxy - whether to use web worker to avoid blocking the main thread. Works only in browser.
* @param wasmPath - the path/URL to the directory with ONNX runtime WebAssembly files. Has to be specified when running in NodeJS.
* @param wasmPath - the path/URL to the directory with ONNX runtime WebAssembly files.
* @param modelID - the ID of the model to use, if not specified, the default model will be used.
*/
constructor(
modality: "text" | "image",
modality: "text" | "image" | "multimodal",
node: boolean,
proxy?: boolean,
wasmPath?: string,
modelID?: string
) {
this.initPromise = null;
this.model = null;
this.modality = modality;
if (node) {
this.proxy = proxy ? proxy : false;
try {
webAI = require("@visheratin/web-ai-node");
} catch (e) {
console.log(e);
throw new Error(
"Please install the @visheratin/web-ai-node package to use the WebAIEmbeddingFunction, `npm install -S @visheratin/web-ai-node`"
);
}
this.initNode(modality, proxy, modelID);
} else {
this.proxy = proxy ? proxy : true;
try {
webAI = require("@visheratin/web-ai");
} catch (e) {
console.log(e);
throw new Error(
"Please install the @visheratin/web-ai package to use the WebAIEmbeddingFunction, `npm install -S @visheratin/web-ai`"
);
}
}
if (wasmPath) {
webAI.SessionParams.wasmRoot = wasmPath;
}
switch (modality) {
case "text": {
let id = "mini-lm-v2-quant"; //default text model
if (modelID) {
id = modelID;
}
const models = webAI.ListTextModels();
for (const modelMetadata of models) {
if (modelMetadata.id === id) {
this.model = new webAI.TextFeatureExtractionModel(modelMetadata);
return;
}
}
throw new Error(
`Could not find text model with id ${modelID} in the WebAI package`
);
}
case "image": {
let id = "efficientformer-l1-feature-quant"; //default image model
if (modelID) {
id = modelID;
}
const imageModels = webAI.ListImageModels();
for (const modelMetadata of imageModels) {
if (modelMetadata.id === id) {
this.model = new webAI.ImageFeatureExtractionModel(modelMetadata);
return;
}
}
throw new Error(
`Could not find image model with id ${modelID} in the WebAI package`
);
}
this.initPromise = this.initBrowser(modality, proxy, wasmPath, modelID);
}
}

Expand All @@ -87,15 +45,185 @@ export class WebAIEmbeddingFunction implements IEmbeddingFunction {
* @returns the embeddings.
*/
public async generate(values: string[]): Promise<number[][]> {
if (this.initPromise) {
await this.initPromise;
}
if (!this.model.initialized) {
await this.model.init(this.proxy);
}
const output = await this.model.process(values);
const embeddings = output.result;
let embeddings = [];
if (this.modality === "text" || this.modality === "image") {
const output = await this.model.process(values);
embeddings = output.result;
} else {
const urlValues = [];
const textValues = [];
for (const value of values) {
try {
new URL(value);
urlValues.push(value);
} catch {
textValues.push(value);
}
}
const urlOutput = await this.model.embedImages(urlValues);
const textOutput = await this.model.embedTexts(textValues);
embeddings = urlOutput.concat(textOutput);
}
if (embeddings.length > 0 && Array.isArray(embeddings[0])) {
return embeddings;
} else {
return [embeddings];
}
}

private initNode(
modality: "text" | "image" | "multimodal",
proxy?: boolean,
modelID?: string
): void {
this.proxy = proxy ? proxy : false;
try {
const webAI = require("@visheratin/web-ai-node");
webAI.SessionParams.executionProviders = ["cpu"];
switch (modality) {
case "text": {
const webAIText = require("@visheratin/web-ai-node/text");
let id = "mini-lm-v2-quant"; //default text model
if (modelID) {
id = modelID;
}
const models = webAIText.ListTextModels();
for (const modelMetadata of models) {
if (modelMetadata.id === id) {
this.model = new webAIText.FeatureExtractionModel(modelMetadata);
return;
}
}
throw new Error(
`Could not find text model with id ${modelID} in the Web AI package`
);
}
case "image": {
const webAIImage = require("@visheratin/web-ai-node/image");
let id = "efficientformer-l1-feature-quant"; //default image model
if (modelID) {
id = modelID;
}
const imageModels = webAIImage.ListImageModels();
for (const modelMetadata of imageModels) {
if (modelMetadata.id === id) {
this.model = new webAIImage.FeatureExtractionModel(modelMetadata);
return;
}
}
throw new Error(
`Could not find image model with id ${modelID} in the Web AI package`
);
}
case "multimodal": {
const webAIMultimodal = require("@visheratin/web-ai-node/multimodal");
let id = "clip-base-quant"; //default multimodal model
if (modelID) {
id = modelID;
}
const multimodalModels = webAIMultimodal.ListMultimodalModels();
for (const modelMetadata of multimodalModels) {
if (modelMetadata.id === id) {
this.model = new webAIMultimodal.ZeroShotClassificationModel(
modelMetadata
);
return;
}
}
throw new Error(
`Could not find multimodal model with id ${modelID} in the Web AI package`
);
}
}
} catch (e) {
console.error(e);
throw new Error(
"Please install the @visheratin/web-ai-node package to use the WebAIEmbeddingFunction, `npm install -S @visheratin/web-ai-node`"
);
}
}

private async initBrowser(
modality: "text" | "image" | "multimodal",
proxy?: boolean,
modelID?: string,
wasmPath?: string
) {
this.proxy = proxy ? proxy : true;
try {
// @ts-ignore
const webAI = await import("@visheratin/web-ai");
if (wasmPath) {
webAI.SessionParams.wasmRoot = wasmPath;
}
switch (modality) {
case "text": {
// @ts-ignore
const webAIText = await import("@visheratin/web-ai/text");
let id = "mini-lm-v2-quant"; //default text model
if (modelID) {
id = modelID;
}
const models = webAIText.ListTextModels();
for (const modelMetadata of models) {
if (modelMetadata.id === id) {
this.model = new webAIText.FeatureExtractionModel(modelMetadata);
return;
}
}
throw new Error(
`Could not find text model with id ${modelID} in the Web AI package`
);
}
case "image": {
// @ts-ignore
const webAIImage = await import("@visheratin/web-ai/image");
let id = "efficientformer-l1-feature-quant"; //default image model
if (modelID) {
id = modelID;
}
const imageModels = webAIImage.ListImageModels();
for (const modelMetadata of imageModels) {
if (modelMetadata.id === id) {
this.model = new webAIImage.FeatureExtractionModel(modelMetadata);
return;
}
}
throw new Error(
`Could not find image model with id ${modelID} in the Web AI package`
);
}
case "multimodal": {
// @ts-ignore
const webAIImage = await import("@visheratin/web-ai/multimodal");
let id = "clip-base-quant"; //default multimodal model
if (modelID) {
id = modelID;
}
const imageModels = webAIImage.ListMultimodalModels();
for (const modelMetadata of imageModels) {
if (modelMetadata.id === id) {
this.model = new webAIImage.ZeroShotClassificationModel(
modelMetadata
);
return;
}
}
throw new Error(
`Could not find multimodal model with id ${modelID} in the Web AI package`
);
}
}
} catch (e) {
throw new Error(
"Please install the @visheratin/web-ai package to use the WebAIEmbeddingFunction, `npm install -S @visheratin/web-ai`"
);
}
}
}

0 comments on commit 221bdfa

Please sign in to comment.