Skip to content

Commit

Permalink
Merge pull request #160 from lgrammel/lg/fix-llama-cpp
Browse files Browse the repository at this point in the history
Add flag for parallelizable embedding model calls. #156
  • Loading branch information
lgrammel authored Nov 4, 2023
2 parents 20cb743 + b13c766 commit 1754549
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/model-function/embed/EmbeddingModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ export interface EmbeddingModel<
*/
readonly maxValuesPerCall: number | undefined;

/**
* True if the model can handle multiple embedding calls in parallel.
*/
readonly isParallizable: boolean;

doEmbedValues(
values: VALUE[],
options?: FunctionOptions
Expand Down
20 changes: 15 additions & 5 deletions src/model-function/embed/embed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,21 @@ export function embedMany<VALUE>(
}
}

const responses = await Promise.all(
valueGroups.map((valueGroup) =>
model.doEmbedValues(valueGroup, options)
)
);
// call the model for each group:
let responses: Array<{ response: unknown; embeddings: Vector[] }>;
if (model.isParallizable) {
responses = await Promise.all(
valueGroups.map((valueGroup) =>
model.doEmbedValues(valueGroup, options)
)
);
} else {
responses = [];
for (const valueGroup of valueGroups) {
const response = await model.doEmbedValues(valueGroup, options);
responses.push(response);
}
}

const rawResponses = responses.map((response) => response.response);
const embeddings: Array<Vector> = [];
Expand Down
1 change: 1 addition & 0 deletions src/model-provider/cohere/CohereTextEmbeddingModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export class CohereTextEmbeddingModel
}

readonly maxValuesPerCall = 96;
readonly isParallizable = true;
readonly embeddingDimensions: number;

readonly contextWindowSize: number;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export class HuggingFaceTextEmbeddingModel
}

readonly maxValuesPerCall;
readonly isParallizable = true;

readonly contextWindowSize = undefined;
readonly embeddingDimensions;
Expand Down
4 changes: 4 additions & 0 deletions src/model-provider/llamacpp/LlamaCppTextEmbeddingModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export interface LlamaCppTextEmbeddingModelSettings
extends EmbeddingModelSettings {
api?: ApiConfiguration;
embeddingDimensions?: number;
isParallizable?: boolean;
}

export class LlamaCppTextEmbeddingModel
Expand All @@ -38,6 +39,9 @@ export class LlamaCppTextEmbeddingModel
}

readonly maxValuesPerCall = 1;
get isParallizable() {
return this.settings.isParallizable ?? false;
}

readonly contextWindowSize = undefined;
readonly embeddingDimensions;
Expand Down
1 change: 1 addition & 0 deletions src/model-provider/openai/OpenAITextEmbeddingModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export class OpenAITextEmbeddingModel
}

readonly maxValuesPerCall = 2048;
readonly isParallizable = true;

readonly embeddingDimensions: number;

Expand Down

1 comment on commit 1754549

@vercel
Copy link

@vercel vercel bot commented on 1754549 Nov 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.