Skip to content

Commit

Permalink
fix: Fix Langchain performance issue (#221)
Browse files Browse the repository at this point in the history
* feat: Adjust embedding types and access

* Adjust type

* fis performance issues

* fix tests

* add docs

* add await everywhere

* try again

* gf

* add chunks notion to readme

---------

Co-authored-by: Matthias Kuhr <[email protected]>
  • Loading branch information
marikaner and MatKuhr authored Oct 16, 2024
1 parent 4b8eb9f commit 99498cd
Show file tree
Hide file tree
Showing 18 changed files with 201 additions and 126 deletions.
5 changes: 5 additions & 0 deletions .changeset/grumpy-apes-beam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@sap-ai-sdk/foundation-models': minor
---

[New Functionality] Add convenience method to access all embeddings in an Azure OpenAI response (`AzureOpenAiEmbeddingResponse`).
5 changes: 5 additions & 0 deletions .changeset/many-zebras-repair.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@sap-ai-sdk/foundation-models': minor
---

[Compatibility Note] Adjust `AzureOpenAiEmbeddingOutput` type to include multiple embedding responses as opposed to one.
5 changes: 5 additions & 0 deletions .changeset/smooth-cameras-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@sap-ai-sdk/langchain': patch
---

[Fixed Issue] Fix performance issues when creating embeddings for split documents by sending all documents in one request instead of splitting it up in separate requests.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe('Azure OpenAI chat client', () => {
};

const mockResponse =
parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
await parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
);
Expand All @@ -58,7 +58,7 @@ describe('Azure OpenAI chat client', () => {

it('throws on bad request', async () => {
const prompt = { messages: [] };
const mockResponse = parseMockResponse(
const mockResponse = await parseMockResponse(
'foundation-models',
'azure-openai-error-response.json'
);
Expand All @@ -85,7 +85,7 @@ describe('Azure OpenAI chat client', () => {
};

const mockResponse =
parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
await parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@ import { createLogger } from '@sap-cloud-sdk/util';
import { jest } from '@jest/globals';
import { parseMockResponse } from '../../../../test-util/mock-http.js';
import { AzureOpenAiChatCompletionResponse } from './azure-openai-chat-completion-response.js';
import type { AzureOpenAiCreateChatCompletionResponse } from './client/inference/schema';
import type { HttpResponse } from '@sap-cloud-sdk/http-client';
import type { AzureOpenAiCreateChatCompletionResponse } from './client/inference/schema/index.js';

describe('OpenAI chat completion response', () => {
const mockResponse =
parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
let mockResponse: AzureOpenAiCreateChatCompletionResponse;
let rawResponse: HttpResponse;
let azureOpenAiChatResponse: AzureOpenAiChatCompletionResponse;

beforeAll(async () => {
mockResponse =
await parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
);
rawResponse = {
data: mockResponse,
status: 200,
headers: {},
request: {}
};
azureOpenAiChatResponse = new AzureOpenAiChatCompletionResponse(
rawResponse
);
const rawResponse = {
data: mockResponse,
status: 200,
headers: {},
request: {}
};
const azureOpenAiChatResponse = new AzureOpenAiChatCompletionResponse(
rawResponse
);
});

it('should return the chat completion response', () => {
expect(azureOpenAiChatResponse.data).toStrictEqual(mockResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ describe('Azure OpenAI embedding client', () => {
const prompt = {
input: ['AI is fascinating']
} as AzureOpenAiEmbeddingParameters;
const mockResponse = parseMockResponse<AzureOpenAiEmbeddingOutput>(
const mockResponse = await parseMockResponse<AzureOpenAiEmbeddingOutput>(
'foundation-models',
'azure-openai-embeddings-success-response.json'
);
Expand All @@ -52,7 +52,7 @@ describe('Azure OpenAI embedding client', () => {

it('throws on bad request', async () => {
const prompt = { input: [] };
const mockResponse = parseMockResponse(
const mockResponse = await parseMockResponse(
'foundation-models',
'azure-openai-error-response.json'
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,57 @@ import { createLogger } from '@sap-cloud-sdk/util';
import { jest } from '@jest/globals';
import { parseMockResponse } from '../../../../test-util/mock-http.js';
import { AzureOpenAiEmbeddingResponse } from './azure-openai-embedding-response.js';
import type { HttpResponse } from '@sap-cloud-sdk/http-client';
import type { AzureOpenAiEmbeddingOutput } from './azure-openai-embedding-types.js';

describe('Azure OpenAI embedding response', () => {
const mockResponse = parseMockResponse<AzureOpenAiEmbeddingResponse>(
'foundation-models',
'azure-openai-embeddings-success-response.json'
);
const rawResponse = {
data: mockResponse,
status: 200,
headers: {},
request: {}
};
const embeddingResponse = new AzureOpenAiEmbeddingResponse(rawResponse);
let embeddingResponse: AzureOpenAiEmbeddingResponse;
let rawResponse: HttpResponse;
let mockedData: AzureOpenAiEmbeddingOutput;
beforeAll(async () => {
mockedData = await parseMockResponse<AzureOpenAiEmbeddingOutput>(
'foundation-models',
'azure-openai-embeddings-success-response.json'
);

rawResponse = {
data: mockedData,
status: 200,
headers: {},
request: {}
};
embeddingResponse = new AzureOpenAiEmbeddingResponse(rawResponse);
});

it('should return the embedding response', () => {
expect(embeddingResponse.data).toStrictEqual(mockResponse);
expect(embeddingResponse.data).toStrictEqual(mockedData);
});

it('should return raw response', () => {
expect(embeddingResponse.rawResponse).toBe(rawResponse);
});

it('should return the first embedding', () => {
expect(embeddingResponse.getEmbedding()).toEqual(
mockedData.data[0].embedding
);
});

it('should return undefined when convenience function is called with incorrect index', () => {
const logger = createLogger({
package: 'foundation-models',
messageContext: 'azure-openai-embedding-response'
});
const errorSpy = jest.spyOn(logger, 'error');
expect(embeddingResponse.getEmbedding(1)).toBeUndefined();
expect(errorSpy).toHaveBeenCalledWith('Data index 1 is out of bounds.');
expect(embeddingResponse.getEmbedding(2)).toBeUndefined();
expect(errorSpy).toHaveBeenCalledWith('Data index 2 is out of bounds.');
expect(errorSpy).toHaveBeenCalledTimes(1);
});

it('should return all embeddings', () => {
expect(embeddingResponse.getEmbeddings()).toEqual([
mockedData.data[0].embedding,
mockedData.data[1]?.embedding
]);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ export class AzureOpenAiEmbeddingResponse {
return this.data.data[dataIndex]?.embedding;
}

/**
* Parses the Azure OpenAI response and returns all embeddings.
* @returns The embedding vectors.
*/
getEmbeddings(): number[][] {
return this.data.data.map(({ embedding }) => embedding);
}

private logInvalidDataIndex(dataIndex: number): void {
if (dataIndex < 0 || dataIndex >= this.data.data.length) {
logger.error(`Data index ${dataIndex} is out of bounds.`);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,20 @@ export interface AzureOpenAiEmbeddingOutput {
/**
* Array of result candidates.
*/
data: [
{
/**
* Embedding object.
*/
object: 'embedding';
/**
* Array of size `1536` (Azure OpenAI's embedding size) containing embedding vector.
*/
embedding: number[];
/**
* Index of choice.
*/
index: number;
}
];
data: {
/**
* Embedding object.
*/
object: 'embedding';
/**
* Array of size `1536` (Azure OpenAI's embedding size) containing embedding vector.
*/
embedding: number[];
/**
* Index of choice.
*/
index: number;
}[];
/**
* Token Usage.
*/
Expand Down
33 changes: 19 additions & 14 deletions packages/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@ This package provides LangChain model clients built on top of the foundation mod

## Table of Contents

- [Table of Contents](#table-of-contents)
- [Installation](#installation)
- [Prerequisites](#prerequisites)
- [Relationship between Models and Deployment ID](#relationship-between-models-and-deployment-id)
- [Usage](#usage)
- [Client Initialization](#client-initialization)
- [Chat Client](#chat-client)
- [Embedding Client](#embedding-client)
- [Local Testing](#local-testing)
- [Support, Feedback, Contribution](#support-feedback-contribution)
- [License](#license)
- [@sap-ai-sdk/langchain](#sap-ai-sdklangchain)
- [Table of Contents](#table-of-contents)
- [Installation](#installation)
- [Prerequisites](#prerequisites)
- [Relationship between Models and Deployment ID](#relationship-between-models-and-deployment-id)
- [Usage](#usage)
- [Client Initialization](#client-initialization)
- [Chat Client](#chat-client)
- [Advanced Example with Templating and Output Parsing](#advanced-example-with-templating-and-output-parsing)
- [Embedding Client](#embedding-client)
- [Embed Text](#embed-text)
- [Embed Document Chunks](#embed-document-chunks)
- [Preprocess, embed, and store documents](#preprocess-embed-and-store-documents)
- [Local Testing](#local-testing)
- [Support, Feedback, Contribution](#support-feedback-contribution)
- [License](#license)

## Installation

Expand Down Expand Up @@ -128,7 +133,7 @@ return llmChain.invoke({

### Embedding Client

Embedding clients allow embedding either text or documents (represented as arrays of strings).
Embedding clients allow embedding either text or document chunks (represented as arrays of strings).
While you can use them standalone, they are usually used in combination with other LangChain utilities, like a text splitter for preprocessing and a vector store for storage and retrieval of the relevant embeddings.
For a complete example how to implement RAG with our LangChain client, take a look at our [sample code](https://github.com/SAP/ai-sdk-js/blob/main/sample-code/src/langchain-azure-openai.ts).

Expand All @@ -140,10 +145,10 @@ const embeddedText = await embeddingClient.embedQuery(
);
```

#### Embed Documents
#### Embed Document Chunks

```ts
const embeddedDocument = await embeddingClient.embedDocuments([
const embeddedDocuments = await embeddingClient.embedDocuments([
'Page 1: Paris is the capital of France.',
'Page 2: It is a beautiful city.'
]);
Expand Down
29 changes: 18 additions & 11 deletions packages/langchain/src/openai/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import { AzureOpenAiEmbeddingClient as AzureOpenAiEmbeddingClientBase } from '@s
import { Embeddings } from '@langchain/core/embeddings';
import type {
AzureOpenAiEmbeddingModel,
AzureOpenAiEmbeddingParameters
AzureOpenAiEmbeddingParameters,
AzureOpenAiEmbeddingResponse
} from '@sap-ai-sdk/foundation-models';
import type { AzureOpenAiEmbeddingModelParams } from './types.js';

Expand All @@ -24,23 +25,29 @@ export class AzureOpenAiEmbeddingClient extends Embeddings {
this.resourceGroup = fields.resourceGroup;
}

/**
* Embed a list of document chunks. All chunks are embedded in one batch.
* @param documents - Document chunks to embed.
* @returns Embeddings.
*/
override async embedDocuments(documents: string[]): Promise<number[][]> {
return Promise.all(
documents.map(document => this.createEmbedding({ input: document }))
);
return (await this.createEmbeddings({ input: documents })).getEmbeddings();
}

/**
* Embed a single string.
* @param input - Input string to embed.
* @returns Embedding.
*/
override async embedQuery(input: string): Promise<number[]> {
return this.createEmbedding({ input });
return (await this.createEmbeddings({ input })).getEmbedding() ?? [];
}

private async createEmbedding(
private async createEmbeddings(
query: AzureOpenAiEmbeddingParameters
): Promise<number[]> {
return this.caller.callWithOptions(
{},
async () =>
(await this.openAiEmbeddingClient.run(query)).getEmbedding() ?? []
): Promise<AzureOpenAiEmbeddingResponse> {
return this.caller.callWithOptions({}, async () =>
this.openAiEmbeddingClient.run(query)
);
}
}
11 changes: 5 additions & 6 deletions packages/langchain/src/openai/util.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ import type {
} from '@sap-ai-sdk/foundation-models';
import type { AzureOpenAiChatCallOptions } from './types.js';

const openAiMockResponse =
parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
);

// Signal and Prompt Index are provided by the super class in every call
const defaultOptions = {
signal: undefined,
Expand All @@ -29,6 +23,11 @@ const defaultOptions = {

describe('Mapping Functions', () => {
it('should parse an OpenAI response to a (LangChain) chat response', async () => {
const openAiMockResponse =
await parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
);
const result = mapOutputToChatResult(openAiMockResponse);
expect(result).toMatchSnapshot();
});
Expand Down
Loading

0 comments on commit 99498cd

Please sign in to comment.