Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[major]: Together AI embeddings #3729

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/langchain-community/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ embeddings/ollama.d.ts
embeddings/tensorflow.cjs
embeddings/tensorflow.js
embeddings/tensorflow.d.ts
embeddings/togetherai.cjs
embeddings/togetherai.js
embeddings/togetherai.d.ts
embeddings/voyage.cjs
embeddings/voyage.js
embeddings/voyage.d.ts
Expand Down
8 changes: 8 additions & 0 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,11 @@
"import": "./embeddings/tensorflow.js",
"require": "./embeddings/tensorflow.cjs"
},
"./embeddings/togetherai": {
"types": "./embeddings/togetherai.d.ts",
"import": "./embeddings/togetherai.js",
"require": "./embeddings/togetherai.cjs"
},
"./embeddings/voyage": {
"types": "./embeddings/voyage.d.ts",
"import": "./embeddings/voyage.js",
Expand Down Expand Up @@ -1346,6 +1351,9 @@
"embeddings/tensorflow.cjs",
"embeddings/tensorflow.js",
"embeddings/tensorflow.d.ts",
"embeddings/togetherai.cjs",
"embeddings/togetherai.js",
"embeddings/togetherai.d.ts",
"embeddings/voyage.cjs",
"embeddings/voyage.js",
"embeddings/voyage.d.ts",
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const entrypoints = {
"embeddings/minimax": "embeddings/minimax",
"embeddings/ollama": "embeddings/ollama",
"embeddings/tensorflow": "embeddings/tensorflow",
"embeddings/togetherai": "embeddings/togetherai",
"embeddings/voyage": "embeddings/voyage",
// llms
"llms/ai21": "llms/ai21",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { test, expect } from "@jest/globals";
import { TogetherAIEmbeddings } from "../togetherai.js";

test("Test TogetherAIEmbeddings.embedQuery", async () => {
const embeddings = new TogetherAIEmbeddings();
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
expect(res.length).toBe(768);
});

test("Test TogetherAIEmbeddings.embedDocuments", async () => {
const embeddings = new TogetherAIEmbeddings();
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
expect(typeof res[1][0]).toBe("number");
expect(res[0].length).toBe(768);
expect(res[1].length).toBe(768);
});
188 changes: 188 additions & 0 deletions libs/langchain-community/src/embeddings/togetherai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import { getEnvironmentVariable } from "@langchain/core/utils/env";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey team, I've reviewed the code and noticed that this PR introduces a new HTTP request using the fetch function to make a POST request to the TogetherAI API. This comment is to flag this change for your review and consideration. Let me know if you have any questions or need further clarification.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the PR! I've flagged this change for review by the maintainers as it involves accessing environment variables for the TogetherAI API key. Keep up the good work!

import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "../utils/chunk.js";

/**
* Interface for TogetherAIEmbeddingsParams parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the TogetherAIEmbeddings class.
*/
export interface TogetherAIEmbeddingsParams extends EmbeddingsParams {
/**
* The API key to use for the TogetherAI API.
* @default {process.env.TOGETHER_AI_API_KEY}
*/
apiKey?: string;

/**
* Model name to use
* @default {"togethercomputer/m2-bert-80M-8k-retrieval"}
*/
modelName?: string;

/**
* Timeout to use when making requests to TogetherAI.
* @default {undefined}
*/
timeout?: number;

/**
* The maximum number of documents to embed in a single request.
* @default {512}
*/
batchSize?: number;

/**
* Whether to strip new lines from the input text. May not be suitable
* for all use cases.
* @default {false}
*/
stripNewLines?: boolean;
}

/** @ignore */
interface TogetherAIEmbeddingsResult {
object: string;
data: Array<{
object: "embedding";
embedding: number[];
index: number;
}>;
model: string;
request_id: string;
}

/**
* Class for generating embeddings using the TogetherAI API. Extends the
* Embeddings class and implements TogetherAIEmbeddingsParams.
* @example
* ```typescript
* const embeddings = new TogetherAIEmbeddings({
* apiKey: process.env.TOGETHER_AI_API_KEY, // Default value
* model: "togethercomputer/m2-bert-80M-8k-retrieval", // Default value
* });
* const res = await embeddings.embedQuery(
* "What would be a good company name a company that makes colorful socks?"
* );
* ```
*/
export class TogetherAIEmbeddings
extends Embeddings
implements TogetherAIEmbeddingsParams
{
modelName = "togethercomputer/m2-bert-80M-8k-retrieval";

apiKey: string;

batchSize = 512;

stripNewLines = false;

timeout?: number;

private embeddingsAPIUrl = "https://api.together.xyz/api/v1/embeddings";

constructor(fields?: Partial<TogetherAIEmbeddingsParams>) {
super(fields ?? {});

let apiKey =
fields?.apiKey ?? getEnvironmentVariable("TOGETHER_AI_API_KEY");
if (!apiKey) {
throw new Error("TOGETHER_AI_API_KEY not found.");
}

this.apiKey = apiKey;
this.modelName = fields?.modelName ?? this.modelName;
this.timeout = fields?.timeout;
this.batchSize = fields?.batchSize ?? this.batchSize;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
}

private constructHeaders() {
return {
accept: "application/json",
"content-type": "application/json",
Authorization: `Bearer ${this.apiKey}`
};
}

private constructBody(input: string) {
const body = {
model: this?.modelName,
input
};
return body;
}

/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the TogetherAI API to generate
* embeddings.
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);

let batchResponses: TogetherAIEmbeddingsResult[] = [];
for await (const batch of batches) {
const batchRequests = batch.map((item) => this.embeddingWithRetry(item));
const response = await Promise.all(batchRequests);
batchResponses = batchResponses.concat(response);
}

const embeddings: number[][] = batchResponses.map(
(response) => response.data[0].embedding
);
return embeddings;
}

/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param {string} text Document to generate an embedding for.
* @returns {Promise<number[]>} Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const { data } = await this.embeddingWithRetry(
this.stripNewLines ? text.replace(/\n/g, " ") : text
);
return data[0].embedding;
}

