-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
community[major]: Together AI embeddings #3729
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { TogetherAIEmbeddings } from "../togetherai.js"; | ||
|
||
test("Test TogetherAIEmbeddings.embedQuery", async () => { | ||
const embeddings = new TogetherAIEmbeddings(); | ||
const res = await embeddings.embedQuery("Hello world"); | ||
expect(typeof res[0]).toBe("number"); | ||
expect(res.length).toBe(768); | ||
}); | ||
|
||
test("Test TogetherAIEmbeddings.embedDocuments", async () => { | ||
const embeddings = new TogetherAIEmbeddings(); | ||
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); | ||
expect(res).toHaveLength(2); | ||
expect(typeof res[0][0]).toBe("number"); | ||
expect(typeof res[1][0]).toBe("number"); | ||
expect(res[0].length).toBe(768); | ||
expect(res[1].length).toBe(768); | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great work on the PR! I've flagged this change for review by the maintainers as it involves accessing environment variables for the TogetherAI API key. Keep up the good work! |
||
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; | ||
import { chunkArray } from "../utils/chunk.js"; | ||
|
||
/** | ||
* Interface for TogetherAIEmbeddingsParams parameters. Extends EmbeddingsParams and | ||
* defines additional parameters specific to the TogetherAIEmbeddings class. | ||
*/ | ||
export interface TogetherAIEmbeddingsParams extends EmbeddingsParams { | ||
/** | ||
* The API key to use for the TogetherAI API. | ||
* @default {process.env.TOGETHER_AI_API_KEY} | ||
*/ | ||
apiKey?: string; | ||
|
||
/** | ||
* Model name to use | ||
* @default {"togethercomputer/m2-bert-80M-8k-retrieval"} | ||
*/ | ||
modelName?: string; | ||
|
||
/** | ||
* Timeout to use when making requests to TogetherAI. | ||
* @default {undefined} | ||
*/ | ||
timeout?: number; | ||
|
||
/** | ||
* The maximum number of documents to embed in a single request. | ||
* @default {512} | ||
*/ | ||
batchSize?: number; | ||
|
||
/** | ||
* Whether to strip new lines from the input text. May not be suitable | ||
* for all use cases. | ||
* @default {false} | ||
*/ | ||
stripNewLines?: boolean; | ||
} | ||
|
||
/** @ignore */ | ||
interface TogetherAIEmbeddingsResult { | ||
object: string; | ||
data: Array<{ | ||
object: "embedding"; | ||
embedding: number[]; | ||
index: number; | ||
}>; | ||
model: string; | ||
request_id: string; | ||
} | ||
|
||
/** | ||
* Class for generating embeddings using the TogetherAI API. Extends the | ||
* Embeddings class and implements TogetherAIEmbeddingsParams. | ||
* @example | ||
* ```typescript | ||
* const embeddings = new TogetherAIEmbeddings({ | ||
* apiKey: process.env.TOGETHER_AI_API_KEY, // Default value | ||
* model: "togethercomputer/m2-bert-80M-8k-retrieval", // Default value | ||
* }); | ||
* const res = await embeddings.embedQuery( | ||
* "What would be a good company name a company that makes colorful socks?" | ||
* ); | ||
* ``` | ||
*/ | ||
export class TogetherAIEmbeddings | ||
extends Embeddings | ||
implements TogetherAIEmbeddingsParams | ||
{ | ||
modelName = "togethercomputer/m2-bert-80M-8k-retrieval"; | ||
|
||
apiKey: string; | ||
|
||
batchSize = 512; | ||
|
||
stripNewLines = false; | ||
|
||
timeout?: number; | ||
|
||
private embeddingsAPIUrl = "https://api.together.xyz/api/v1/embeddings"; | ||
|
||
constructor(fields?: Partial<TogetherAIEmbeddingsParams>) { | ||
super(fields ?? {}); | ||
|
||
let apiKey = | ||
fields?.apiKey ?? getEnvironmentVariable("TOGETHER_AI_API_KEY"); | ||
if (!apiKey) { | ||
throw new Error("TOGETHER_AI_API_KEY not found."); | ||
} | ||
|
||
this.apiKey = apiKey; | ||
this.modelName = fields?.modelName ?? this.modelName; | ||
this.timeout = fields?.timeout; | ||
this.batchSize = fields?.batchSize ?? this.batchSize; | ||
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; | ||
} | ||
|
||
private constructHeaders() { | ||
return { | ||
accept: "application/json", | ||
"content-type": "application/json", | ||
Authorization: `Bearer ${this.apiKey}` | ||
}; | ||
} | ||
|
||
private constructBody(input: string) { | ||
const body = { | ||
model: this?.modelName, | ||
input | ||
}; | ||
return body; | ||
} | ||
|
||
/** | ||
* Method to generate embeddings for an array of documents. Splits the | ||
* documents into batches and makes requests to the TogetherAI API to generate | ||
* embeddings. | ||
* @param texts Array of documents to generate embeddings for. | ||
* @returns Promise that resolves to a 2D array of embeddings for each document. | ||
*/ | ||
async embedDocuments(texts: string[]): Promise<number[][]> { | ||
const batches = chunkArray( | ||
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, | ||
this.batchSize | ||
); | ||
|
||
let batchResponses: TogetherAIEmbeddingsResult[] = []; | ||
for await (const batch of batches) { | ||
const batchRequests = batch.map((item) => this.embeddingWithRetry(item)); | ||
const response = await Promise.all(batchRequests); | ||
batchResponses = batchResponses.concat(response); | ||
} | ||
|
||
const embeddings: number[][] = batchResponses.map( | ||
(response) => response.data[0].embedding | ||
); | ||
return embeddings; | ||
} | ||
|
||
/** | ||
* Method to generate an embedding for a single document. Calls the | ||
* embeddingWithRetry method with the document as the input. | ||
* @param {string} text Document to generate an embedding for. | ||
* @returns {Promise<number[]>} Promise that resolves to an embedding for the document. | ||
*/ | ||
async embedQuery(text: string): Promise<number[]> { | ||
const { data } = await this.embeddingWithRetry( | ||
this.stripNewLines ? text.replace(/\n/g, " ") : text | ||
); | ||
return data[0].embedding; | ||
} | ||
|
||
/** | ||
* Private method to make a request to the TogetherAI API to generate | ||
* embeddings. Handles the retry logic and returns the response from the | ||
* API. | ||
* @param {string} input The input text to embed. | ||
* @returns Promise that resolves to the response from the API. | ||
* @TODO Figure out return type and statically type it. | ||
*/ | ||
private async embeddingWithRetry( | ||
input: string | ||
): Promise<TogetherAIEmbeddingsResult> { | ||
const body = JSON.stringify(this.constructBody(input)); | ||
const headers = this.constructHeaders(); | ||
|
||
return this.caller.call(async () => { | ||
const fetchResponse = await fetch(this.embeddingsAPIUrl, { | ||
method: "POST", | ||
headers, | ||
body | ||
}); | ||
|
||
if (fetchResponse.status === 200) { | ||
return fetchResponse.json(); | ||
} | ||
throw new Error( | ||
`Error getting prompt completion from Together AI. ${JSON.stringify( | ||
await fetchResponse.json(), | ||
null, | ||
2 | ||
)}` | ||
); | ||
}); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,7 +139,7 @@ export class TogetherAI extends LLM<TogetherAICallOptions> { | |
|
||
private apiKey: string; | ||
|
||
private inferenceUrl = "https://api.together.xyz/inference"; | ||
private inferenceAPIUrl = "https://api.together.xyz/inference"; | ||
|
||
static lc_name() { | ||
return "TogetherAI"; | ||
|
@@ -197,7 +197,7 @@ export class TogetherAI extends LLM<TogetherAICallOptions> { | |
options?: this["ParsedCallOptions"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey there! I noticed that the recent change in the |
||
) { | ||
return this.caller.call(async () => { | ||
const fetchResponse = await fetch(this.inferenceUrl, { | ||
const fetchResponse = await fetch(this.inferenceAPIUrl, { | ||
method: "POST", | ||
headers: { | ||
...this.constructHeaders(), | ||
|
@@ -236,7 +236,7 @@ export class TogetherAI extends LLM<TogetherAICallOptions> { | |
options: this["ParsedCallOptions"], | ||
runManager?: CallbackManagerForLLMRun | ||
): AsyncGenerator<GenerationChunk> { | ||
const fetchResponse = await fetch(this.inferenceUrl, { | ||
const fetchResponse = await fetch(this.inferenceAPIUrl, { | ||
method: "POST", | ||
headers: { | ||
...this.constructHeaders(), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey team, I've reviewed the code and noticed that this PR introduces a new HTTP request using the
fetch
function to make a POST request to the TogetherAI API. This comment is to flag this change for your review and consideration. Let me know if you have any questions or need further clarification.