/**
* Private method to make a request to the TogetherAI API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param {string} input The input text to embed.
* @returns Promise that resolves to the response from the API.
* @TODO Figure out return type and statically type it.
*/
private async embeddingWithRetry(
input: string
): Promise<TogetherAIEmbeddingsResult> {
const body = JSON.stringify(this.constructBody(input));
const headers = this.constructHeaders();

return this.caller.call(async () => {
const fetchResponse = await fetch(this.embeddingsAPIUrl, {
method: "POST",
headers,
body
});

if (fetchResponse.status === 200) {
return fetchResponse.json();
}
throw new Error(
`Error getting prompt completion from Together AI. ${JSON.stringify(
await fetchResponse.json(),
null,
2
)}`
);
});
}
}
6 changes: 3 additions & 3 deletions libs/langchain-community/src/llms/togetherai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export class TogetherAI extends LLM<TogetherAICallOptions> {

private apiKey: string;

private inferenceUrl = "https://api.together.xyz/inference";
private inferenceAPIUrl = "https://api.together.xyz/inference";

static lc_name() {
return "TogetherAI";
Expand Down Expand Up @@ -197,7 +197,7 @@ export class TogetherAI extends LLM<TogetherAICallOptions> {
options?: this["ParsedCallOptions"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I noticed that the recent change in the togetherai.ts file updates the URL for an existing fetch request. This comment is to flag the modification for maintainers to review. Keep up the great work!

) {
return this.caller.call(async () => {
const fetchResponse = await fetch(this.inferenceUrl, {
const fetchResponse = await fetch(this.inferenceAPIUrl, {
method: "POST",
headers: {
...this.constructHeaders(),
Expand Down Expand Up @@ -236,7 +236,7 @@ export class TogetherAI extends LLM<TogetherAICallOptions> {
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const fetchResponse = await fetch(this.inferenceUrl, {
const fetchResponse = await fetch(this.inferenceAPIUrl, {
method: "POST",
headers: {
...this.constructHeaders(),
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/src/load/import_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export * as agents__toolkits__base from "../agents/toolkits/base.js";
export * as agents__toolkits__connery from "../agents/toolkits/connery/index.js";
export * as embeddings__minimax from "../embeddings/minimax.js";
export * as embeddings__ollama from "../embeddings/ollama.js";
export * as embeddings__togetherai from "../embeddings/togetherai.js";
export * as embeddings__voyage from "../embeddings/voyage.js";
export * as llms__ai21 from "../llms/ai21.js";
export * as llms__aleph_alpha from "../llms/aleph_alpha.js";
Expand Down
Loading