diff --git a/libs/langchain-google-common/.gitignore b/libs/langchain-google-common/.gitignore index df014a2d426b..c4537f159680 100644 --- a/libs/langchain-google-common/.gitignore +++ b/libs/langchain-google-common/.gitignore @@ -10,6 +10,14 @@ types.cjs types.js types.d.ts types.d.cts +experimental/media.cjs +experimental/media.js +experimental/media.d.ts +experimental/media.d.cts +experimental/utils/media_core.cjs +experimental/utils/media_core.js +experimental/utils/media_core.d.ts +experimental/utils/media_core.d.cts node_modules dist .yarn diff --git a/libs/langchain-google-common/langchain.config.js b/libs/langchain-google-common/langchain.config.js index df02f88bd793..458e4abdc1a4 100644 --- a/libs/langchain-google-common/langchain.config.js +++ b/libs/langchain-google-common/langchain.config.js @@ -16,6 +16,8 @@ export const config = { index: "index", utils: "utils/index", types: "types", + "experimental/media": "experimental/media", + "experimental/utils/media_core": "experimental/utils/media_core", }, tsConfigPath: resolve("./tsconfig.json"), cjsSource: "./dist-cjs", diff --git a/libs/langchain-google-common/package.json b/libs/langchain-google-common/package.json index 242093637b64..2bafc1cbcc5e 100644 --- a/libs/langchain-google-common/package.json +++ b/libs/langchain-google-common/package.json @@ -92,6 +92,24 @@ "import": "./types.js", "require": "./types.cjs" }, + "./experimental/media": { + "types": { + "import": "./experimental/media.d.ts", + "require": "./experimental/media.d.cts", + "default": "./experimental/media.d.ts" + }, + "import": "./experimental/media.js", + "require": "./experimental/media.cjs" + }, + "./experimental/utils/media_core": { + "types": { + "import": "./experimental/utils/media_core.d.ts", + "require": "./experimental/utils/media_core.d.cts", + "default": "./experimental/utils/media_core.d.ts" + }, + "import": "./experimental/utils/media_core.js", + "require": "./experimental/utils/media_core.cjs" + }, "./package.json": "./package.json" }, "files": [ @@ -107,6 +125,14 @@ "types.cjs", "types.js", "types.d.ts", - "types.d.cts" + "types.d.cts", + "experimental/media.cjs", + "experimental/media.js", + "experimental/media.d.ts", + "experimental/media.d.cts", + "experimental/utils/media_core.cjs", + "experimental/utils/media_core.js", + "experimental/utils/media_core.d.ts", + "experimental/utils/media_core.d.cts" ] } diff --git a/libs/langchain-google-common/src/auth.ts b/libs/langchain-google-common/src/auth.ts index 9e278a9605d2..60e0fab4f998 100644 --- a/libs/langchain-google-common/src/auth.ts +++ b/libs/langchain-google-common/src/auth.ts @@ -1,9 +1,9 @@ import { ReadableJsonStream } from "./utils/stream.js"; import { GooglePlatformType } from "./types.js"; -export type GoogleAbstractedClientOpsMethod = "GET" | "POST"; +export type GoogleAbstractedClientOpsMethod = "GET" | "POST" | "DELETE"; -export type GoogleAbstractedClientOpsResponseType = "json" | "stream"; +export type GoogleAbstractedClientOpsResponseType = "json" | "stream" | "blob"; export type GoogleAbstractedClientOps = { url?: string; @@ -28,6 +28,17 @@ export abstract class GoogleAbstractedFetchClient abstract request(opts: GoogleAbstractedClientOps): unknown; + async _buildData(res: Response, opts: GoogleAbstractedClientOps) { + switch (opts.responseType) { + case "json": + return res.json(); + case "stream": + return new ReadableJsonStream(res.body); + default: + return res.blob(); + } + } + async _request( url: string | undefined, opts: GoogleAbstractedClientOps, @@ -47,7 +58,11 @@ export abstract class GoogleAbstractedFetchClient }, }; if (opts.data !== undefined) { - fetchOptions.body = JSON.stringify(opts.data); + if (typeof opts.data === "string") { + fetchOptions.body = opts.data; + } else { + fetchOptions.body = JSON.stringify(opts.data); + } } const res = await fetch(url, fetchOptions); @@ -57,16 +72,21 @@ export abstract class GoogleAbstractedFetchClient const error = new Error( `Google request failed with status code ${res.status}: ${resText}` ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any + /* eslint-disable @typescript-eslint/no-explicit-any */ (error as any).response = res; + (error as any).details = { + url, + opts, + fetchOptions, + result: res, + }; + /* eslint-enable @typescript-eslint/no-explicit-any */ throw error; } + const data = await this._buildData(res, opts); return { - data: - opts.responseType === "json" - ? await res.json() - : new ReadableJsonStream(res.body), + data, config: {}, status: res.status, statusText: res.statusText, diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index 2fa539657ec0..4ee4e0f6ed05 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -39,12 +39,7 @@ import { copyAndValidateModelParamsInto, } from "./utils/common.js"; import { AbstractGoogleLLMConnection } from "./connection.js"; -import { - baseMessageToContent, - safeResponseToChatGeneration, - safeResponseToChatResult, - DefaultGeminiSafetyHandler, -} from "./utils/gemini.js"; +import { DefaultGeminiSafetyHandler } from "./utils/gemini.js"; import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js"; import { JsonStream } from "./utils/stream.js"; import { ensureParams } from "./utils/failed_handler.js"; @@ -55,6 +50,7 @@ import type { GeminiFunctionDeclaration, GeminiFunctionSchema, GoogleAIToolType, + GeminiAPIConfig, } from "./types.js"; import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js"; @@ -100,61 +96,69 @@ class ChatConnection extends AbstractGoogleLLMConnection< return true; } - formatContents( + async formatContents( input: BaseMessage[], _parameters: GoogleAIModelParams - ): GeminiContent[] { - return input - .map((msg, i) => - baseMessageToContent(msg, input[i - 1], this.useSystemInstruction) + ): Promise { + const inputPromises: Promise[] = input.map((msg, i) => + this.api.baseMessageToContent( + msg, + input[i - 1], + this.useSystemInstruction ) - .reduce((acc, cur) => { - // Filter out the system content - if (cur.every((content) => content.role === "system")) { - return acc; - } - - // Combine adjacent function messages - if ( - cur[0]?.role === "function" && - acc.length > 0 && - acc[acc.length - 1].role === "function" - ) { - acc[acc.length - 1].parts = [ - ...acc[acc.length - 1].parts, - ...cur[0].parts, - ]; - } else { - acc.push(...cur); - } + ); + const inputs = await Promise.all(inputPromises); + return inputs.reduce((acc, cur) => { + // Filter out the system content + if (cur.every((content) => content.role === "system")) { return acc; - }, [] as GeminiContent[]); + } + + // Combine adjacent function messages + if ( + cur[0]?.role === "function" && + acc.length > 0 && + acc[acc.length - 1].role === "function" + ) { + acc[acc.length - 1].parts = [ + ...acc[acc.length - 1].parts, + ...cur[0].parts, + ]; + } else { + acc.push(...cur); + } + + return acc; + }, [] as GeminiContent[]); } - formatSystemInstruction( + async formatSystemInstruction( input: BaseMessage[], _parameters: GoogleAIModelParams - ): GeminiContent { + ): Promise { if (!this.useSystemInstruction) { return {} as GeminiContent; } let ret = {} as GeminiContent; - input.forEach((message, index) => { + for (let index = 0; index < input.length; index += 1) { + const message = input[index]; if (message._getType() === "system") { // For system types, we only want it if it is the first message, // if it appears anywhere else, it should be an error. if (index === 0) { // eslint-disable-next-line prefer-destructuring - ret = baseMessageToContent(message, undefined, true)[0]; + ret = ( + await this.api.baseMessageToContent(message, undefined, true) + )[0]; } else { throw new Error( "System messages are only permitted as the first passed message." ); } } - }); + } return ret; } @@ -168,6 +172,7 @@ export interface ChatGoogleBaseInput GoogleConnectionParams, GoogleAIModelParams, GoogleAISafetyParams, + GeminiAPIConfig, Pick {} /** @@ -338,7 +343,10 @@ export abstract class ChatGoogleBase parameters, options ); - const ret = safeResponseToChatResult(response, this.safetyHandler); + const ret = this.connection.api.safeResponseToChatResult( + response, + this.safetyHandler + ); await runManager?.handleLLMNewToken(ret.generations[0].text); return ret; } @@ -378,7 +386,10 @@ export abstract class ChatGoogleBase } const chunk = output !== null - ? safeResponseToChatGeneration({ data: output }, this.safetyHandler) + ? this.connection.api.safeResponseToChatGeneration( + { data: output }, + this.safetyHandler + ) : new ChatGenerationChunk({ text: "", generationInfo: { finishReason: "stop" }, diff --git a/libs/langchain-google-common/src/connection.ts b/libs/langchain-google-common/src/connection.ts index 4a90e9f795af..7e7da9daa304 100644 --- a/libs/langchain-google-common/src/connection.ts +++ b/libs/langchain-google-common/src/connection.ts @@ -20,6 +20,7 @@ import type { GeminiTool, GeminiFunctionDeclaration, GoogleAIModelRequestParams, + GoogleRawResponse, GoogleAIToolType, } from "./types.js"; import { @@ -28,6 +29,7 @@ import { GoogleAbstractedClientOpsMethod, } from "./auth.js"; import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js"; +import { getGeminiAPI } from "./utils/index.js"; export abstract class GoogleConnection< CallOptions extends AsyncCallerCallOptions, @@ -84,15 +86,23 @@ export abstract class GoogleConnection< return this.constructor.name; } - async _request( + async additionalHeaders(): Promise> { + return {}; + } + + async _buildOpts( data: unknown | undefined, - options: CallOptions - ): Promise { + _options: CallOptions, + requestHeaders: Record = {} + ): Promise { const url = await this.buildUrl(); const method = this.buildMethod(); const infoHeaders = (await this._clientInfoHeaders()) ?? {}; + const additionalHeaders = (await this.additionalHeaders()) ?? {}; const headers = { ...infoHeaders, + ...additionalHeaders, + ...requestHeaders, }; const opts: GoogleAbstractedClientOps = { @@ -108,7 +118,15 @@ export abstract class GoogleConnection< } else { opts.responseType = "json"; } + return opts; + } + async _request( + data: unknown | undefined, + options: CallOptions, + requestHeaders: Record = {} + ): Promise { + const opts = await this._buildOpts(data, options, requestHeaders); const callResponse = await this.caller.callWithOptions( { signal: options?.signal }, async () => this.client.request(opts) @@ -165,6 +183,21 @@ export abstract class GoogleHostConnection< } } +export abstract class GoogleRawConnection< + CallOptions extends AsyncCallerCallOptions, + AuthOptions +> extends GoogleHostConnection { + async _buildOpts( + data: unknown | undefined, + _options: CallOptions, + requestHeaders: Record = {} + ): Promise { + const opts = await super._buildOpts(data, _options, requestHeaders); + opts.responseType = "blob"; + return opts; + } +} + export abstract class GoogleAIConnection< CallOptions extends AsyncCallerCallOptions, InputType, @@ -180,6 +213,9 @@ export abstract class GoogleAIConnection< client: GoogleAbstractedClient; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + api: any; // FIXME: Make this a real type + constructor( fields: GoogleAIBaseLLMInput | undefined, caller: AsyncCaller, @@ -190,6 +226,7 @@ export abstract class GoogleAIConnection< this.client = client; this.modelName = fields?.model ?? fields?.modelName ?? this.model; this.model = this.modelName; + this.api = getGeminiAPI(fields); } get modelFamily(): GoogleLLMModelFamily { @@ -235,14 +272,14 @@ export abstract class GoogleAIConnection< abstract formatData( input: InputType, parameters: GoogleAIModelRequestParams - ): unknown; + ): Promise; async request( input: InputType, parameters: GoogleAIModelRequestParams, options: CallOptions - ): Promise { - const data = this.formatData(input, parameters); + ): Promise { + const data = await this.formatData(input, parameters); const response = await this._request(data, options); return response; } @@ -273,7 +310,7 @@ export abstract class AbstractGoogleLLMConnection< abstract formatContents( input: MessageType, parameters: GoogleAIModelRequestParams - ): GeminiContent[]; + ): Promise; formatGenerationConfig( _input: MessageType, @@ -296,10 +333,10 @@ export abstract class AbstractGoogleLLMConnection< return parameters.safetySettings ?? []; } - formatSystemInstruction( + async formatSystemInstruction( _input: MessageType, _parameters: GoogleAIModelRequestParams - ): GeminiContent { + ): Promise { return {} as GeminiContent; } @@ -362,16 +399,19 @@ export abstract class AbstractGoogleLLMConnection< }; } - formatData( + async formatData( input: MessageType, parameters: GoogleAIModelRequestParams - ): GeminiRequest { - const contents = this.formatContents(input, parameters); + ): Promise { + const contents = await this.formatContents(input, parameters); const generationConfig = this.formatGenerationConfig(input, parameters); const tools = this.formatTools(input, parameters); const toolConfig = this.formatToolConfig(parameters); const safetySettings = this.formatSafetySettings(input, parameters); - const systemInstruction = this.formatSystemInstruction(input, parameters); + const systemInstruction = await this.formatSystemInstruction( + input, + parameters + ); const ret: GeminiRequest = { contents, diff --git a/libs/langchain-google-common/src/embeddings.ts b/libs/langchain-google-common/src/embeddings.ts index 4bf568959efe..d1e2549c631a 100644 --- a/libs/langchain-google-common/src/embeddings.ts +++ b/libs/langchain-google-common/src/embeddings.ts @@ -38,10 +38,10 @@ class EmbeddingsConnection< return "predict"; } - formatData( + async formatData( input: GoogleEmbeddingsInstance[], parameters: GoogleAIModelRequestParams - ): unknown { + ): Promise { return { instances: input, parameters, @@ -172,7 +172,8 @@ export abstract class BaseGoogleEmbeddings ?.map( (response) => response?.data?.predictions?.map( - (result) => result.embeddings.values + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (result: any) => result.embeddings?.values ) ?? [] ) .flat() ?? []; diff --git a/libs/langchain-google-common/src/experimental/media.ts b/libs/langchain-google-common/src/experimental/media.ts new file mode 100644 index 000000000000..89646482ca70 --- /dev/null +++ b/libs/langchain-google-common/src/experimental/media.ts @@ -0,0 +1,803 @@ +import { + AsyncCaller, + AsyncCallerCallOptions, + AsyncCallerParams, +} from "@langchain/core/utils/async_caller"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { + MediaBlob, + BlobStore, + BlobStoreOptions, + MediaBlobData, +} from "./utils/media_core.js"; +import { + GoogleConnectionParams, + GoogleRawResponse, + GoogleResponse, +} from "../types.js"; +import { GoogleHostConnection, GoogleRawConnection } from "../connection.js"; +import { + ApiKeyGoogleAuth, + GoogleAbstractedClient, + GoogleAbstractedClientOpsMethod, +} from "../auth.js"; + +export interface GoogleUploadConnectionParams + extends GoogleConnectionParams {} + +export abstract class GoogleMultipartUploadConnection< + CallOptions extends AsyncCallerCallOptions, + ResponseType extends GoogleResponse, + AuthOptions +> extends GoogleHostConnection { + constructor( + fields: GoogleConnectionParams | undefined, + caller: AsyncCaller, + client: GoogleAbstractedClient + ) { + super(fields, caller, client); + } + + async _body( + separator: string, + data: MediaBlob, + metadata: Record + ): Promise { + const contentType = data.mimetype; + const { encoded, encoding } = await data.encode(); + const body = [ + `--${separator}`, + "Content-Type: application/json; charset=UTF-8", + "", + JSON.stringify(metadata), + "", + `--${separator}`, + `Content-Type: ${contentType}`, + `Content-Transfer-Encoding: ${encoding}`, + "", + encoded, + `--${separator}--`, + ]; + return body.join("\n"); + } + + async request( + data: MediaBlob, + metadata: Record, + options: CallOptions + ): Promise { + const separator = `separator-${Date.now()}`; + const body = await this._body(separator, data, metadata); + const requestHeaders = { + "Content-Type": `multipart/related; boundary=${separator}`, + "X-Goog-Upload-Protocol": "multipart", + }; + const response = this._request(body, options, requestHeaders); + return response; + } +} + +export abstract class GoogleDownloadConnection< + CallOptions extends AsyncCallerCallOptions, + ResponseType extends GoogleResponse, + AuthOptions +> extends GoogleHostConnection { + async request(options: CallOptions): Promise { + return this._request(undefined, options); + } +} + +export abstract class GoogleDownloadRawConnection< + CallOptions extends AsyncCallerCallOptions, + AuthOptions +> extends GoogleRawConnection { + buildMethod(): GoogleAbstractedClientOpsMethod { + return "GET"; + } + + async request(options: CallOptions): Promise { + return this._request(undefined, options); + } +} + +export interface BlobStoreGoogleParams + extends GoogleConnectionParams, + AsyncCallerParams, + BlobStoreOptions {} + +export abstract class BlobStoreGoogle< + ResponseType extends GoogleResponse, + AuthOptions +> extends BlobStore { + caller: AsyncCaller; + + client: GoogleAbstractedClient; + + constructor(fields?: BlobStoreGoogleParams) { + super(fields); + this.caller = new AsyncCaller(fields ?? {}); + this.client = this.buildClient(fields); + } + + abstract buildClient( + fields?: BlobStoreGoogleParams + ): GoogleAbstractedClient; + + abstract buildSetMetadata([key, blob]: [string, MediaBlob]): Record< + string, + unknown + >; + + abstract buildSetConnection([key, blob]: [ + string, + MediaBlob + ]): GoogleMultipartUploadConnection< + AsyncCallerCallOptions, + ResponseType, + AuthOptions + >; + + async _set(keyValuePair: [string, MediaBlob]): Promise { + const [, blob] = keyValuePair; + const setMetadata = this.buildSetMetadata(keyValuePair); + const metadata = setMetadata; + const options = {}; + const connection = this.buildSetConnection(keyValuePair); + const response = await connection.request(blob, metadata, options); + return response; + } + + async mset(keyValuePairs: [string, MediaBlob][]): Promise { + const ret = keyValuePairs.map((keyValue) => this._set(keyValue)); + await Promise.all(ret); + } + + abstract buildGetMetadataConnection( + key: string + ): GoogleDownloadConnection< + AsyncCallerCallOptions, + ResponseType, + AuthOptions + >; + + async _getMetadata(key: string): Promise> { + const connection = this.buildGetMetadataConnection(key); + const options = {}; + const response = await connection.request(options); + return response.data; + } + + abstract buildGetDataConnection( + key: string + ): GoogleDownloadRawConnection; + + async _getData(key: string): Promise { + const connection = this.buildGetDataConnection(key); + const options = {}; + const response = await connection.request(options); + return response.data; + } + + _getMimetypeFromMetadata(metadata: Record): string { + return metadata.contentType as string; + } + + async _get(key: string): Promise { + const metadata = await this._getMetadata(key); + const data = await this._getData(key); + if (data && metadata) { + const ret = await MediaBlob.fromBlob(data, { metadata, path: key }); + return ret; + } else { + return undefined; + } + } + + async mget(keys: string[]): Promise<(MediaBlob | undefined)[]> { + const ret = keys.map((key) => this._get(key)); + return await Promise.all(ret); + } + + abstract buildDeleteConnection( + key: string + ): GoogleDownloadConnection< + AsyncCallerCallOptions, + GoogleResponse, + AuthOptions + >; + + async _del(key: string): Promise { + const connection = this.buildDeleteConnection(key); + const options = {}; + await connection.request(options); + } + + async mdelete(keys: string[]): Promise { + const ret = keys.map((key) => this._del(key)); + await Promise.all(ret); + } + + // eslint-disable-next-line require-yield + async *yieldKeys(_prefix: string | undefined): AsyncGenerator { + // TODO: Implement. Most have an implementation that uses nextToken. + throw new Error("yieldKeys is not implemented"); + } +} + +/** + * Based on https://cloud.google.com/storage/docs/json_api/v1/objects#resource + */ +export interface GoogleCloudStorageObject extends Record { + id?: string; + name?: string; + contentType?: string; + metadata?: Record; + // This is incomplete. +} + +export interface GoogleCloudStorageResponse extends GoogleResponse { + data: GoogleCloudStorageObject; +} + +export type BucketAndPath = { + bucket: string; + path: string; +}; + +export class GoogleCloudStorageUri { + static uriRegexp = /gs:\/\/([a-z0-9][a-z0-9._-]+[a-z0-9])\/(.*)/; + + bucket: string; + + path: string; + + constructor(uri: string) { + const bucketAndPath = GoogleCloudStorageUri.uriToBucketAndPath(uri); + this.bucket = bucketAndPath.bucket; + this.path = bucketAndPath.path; + } + + get uri() { + return `gs://${this.bucket}/${this.path}`; + } + + get isValid() { + return ( + typeof this.bucket !== "undefined" && typeof this.path !== "undefined" + ); + } + + static uriToBucketAndPath(uri: string): BucketAndPath { + const match = this.uriRegexp.exec(uri); + if (!match) { + throw new Error(`Invalid gs:// URI: ${uri}`); + } + return { + bucket: match[1], + path: match[2], + }; + } + + static isValidUri(uri: string): boolean { + return this.uriRegexp.test(uri); + } +} + +export interface GoogleCloudStorageConnectionParams { + uri: string; +} + +export interface GoogleCloudStorageUploadConnectionParams + extends GoogleUploadConnectionParams, + GoogleCloudStorageConnectionParams {} + +export class GoogleCloudStorageUploadConnection< + AuthOptions +> extends GoogleMultipartUploadConnection< + AsyncCallerCallOptions, + GoogleCloudStorageResponse, + AuthOptions +> { + uri: GoogleCloudStorageUri; + + constructor( + fields: GoogleCloudStorageUploadConnectionParams, + caller: AsyncCaller, + client: GoogleAbstractedClient + ) { + super(fields, caller, client); + this.uri = new GoogleCloudStorageUri(fields.uri); + } + + async buildUrl(): Promise { + return `https://storage.googleapis.com/upload/storage/${this.apiVersion}/b/${this.uri.bucket}/o?uploadType=multipart`; + } +} + +export interface GoogleCloudStorageDownloadConnectionParams + extends GoogleCloudStorageConnectionParams, + GoogleConnectionParams { + method: GoogleAbstractedClientOpsMethod; + alt: "media" | undefined; +} + +export class GoogleCloudStorageDownloadConnection< + ResponseType extends GoogleResponse, + AuthOptions +> extends GoogleDownloadConnection< + AsyncCallerCallOptions, + ResponseType, + AuthOptions +> { + uri: GoogleCloudStorageUri; + + method: GoogleAbstractedClientOpsMethod; + + alt: "media" | undefined; + + constructor( + fields: GoogleCloudStorageDownloadConnectionParams, + caller: AsyncCaller, + client: GoogleAbstractedClient + ) { + super(fields, caller, client); + this.uri = new GoogleCloudStorageUri(fields.uri); + this.method = fields.method; + this.alt = fields.alt; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return this.method; + } + + async buildUrl(): Promise { + const path = encodeURIComponent(this.uri.path); + const ret = `https://storage.googleapis.com/storage/${this.apiVersion}/b/${this.uri.bucket}/o/${path}`; + return this.alt ? `${ret}?alt=${this.alt}` : ret; + } +} + +export interface GoogleCloudStorageRawConnectionParams + extends GoogleCloudStorageConnectionParams, + GoogleConnectionParams {} + +export class GoogleCloudStorageRawConnection< + AuthOptions +> extends GoogleDownloadRawConnection { + uri: GoogleCloudStorageUri; + + constructor( + fields: GoogleCloudStorageRawConnectionParams, + caller: AsyncCaller, + client: GoogleAbstractedClient + ) { + super(fields, caller, client); + this.uri = new GoogleCloudStorageUri(fields.uri); + } + + async buildUrl(): Promise { + const path = encodeURIComponent(this.uri.path); + const ret = `https://storage.googleapis.com/storage/${this.apiVersion}/b/${this.uri.bucket}/o/${path}?alt=media`; + return ret; + } +} + +export interface BlobStoreGoogleCloudStorageBaseParams + extends BlobStoreGoogleParams { + uriPrefix: GoogleCloudStorageUri; +} + +export abstract class BlobStoreGoogleCloudStorageBase< + AuthOptions +> extends BlobStoreGoogle { + params: BlobStoreGoogleCloudStorageBaseParams; + + constructor(fields: BlobStoreGoogleCloudStorageBaseParams) { + super(fields); + this.params = fields; + this.defaultStoreOptions = { + ...this.defaultStoreOptions, + pathPrefix: fields.uriPrefix.uri, + }; + } + + buildSetConnection([key, _blob]: [ + string, + MediaBlob + ]): GoogleMultipartUploadConnection< + AsyncCallerCallOptions, + GoogleCloudStorageResponse, + AuthOptions + > { + const params: GoogleCloudStorageUploadConnectionParams = { + ...this.params, + uri: key, + }; + return new GoogleCloudStorageUploadConnection( + params, + this.caller, + this.client + ); + } + + buildSetMetadata([key, blob]: [string, MediaBlob]): Record { + const uri = new GoogleCloudStorageUri(key); + const ret: GoogleCloudStorageObject = { + name: uri.path, + metadata: blob.metadata, + contentType: blob.mimetype, + }; + return ret; + } + + buildGetMetadataConnection( + key: string + ): GoogleDownloadConnection< + AsyncCallerCallOptions, + GoogleCloudStorageResponse, + AuthOptions + > { + const params: GoogleCloudStorageDownloadConnectionParams = { + uri: key, + method: "GET", + alt: undefined, + }; + return new GoogleCloudStorageDownloadConnection< + GoogleCloudStorageResponse, + AuthOptions + >(params, this.caller, this.client); + } + + buildGetDataConnection( + key: string + ): GoogleDownloadRawConnection { + const params: GoogleCloudStorageRawConnectionParams = { + uri: key, + }; + return new GoogleCloudStorageRawConnection( + params, + this.caller, + this.client + ); + } + + buildDeleteConnection( + key: string + ): GoogleDownloadConnection< + AsyncCallerCallOptions, + GoogleResponse, + AuthOptions + > { + const params: GoogleCloudStorageDownloadConnectionParams = { + uri: key, + method: "DELETE", + alt: undefined, + }; + return new GoogleCloudStorageDownloadConnection< + GoogleResponse, + AuthOptions + >(params, this.caller, this.client); + } +} + +export type AIStudioFileState = + | "PROCESSING" + | "ACTIVE" + | "FAILED" + | "STATE_UNSPECIFIED"; + +export type AIStudioFileVideoMetadata = { + videoMetadata: { + videoDuration: string; // Duration in seconds, possibly with fractional, ending in "s" + }; +}; + +export type AIStudioFileMetadata = AIStudioFileVideoMetadata; + +export interface AIStudioFileObject { + name?: string; + displayName?: string; + mimeType?: string; + sizeBytes?: string; // int64 format + createTime?: string; // timestamp format + updateTime?: string; // timestamp format + expirationTime?: string; // timestamp format + sha256Hash?: string; // base64 encoded + uri?: string; + state?: AIStudioFileState; + error?: { + code: number; + message: string; + details: Record[]; + }; + metadata?: AIStudioFileMetadata; +} + +export class AIStudioMediaBlob extends MediaBlob { + _valueAsDate(value: string): Date { + if (!value) { + return new Date(0); + } + return new Date(value); + } + + _metadataFieldAsDate(field: string): Date { + return this._valueAsDate(this.metadata?.[field]); + } + + get createDate(): Date { + return this._metadataFieldAsDate("createTime"); + } + + get updateDate(): Date { + return this._metadataFieldAsDate("updateTime"); + } + + get expirationDate(): Date { + return this._metadataFieldAsDate("expirationTime"); + } + + get isExpired(): boolean { + const now = new Date().toISOString(); + const exp = this.metadata?.expirationTime ?? now; + return exp <= now; + } +} + +export interface AIStudioFileGetResponse extends GoogleResponse { + data: AIStudioFileObject; +} + +export interface AIStudioFileSaveResponse extends GoogleResponse { + data: { + file: AIStudioFileObject; + }; +} + +export interface AIStudioFileListResponse extends GoogleResponse { + data: { + files: AIStudioFileObject[]; + nextPageToken: string; + }; +} + +export type AIStudioFileResponse = + | AIStudioFileGetResponse + | AIStudioFileSaveResponse + | AIStudioFileListResponse; + +export interface AIStudioFileConnectionParams {} + +export interface AIStudioFileUploadConnectionParams + extends GoogleUploadConnectionParams, + AIStudioFileConnectionParams {} + +export class AIStudioFileUploadConnection< + AuthOptions +> extends GoogleMultipartUploadConnection< + AsyncCallerCallOptions, + AIStudioFileSaveResponse, + AuthOptions +> { + apiVersion = "v1beta"; + + async buildUrl(): Promise { + return `https://generativelanguage.googleapis.com/upload/${this.apiVersion}/files`; + } +} + +export interface AIStudioFileDownloadConnectionParams + extends AIStudioFileConnectionParams, + GoogleConnectionParams { + method: GoogleAbstractedClientOpsMethod; + name: string; +} + +export class AIStudioFileDownloadConnection< + ResponseType extends GoogleResponse, + AuthOptions +> extends GoogleDownloadConnection< + AsyncCallerCallOptions, + ResponseType, + AuthOptions +> { + method: GoogleAbstractedClientOpsMethod; + + name: string; + + apiVersion = "v1beta"; + + constructor( + fields: AIStudioFileDownloadConnectionParams, + caller: AsyncCaller, + client: GoogleAbstractedClient + ) { + super(fields, caller, client); + this.method = fields.method; + this.name = fields.name; + } + + buildMethod(): GoogleAbstractedClientOpsMethod { + return this.method; + } + + async buildUrl(): Promise { + return `https://generativelanguage.googleapis.com/${this.apiVersion}/files/${this.name}`; + } +} + +export interface BlobStoreAIStudioFileBaseParams + extends BlobStoreGoogleParams { + retryTime?: number; +} + +export abstract class BlobStoreAIStudioFileBase< + AuthOptions +> extends BlobStoreGoogle { + params?: BlobStoreAIStudioFileBaseParams; + + retryTime: number = 1000; + + constructor(fields?: BlobStoreAIStudioFileBaseParams) { + const params: BlobStoreAIStudioFileBaseParams = { + defaultStoreOptions: { + pathPrefix: "https://generativelanguage.googleapis.com/v1beta/files/", + actionIfInvalid: "removePath", + }, + ...fields, + }; + super(params); + this.params = params; + this.retryTime = params?.retryTime ?? this.retryTime ?? 1000; + } + + _pathToName(path: string): string { + return path.split("/").pop() ?? path; + } + + abstract buildAbstractedClient( + fields?: BlobStoreGoogleParams + ): GoogleAbstractedClient; + + buildApiKeyClient(apiKey: string): GoogleAbstractedClient { + return new ApiKeyGoogleAuth(apiKey); + } + + buildApiKey(fields?: BlobStoreGoogleParams): string | undefined { + return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); + } + + buildClient( + fields?: BlobStoreGoogleParams + ): GoogleAbstractedClient { + const apiKey = this.buildApiKey(fields); + if (apiKey) { + return this.buildApiKeyClient(apiKey); + } else { + // TODO: Test that you can use OAuth to access + return this.buildAbstractedClient(fields); + } + } + + async _regetMetadata(key: string): Promise { + // Sleep for some time period + // eslint-disable-next-line no-promise-executor-return + await new Promise((resolve) => setTimeout(resolve, this.retryTime)); + + // Fetch the latest metadata + return this._getMetadata(key); + } + + async _set([key, blob]: [ + string, + MediaBlob + ]): Promise { + const response = (await super._set([ + key, + blob, + ])) as AIStudioFileSaveResponse; + + let file = response.data?.file ?? { state: "FAILED" }; + while (file.state === "PROCESSING" && file.uri && this.retryTime > 0) { + file = await this._regetMetadata(file.uri); + } + + // The response should contain the name (and valid URI), so we need to + // update the blob with this. We can't return a new blob, since mset() + // doesn't return anything. + /* eslint-disable no-param-reassign */ + blob.path = file.uri; + blob.metadata = { + ...blob.metadata, + ...file, + }; + /* eslint-enable no-param-reassign */ + + return response; + } + + buildSetConnection([_key, _blob]: [ + string, + MediaBlob + ]): GoogleMultipartUploadConnection< + AsyncCallerCallOptions, + AIStudioFileResponse, + AuthOptions + > { + return new AIStudioFileUploadConnection( + this.params, + this.caller, + this.client + ); + } + + buildSetMetadata([_key, _blob]: [string, MediaBlob]): Record< + string, + unknown + > { + return {}; + } + + buildGetMetadataConnection( + key: string + ): GoogleDownloadConnection< + AsyncCallerCallOptions, + AIStudioFileResponse, + AuthOptions + > { + const params: AIStudioFileDownloadConnectionParams = { + ...this.params, + method: "GET", + name: this._pathToName(key), + }; + return new AIStudioFileDownloadConnection< + AIStudioFileResponse, + AuthOptions + >(params, this.caller, this.client); + } + + buildGetDataConnection( + _key: string + ): GoogleDownloadRawConnection { + throw new Error("AI Studio File API does not provide data"); + } + + async _get(key: string): Promise { + const metadata = await this._getMetadata(key); + if (metadata) { + const contentType = + (metadata?.mimeType as string) ?? "application/octet-stream"; + // TODO - Get the actual data (and other metadata) from an optional backing store + const data: MediaBlobData = { + value: "", + type: contentType, + }; + + return new MediaBlob({ + path: key, + data, + metadata, + }); + } else { + return undefined; + } + } + + buildDeleteConnection( + key: string + ): GoogleDownloadConnection< + AsyncCallerCallOptions, + AIStudioFileResponse, + AuthOptions + > { + const params: AIStudioFileDownloadConnectionParams = { + ...this.params, + method: "DELETE", + name: this._pathToName(key), + }; + return new AIStudioFileDownloadConnection< + AIStudioFileResponse, + AuthOptions + >(params, this.caller, this.client); + } +} diff --git a/libs/langchain-google-common/src/experimental/utils/media_core.ts b/libs/langchain-google-common/src/experimental/utils/media_core.ts new file mode 100644 index 000000000000..f27d5c55ed52 --- /dev/null +++ b/libs/langchain-google-common/src/experimental/utils/media_core.ts @@ -0,0 +1,669 @@ +import { v1, v4 } from "uuid"; // FIXME - it is importing the wrong uuid, so v6 and v7 aren't implemented +import { BaseStore } from "@langchain/core/stores"; +import { Serializable } from "@langchain/core/load/serializable"; + +export type MediaBlobData = { + value: string; // In Base64 encoding + type: string; // The mime type and possibly encoding +}; + +export interface MediaBlobParameters { + data?: MediaBlobData; + + metadata?: Record; + + path?: string; +} + +function bytesToString(dataArray: Uint8Array): string { + // Need to handle the array in smaller chunks to deal with stack size limits + let ret = ""; + const chunkSize = 102400; + for (let i = 0; i < dataArray.length; i += chunkSize) { + const chunk = dataArray.subarray(i, i + chunkSize); + ret += String.fromCharCode(...chunk); + } + + return ret; +} + +/** + * Represents a chunk of data that can be identified by the path where the + * data is (or will be) located, along with optional metadata about the data. + */ +export class MediaBlob extends Serializable implements MediaBlobParameters { + lc_serializable = true; + + lc_namespace = [ + "langchain", + "google_common", + "experimental", + "utils", + "media_core", + ]; + + data: MediaBlobData = { + value: "", + type: "text/plain", + }; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + metadata?: Record; + + path?: string; + + constructor(params: MediaBlobParameters) { + super(params); + + this.data = params.data ?? this.data; + this.metadata = params.metadata; + this.path = params.path; + } + + get size(): number { + return this.asBytes.length; + } + + get dataType(): string { + return this.data?.type ?? ""; + } + + get encoding(): string { + const charsetEquals = this.dataType.indexOf("charset="); + return charsetEquals === -1 + ? "utf-8" + : this.dataType.substring(charsetEquals + 8); + } + + get mimetype(): string { + const semicolon = this.dataType.indexOf(";"); + return semicolon === -1 + ? this.dataType + : this.dataType.substring(0, semicolon); + } + + get asBytes(): Uint8Array { + if (!this.data) { + return Uint8Array.from([]); + } + const binString = atob(this.data?.value); + const ret = new Uint8Array(binString.length); + for (let co = 0; co < binString.length; co += 1) { + ret[co] = binString.charCodeAt(co); + } + return ret; + } + + async asString(): Promise { + return bytesToString(this.asBytes); + } + + async asBase64(): Promise { + return this.data?.value ?? ""; + } + + async asDataUrl(): Promise { + return `data:${this.mimetype};base64,${await this.asBase64()}`; + } + + async asUri(): Promise { + return this.path ?? (await this.asDataUrl()); + } + + async encode(): Promise<{ encoded: string; encoding: string }> { + const dataUrl = await this.asDataUrl(); + const comma = dataUrl.indexOf(","); + const encoded = dataUrl.substring(comma + 1); + const encoding: string = dataUrl.indexOf("base64") > -1 ? "base64" : "8bit"; + return { + encoded, + encoding, + }; + } + + static fromDataUrl(url: string): MediaBlob { + if (!url.startsWith("data:")) { + throw new Error("Not a data: URL"); + } + const colon = url.indexOf(":"); + const semicolon = url.indexOf(";"); + const mimeType = url.substring(colon + 1, semicolon); + + const comma = url.indexOf(","); + const base64Data = url.substring(comma + 1); + + const data: MediaBlobData = { + type: mimeType, + value: base64Data, + }; + + return new MediaBlob({ + data, + path: url, + }); + } + + static async fromBlob( + blob: Blob, + other?: Omit + ): Promise { + const valueBuffer = await blob.arrayBuffer(); + const valueArray = new Uint8Array(valueBuffer); + const valueStr = bytesToString(valueArray); + const value = btoa(valueStr); + + return new MediaBlob({ + ...other, + data: { + value, + type: blob.type, + }, + }); + } +} + +export type ActionIfInvalidAction = + | "ignore" + | "prefixPath" + | "prefixUuid1" + | "prefixUuid4" + | "prefixUuid6" + | "prefixUuid7" + | "removePath"; + +export interface BlobStoreStoreOptions { + /** + * If the path is missing or invalid in the blob, how should we create + * a new path? + * Subclasses may define their own methods, but the following are supported + * by default: + * - Undefined or an emtpy string: Reject the blob + * - "ignore": Attempt to store it anyway (but this may fail) + * - "prefixPath": Use the default prefix for the BlobStore and get the + * unique portion from the URL. The original path is stored in the metadata + * - "prefixUuid": Use the default prefix for the BlobStore and get the + * unique portion from a generated UUID. The original path is stored + * in the metadata + */ + actionIfInvalid?: ActionIfInvalidAction; + + /** + * The expected prefix for URIs that are stored. + * This may be used to test if a MediaBlob is valid and used to create a new + * path if "prefixPath" or "prefixUuid" is set for actionIfInvalid. + */ + pathPrefix?: string; +} + +export type ActionIfBlobMissingAction = "emptyBlob"; + +export interface BlobStoreFetchOptions { + /** + * If the blob is not found when fetching, what should we do? + * Subclasses may define their own methods, but the following are supported + * by default: + * - Undefined or an empty string: return undefined + * - "emptyBlob": return a new MediaBlob that has the path set, but nothing else. + */ + actionIfBlobMissing?: ActionIfBlobMissingAction; +} + +export interface BlobStoreOptions { + defaultStoreOptions?: BlobStoreStoreOptions; + + defaultFetchOptions?: BlobStoreFetchOptions; +} + +/** + * A specialized Store that is designed to handle MediaBlobs and use the + * key that is included in the blob to determine exactly how it is stored. + * + * The full details of a MediaBlob may be changed when it is stored. + * For example, it may get additional or different Metadata. This should be + * what is returned when the store() method is called. + * + * Although BlobStore extends BaseStore, not all of the methods from + * BaseStore may be implemented (or even possible). Those that are not + * implemented should be documented and throw an Error if called. + */ +export abstract class BlobStore extends BaseStore { + lc_namespace = ["langchain", "google-common"]; // FIXME - What should this be? And why? + + defaultStoreOptions: BlobStoreStoreOptions; + + defaultFetchOptions: BlobStoreFetchOptions; + + constructor(opts?: BlobStoreOptions) { + super(opts); + this.defaultStoreOptions = opts?.defaultStoreOptions ?? {}; + this.defaultFetchOptions = opts?.defaultFetchOptions ?? {}; + } + + protected async _realKey(key: string | MediaBlob): Promise { + return typeof key === "string" ? key : await key.asUri(); + } + + /** + * Is the path supported by this BlobStore? + * + * Although this is async, this is expected to be a relatively fast operation + * (ie - you shouldn't make network calls). + * + * @param path The path to check + * @param opts Any options (if needed) that may be used to determine if it is valid + * @return If the path is supported + */ + hasValidPath( + path: string | undefined, + opts?: BlobStoreStoreOptions + ): Promise { + const prefix = opts?.pathPrefix ?? ""; + const isPrefixed = typeof path !== "undefined" && path.startsWith(prefix); + return Promise.resolve(isPrefixed); + } + + protected _blobPathSuffix(blob: MediaBlob): string { + // Get the path currently set and make sure we treat it as a string + const blobPath = `${blob.path}`; + + // Advance past the first set of / + let pathStart = blobPath.indexOf("/") + 1; + while (blobPath.charAt(pathStart) === "/") { + pathStart += 1; + } + + // We will use the rest as the path for a replacement + return blobPath.substring(pathStart); + } + + protected async _newBlob( + oldBlob: MediaBlob, + newPath?: string + ): Promise { + const oldPath = oldBlob.path; + const metadata = oldBlob?.metadata ?? {}; + metadata.langchainOldPath = oldPath; + const newBlob = new MediaBlob({ + ...oldBlob, + metadata, + }); + if (newPath) { + newBlob.path = newPath; + } else if (newBlob.path) { + delete newBlob.path; + } + return newBlob; + } + + protected async _validBlobPrefixPath( + blob: MediaBlob, + opts?: BlobStoreStoreOptions + ): Promise { + const prefix = opts?.pathPrefix ?? ""; + const suffix = this._blobPathSuffix(blob); + const newPath = `${prefix}${suffix}`; + return this._newBlob(blob, newPath); + } + + protected _validBlobPrefixUuidFunction( + name: ActionIfInvalidAction | string + ): string { + switch (name) { + case "prefixUuid1": + return v1(); + case "prefixUuid4": + return v4(); + // case "prefixUuid6": return v6(); + // case "prefixUuid7": return v7(); + default: + throw new Error(`Unknown uuid function: ${name}`); + } + } + + protected async _validBlobPrefixUuid( + blob: MediaBlob, + opts?: BlobStoreStoreOptions + ): Promise { + const prefix = opts?.pathPrefix ?? ""; + const suffix = this._validBlobPrefixUuidFunction( + opts?.actionIfInvalid ?? "prefixUuid4" + ); + const newPath = `${prefix}${suffix}`; + return this._newBlob(blob, newPath); + } + + protected async _validBlobRemovePath( + blob: MediaBlob, + _opts?: BlobStoreStoreOptions + ): Promise { + return this._newBlob(blob, undefined); + } + + /** + * Based on the blob and options, return a blob that has a valid path + * that can be saved. + * @param blob + * @param opts + */ + protected async _validStoreBlob( + blob: MediaBlob, + opts?: BlobStoreStoreOptions + ): Promise { + if (await this.hasValidPath(blob.path, opts)) { + return blob; + } + switch (opts?.actionIfInvalid) { + case "ignore": + return blob; + case "prefixPath": + return this._validBlobPrefixPath(blob, opts); + case "prefixUuid1": + case "prefixUuid4": + case "prefixUuid6": + case "prefixUuid7": + return this._validBlobPrefixUuid(blob, opts); + case "removePath": + return this._validBlobRemovePath(blob, opts); + default: + return undefined; + } + } + + async store( + blob: MediaBlob, + opts: BlobStoreStoreOptions = {} + ): Promise { + const allOpts: BlobStoreStoreOptions = { + ...this.defaultStoreOptions, + ...opts, + }; + const validBlob = await this._validStoreBlob(blob, allOpts); + if (typeof validBlob !== "undefined") { + const validKey = await validBlob.asUri(); + await this.mset([[validKey, validBlob]]); + const savedKey = await validBlob.asUri(); + return await this.fetch(savedKey); + } + return undefined; + } + + protected async _missingFetchBlobEmpty( + path: string, + _opts?: BlobStoreFetchOptions + ): Promise { + return new MediaBlob({ path }); + } + + protected async _missingFetchBlob( + path: string, + opts?: BlobStoreFetchOptions + ): Promise { + switch (opts?.actionIfBlobMissing) { + case "emptyBlob": + return this._missingFetchBlobEmpty(path, opts); + default: + return undefined; + } + } + + async fetch( + key: string | MediaBlob, + opts: BlobStoreFetchOptions = {} + ): Promise { + const allOpts: BlobStoreFetchOptions = { + ...this.defaultFetchOptions, + ...opts, + }; + const realKey = await this._realKey(key); + const ret = await this.mget([realKey]); + return ret?.[0] ?? (await this._missingFetchBlob(realKey, allOpts)); + } +} + +export interface BackedBlobStoreOptions extends BlobStoreOptions { + backingStore: BaseStore; +} + +export class BackedBlobStore extends BlobStore { + backingStore: BaseStore; + + constructor(opts: BackedBlobStoreOptions) { + super(opts); + this.backingStore = opts.backingStore; + } + + mdelete(keys: string[]): Promise { + return this.backingStore.mdelete(keys); + } + + mget(keys: string[]): Promise<(MediaBlob | undefined)[]> { + return this.backingStore.mget(keys); + } + + mset(keyValuePairs: [string, MediaBlob][]): Promise { + return this.backingStore.mset(keyValuePairs); + } + + yieldKeys(prefix: string | undefined): AsyncGenerator { + return this.backingStore.yieldKeys(prefix); + } +} + +export interface ReadThroughBlobStoreOptions extends BlobStoreOptions { + baseStore: BlobStore; + backingStore: BlobStore; +} + +export class ReadThroughBlobStore extends BlobStore { + baseStore: BlobStore; + + backingStore: BlobStore; + + constructor(opts: ReadThroughBlobStoreOptions) { + super(opts); + this.baseStore = opts.baseStore; + this.backingStore = opts.backingStore; + } + + async store( + blob: MediaBlob, + opts: BlobStoreStoreOptions = {} + ): Promise { + const originalUri = await blob.asUri(); + const newBlob = await this.backingStore.store(blob, opts); + if (newBlob) { + await this.baseStore.mset([[originalUri, newBlob]]); + } + return newBlob; + } + + mdelete(keys: string[]): Promise { + return this.baseStore.mdelete(keys); + } + + mget(keys: string[]): Promise<(MediaBlob | undefined)[]> { + return this.baseStore.mget(keys); + } + + mset(_keyValuePairs: [string, MediaBlob][]): Promise { + throw new Error("Do not call ReadThroughBlobStore.mset directly"); + } + + yieldKeys(prefix: string | undefined): AsyncGenerator { + return this.baseStore.yieldKeys(prefix); + } +} + +export class SimpleWebBlobStore extends BlobStore { + _notImplementedException() { + throw new Error("Not implemented for SimpleWebBlobStore"); + } + + async hasValidPath( + path: string | undefined, + _opts?: BlobStoreStoreOptions + ): Promise { + return ( + (await super.hasValidPath(path, { pathPrefix: "https://" })) || + (await super.hasValidPath(path, { pathPrefix: "http://" })) + ); + } + + async _fetch(url: string): Promise { + const ret = new MediaBlob({ + path: url, + }); + const metadata: Record = {}; + const fetchOptions = { + method: "GET", + }; + const res = await fetch(url, fetchOptions); + metadata.status = res.status; + + const headers: Record = {}; + for (const [key, value] of res.headers.entries()) { + headers[key] = value; + } + metadata.headers = headers; + + metadata.ok = res.ok; + if (res.ok) { + const resMediaBlob = await MediaBlob.fromBlob(await res.blob()); + ret.data = resMediaBlob.data; + } + + ret.metadata = metadata; + return ret; + } + + async mget(keys: string[]): Promise<(MediaBlob | undefined)[]> { + const blobMap = keys.map(this._fetch); + return await Promise.all(blobMap); + } + + async mdelete(_keys: string[]): Promise { + this._notImplementedException(); + } + + async mset(_keyValuePairs: [string, MediaBlob][]): Promise { + this._notImplementedException(); + } + + async *yieldKeys(_prefix: string | undefined): AsyncGenerator { + this._notImplementedException(); + yield ""; + } +} + +/** + * A blob "store" that works with data: URLs that will turn the URL into + * a blob. + */ +export class DataBlobStore extends BlobStore { + _notImplementedException() { + throw new Error("Not implemented for DataBlobStore"); + } + + hasValidPath(path: string, _opts?: BlobStoreStoreOptions): Promise { + return super.hasValidPath(path, { pathPrefix: "data:" }); + } + + _fetch(url: string): MediaBlob { + return MediaBlob.fromDataUrl(url); + } + + async mget(keys: string[]): Promise<(MediaBlob | undefined)[]> { + const blobMap = keys.map(this._fetch); + return blobMap; + } + + async mdelete(_keys: string[]): Promise { + this._notImplementedException(); + } + + async mset(_keyValuePairs: [string, MediaBlob][]): Promise { + this._notImplementedException(); + } + + async *yieldKeys(_prefix: string | undefined): AsyncGenerator { + this._notImplementedException(); + yield ""; + } +} + +export interface MediaManagerConfiguration { + /** + * A store that, given a common URI, returns the corresponding MediaBlob. + * The returned MediaBlob may have a different URI. + * In many cases, this will be a ReadThroughStore or something similar + * that has a cached version of the MediaBlob, but also a way to get + * a new (or refreshed) version. + */ + store: BlobStore; + + /** + * BlobStores that can resolve a URL into the MediaBlob to save + * in the canonical store. This list is evaluated in order. + * If not provided, a default list (which involves a DataBlobStore + * and a SimpleWebBlobStore) will be used. + */ + resolvers?: BlobStore[]; +} + +/** + * Responsible for converting a URI (typically a web URL) into a MediaBlob. + * Allows for aliasing / caching of the requested URI and what it resolves to. + * This MediaBlob is expected to be usable to provide to an LLM, either + * through the Base64 of the media or through a canonical URI that the LLM + * supports. + */ +export class MediaManager { + store: BlobStore; + + resolvers: BlobStore[] | undefined; + + constructor(config: MediaManagerConfiguration) { + this.store = config.store; + this.resolvers = config.resolvers; + } + + defaultResolvers(): BlobStore[] { + return [new DataBlobStore({}), new SimpleWebBlobStore({})]; + } + + async _isInvalid(blob: MediaBlob | undefined): Promise { + return typeof blob === "undefined"; + } + + /** + * Given the public URI, load what is at this URI and save it + * in the store. + * @param uri The URI to resolve using the resolver + * @return A canonical MediaBlob for this URI + */ + async _resolveAndSave(uri: string): Promise { + let resolvedBlob: MediaBlob | undefined; + + const resolvers = this.resolvers || this.defaultResolvers(); + for (let co = 0; co < resolvers.length; co += 1) { + const resolver = resolvers[co]; + if (await resolver.hasValidPath(uri)) { + resolvedBlob = await resolver.fetch(uri); + } + } + + if (resolvedBlob) { + return await this.store.store(resolvedBlob); + } else { + return new MediaBlob({}); + } + } + + async getMediaBlob(uri: string): Promise { + const aliasBlob = await this.store.fetch(uri); + const ret = (await this._isInvalid(aliasBlob)) + ? await this._resolveAndSave(uri) + : (aliasBlob as MediaBlob); + return ret; + } +} diff --git a/libs/langchain-google-common/src/llms.ts b/libs/langchain-google-common/src/llms.ts index 347098177186..e9c267aaaf65 100644 --- a/libs/langchain-google-common/src/llms.ts +++ b/libs/langchain-google-common/src/llms.ts @@ -21,13 +21,7 @@ import { copyAIModelParams, copyAndValidateModelParamsInto, } from "./utils/common.js"; -import { - chunkToString, - messageContentToParts, - safeResponseToBaseMessage, - safeResponseToString, - DefaultGeminiSafetyHandler, -} from "./utils/gemini.js"; +import { DefaultGeminiSafetyHandler } from "./utils/gemini.js"; import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js"; import { ensureParams } from "./utils/failed_handler.js"; import { ChatGoogleBase } from "./chat_models.js"; @@ -39,11 +33,11 @@ class GoogleLLMConnection extends AbstractGoogleLLMConnection< MessageContent, AuthOptions > { - formatContents( + async formatContents( input: MessageContent, _parameters: GoogleAIModelParams - ): GeminiContent[] { - const parts = messageContentToParts(input); + ): Promise { + const parts = await this.api.messageContentToParts(input); const contents: GeminiContent[] = [ { role: "user", // Required by Vertex AI @@ -189,7 +183,10 @@ export abstract class GoogleBaseLLM ): Promise { const parameters = copyAIModelParams(this, options); const result = await this.connection.request(prompt, parameters, options); - const ret = safeResponseToString(result, this.safetyHandler); + const ret = this.connection.api.safeResponseToString( + result, + this.safetyHandler + ); return ret; } @@ -234,7 +231,7 @@ export abstract class GoogleBaseLLM const proxyChat = this.createProxyChat(); try { for await (const chunk of proxyChat._streamIterator(input, options)) { - const stringValue = chunkToString(chunk); + const stringValue = this.connection.api.chunkToString(chunk); const generationChunk = new GenerationChunk({ text: stringValue, }); @@ -267,7 +264,10 @@ export abstract class GoogleBaseLLM {}, options as BaseLanguageModelCallOptions ); - const ret = safeResponseToBaseMessage(result, this.safetyHandler); + const ret = this.connection.api.safeResponseToBaseMessage( + result, + this.safetyHandler + ); return ret; } diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index dda4b68033ce..9da477df3e0e 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -9,6 +9,7 @@ import { SystemMessage, ToolMessage, } from "@langchain/core/messages"; +import { InMemoryStore } from "@langchain/core/stores"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; @@ -17,6 +18,12 @@ import { authOptions, MockClient, MockClientAuthInfo, mockId } from "./mock.js"; import { GeminiTool, GoogleAIBaseLLMInput } from "../types.js"; import { GoogleAbstractedClient } from "../auth.js"; import { GoogleAISafetyError } from "../utils/safety.js"; +import { + BackedBlobStore, + MediaBlob, + MediaManager, + ReadThroughBlobStore, +} from "../experimental/utils/media_core.js"; import { removeAdditionalProperties } from "../utils/zod_to_gemini_parameters.js"; class ChatGoogle extends ChatGoogleBase { @@ -502,10 +509,6 @@ describe("Mock ChatGoogle", () => { expect(caught).toEqual(true); }); - /* - * Images aren't supported (yet) by Gemini, but a one-round with - * image should work ok. - */ test("3. invoke - images", async () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; @@ -517,7 +520,7 @@ describe("Mock ChatGoogle", () => { }; const model = new ChatGoogle({ authOptions, - model: "gemini-pro-vision", + model: "gemini-1.5-flash", }); const message: MessageContentComplex[] = [ @@ -552,6 +555,200 @@ describe("Mock ChatGoogle", () => { expect(result.content).toBe("A blue square."); }); + test("3. invoke - media - invalid", async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-3-mock.json", + }; + const model = new ChatGoogle({ + authOptions, + model: "gemini-1.5-flash", + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "mock://example.com/blue-box.png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + try { + const result = await model.invoke(messages); + expect(result).toBeUndefined(); + } catch (e) { + expect((e as Error).message).toEqual("Invalid media content"); + } + }); + + test("3. invoke - media - no manager", async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-3-mock.json", + }; + const model = new ChatGoogle({ + authOptions, + model: "gemini-1.5-flash", + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "mock://example.com/blue-box.png", + mimeType: "image/png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + const result = await model.invoke(messages); + + console.log(JSON.stringify(record.opts, null, 1)); + + expect(record.opts).toHaveProperty("data"); + expect(record.opts.data).toHaveProperty("contents"); + expect(record.opts.data.contents).toHaveLength(1); + expect(record.opts.data.contents[0]).toHaveProperty("parts"); + + const parts = record?.opts?.data?.contents[0]?.parts; + expect(parts).toHaveLength(2); + expect(parts[0]).toHaveProperty("text"); + expect(parts[1]).toHaveProperty("fileData"); + expect(parts[1].fileData).toHaveProperty("mimeType"); + expect(parts[1].fileData).toHaveProperty("fileUri"); + + expect(result.content).toBe("A blue square."); + }); + + test("3. invoke - media - manager", async () => { + class MemStore extends InMemoryStore { + get length() { + return Object.keys(this.store).length; + } + } + + const aliasMemory = new MemStore(); + const aliasStore = new BackedBlobStore({ + backingStore: aliasMemory, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + const canonicalMemory = new MemStore(); + const canonicalStore = new BackedBlobStore({ + backingStore: canonicalMemory, + defaultStoreOptions: { + pathPrefix: "canonical://store/", + actionIfInvalid: "prefixPath", + }, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + const blobStore = new ReadThroughBlobStore({ + baseStore: aliasStore, + backingStore: canonicalStore, + }); + const resolverMemory = new MemStore(); + const resolver = new BackedBlobStore({ + backingStore: resolverMemory, + defaultFetchOptions: { + actionIfBlobMissing: "emptyBlob", + }, + }); + const mediaManager = new MediaManager({ + store: blobStore, + resolvers: [resolver], + }); + + async function store(path: string, text: string): Promise { + const type = path.endsWith(".png") ? "image/png" : "text/plain"; + const blob = new MediaBlob({ + data: { + value: text, + type, + }, + path, + }); + await resolver.store(blob); + } + await store("resolve://host/foo", "fooing"); + await store("resolve://host2/bar/baz", "barbazing"); + await store("resolve://host/foo/blue-box.png", "png"); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-3-mock.json", + }; + const model = new ChatGoogle({ + authOptions, + model: "gemini-1.5-flash", + mediaManager, + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "resolve://host/foo/blue-box.png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + const result = await model.invoke(messages); + + console.log(JSON.stringify(record.opts, null, 1)); + + expect(record.opts).toHaveProperty("data"); + expect(record.opts.data).toHaveProperty("contents"); + expect(record.opts.data.contents).toHaveLength(1); + expect(record.opts.data.contents[0]).toHaveProperty("parts"); + + const parts = record?.opts?.data?.contents[0]?.parts; + expect(parts).toHaveLength(2); + expect(parts[0]).toHaveProperty("text"); + expect(parts[1]).toHaveProperty("fileData"); + expect(parts[1].fileData).toHaveProperty("mimeType"); + expect(parts[1].fileData.mimeType).toEqual("image/png"); + expect(parts[1].fileData).toHaveProperty("fileUri"); + expect(parts[1].fileData.fileUri).toEqual( + "canonical://store/host/foo/blue-box.png" + ); + + expect(result.content).toBe("A blue square."); + }); + test("4. Functions Bind - Gemini format request", async () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; diff --git a/libs/langchain-google-common/src/tests/utils.test.ts b/libs/langchain-google-common/src/tests/utils.test.ts index c0398a218c67..547392c397e7 100644 --- a/libs/langchain-google-common/src/tests/utils.test.ts +++ b/libs/langchain-google-common/src/tests/utils.test.ts @@ -1,85 +1,419 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { expect, test } from "@jest/globals"; +import { beforeEach, expect, test } from "@jest/globals"; +import { InMemoryStore } from "@langchain/core/stores"; +import { SerializedConstructor } from "@langchain/core/load/serializable"; +import { load } from "@langchain/core/load"; import { z } from "zod"; import { zodToGeminiParameters } from "../utils/zod_to_gemini_parameters.js"; +import { + BackedBlobStore, + BlobStore, + MediaBlob, + MediaManager, + ReadThroughBlobStore, + SimpleWebBlobStore, +} from "../experimental/utils/media_core.js"; import { ReadableJsonStream } from "../utils/stream.js"; -test("zodToGeminiParameters can convert zod schema to gemini schema", () => { - const zodSchema = z - .object({ - operation: z - .enum(["add", "subtract", "multiply", "divide"]) - .describe("The type of operation to execute"), - number1: z.number().describe("The first number to operate on."), - number2: z.number().describe("The second number to operate on."), - childObject: z.object({}), - }) - .describe("A simple calculator tool"); - - const convertedSchema = zodToGeminiParameters(zodSchema); - - expect(convertedSchema.type).toBe("object"); - expect(convertedSchema.description).toBe("A simple calculator tool"); - expect((convertedSchema as any).additionalProperties).toBeUndefined(); - expect(convertedSchema.properties).toEqual({ - operation: { - type: "string", - enum: ["add", "subtract", "multiply", "divide"], - description: "The type of operation to execute", - }, - number1: { - type: "number", - description: "The first number to operate on.", - }, - number2: { - type: "number", - description: "The second number to operate on.", - }, - childObject: { - type: "object", - properties: {}, - }, +describe("zodToGeminiParameters", () => { + test("can convert zod schema to gemini schema", () => { + const zodSchema = z + .object({ + operation: z + .enum(["add", "subtract", "multiply", "divide"]) + .describe("The type of operation to execute"), + number1: z.number().describe("The first number to operate on."), + number2: z.number().describe("The second number to operate on."), + childObject: z.object({}), + }) + .describe("A simple calculator tool"); + + const convertedSchema = zodToGeminiParameters(zodSchema); + + expect(convertedSchema.type).toBe("object"); + expect(convertedSchema.description).toBe("A simple calculator tool"); + expect((convertedSchema as any).additionalProperties).toBeUndefined(); + expect(convertedSchema.properties).toEqual({ + operation: { + type: "string", + enum: ["add", "subtract", "multiply", "divide"], + description: "The type of operation to execute", + }, + number1: { + type: "number", + description: "The first number to operate on.", + }, + number2: { + type: "number", + description: "The second number to operate on.", + }, + childObject: { + type: "object", + properties: {}, + }, + }); + expect(convertedSchema.required).toEqual([ + "operation", + "number1", + "number2", + "childObject", + ]); + }); + + test("removes additional properties from arrays", () => { + const zodSchema = z + .object({ + people: z + .object({ + name: z.string().describe("The name of a person"), + }) + .array() + .describe("person elements"), + }) + .describe("A list of people"); + + const convertedSchema = zodToGeminiParameters(zodSchema); + expect(convertedSchema.type).toBe("object"); + expect(convertedSchema.description).toBe("A list of people"); + expect((convertedSchema as any).additionalProperties).toBeUndefined(); + + const peopleSchema = convertedSchema?.properties?.people; + expect(peopleSchema).not.toBeUndefined(); + + if (peopleSchema !== undefined) { + expect(peopleSchema.type).toBe("array"); + expect((peopleSchema as any).additionalProperties).toBeUndefined(); + expect(peopleSchema.description).toBe("person elements"); + } + + const arrayItemsSchema = peopleSchema?.items; + expect(arrayItemsSchema).not.toBeUndefined(); + if (arrayItemsSchema !== undefined) { + expect(arrayItemsSchema.type).toBe("object"); + expect((arrayItemsSchema as any).additionalProperties).toBeUndefined(); + } }); - expect(convertedSchema.required).toEqual([ - "operation", - "number1", - "number2", - "childObject", - ]); }); -test("zodToGeminiParameters removes additional properties from arrays", () => { - const zodSchema = z - .object({ - people: z - .object({ - name: z.string().describe("The name of a person"), - }) - .array() - .describe("person elements"), - }) - .describe("A list of people"); - - const convertedSchema = zodToGeminiParameters(zodSchema); - expect(convertedSchema.type).toBe("object"); - expect(convertedSchema.description).toBe("A list of people"); - expect((convertedSchema as any).additionalProperties).toBeUndefined(); - - const peopleSchema = convertedSchema?.properties?.people; - expect(peopleSchema).not.toBeUndefined(); - - if (peopleSchema !== undefined) { - expect(peopleSchema.type).toBe("array"); - expect((peopleSchema as any).additionalProperties).toBeUndefined(); - expect(peopleSchema.description).toBe("person elements"); - } - - const arrayItemsSchema = peopleSchema?.items; - expect(arrayItemsSchema).not.toBeUndefined(); - if (arrayItemsSchema !== undefined) { - expect(arrayItemsSchema.type).toBe("object"); - expect((arrayItemsSchema as any).additionalProperties).toBeUndefined(); - } +describe("media core", () => { + test("MediaBlob plain", async () => { + const blob = new Blob(["This is a test"], { type: "text/plain" }); + const mblob = await MediaBlob.fromBlob(blob); + expect(mblob.dataType).toEqual("text/plain"); + expect(mblob.mimetype).toEqual("text/plain"); + expect(mblob.encoding).toEqual("utf-8"); + expect(await mblob.asString()).toEqual("This is a test"); + }); + + test("MediaBlob charset", async () => { + const blob = new Blob(["This is a test"], { + type: "text/plain; charset=US-ASCII", + }); + const mblob = await MediaBlob.fromBlob(blob); + expect(mblob.dataType).toEqual("text/plain; charset=us-ascii"); + expect(mblob.mimetype).toEqual("text/plain"); + expect(mblob.encoding).toEqual("us-ascii"); + expect(await mblob.asString()).toEqual("This is a test"); + }); + + test("MediaBlob fromDataUrl", async () => { + const blobData = "This is a test"; + const blobMimeType = "text/plain"; + const blobDataType = `${blobMimeType}; charset=US-ASCII`; + const blob = new Blob([blobData], { + type: blobDataType, + }); + const mblob = await MediaBlob.fromBlob(blob); + const dataUrl = await mblob.asDataUrl(); + const dblob = MediaBlob.fromDataUrl(dataUrl); + expect(await dblob.asString()).toEqual(blobData); + expect(dblob.mimetype).toEqual(blobMimeType); + }); + + test("MediaBlob serialize", async () => { + const blob = new Blob(["This is a test"], { type: "text/plain" }); + const mblob = await MediaBlob.fromBlob(blob); + console.log("serialize mblob", mblob); + const serialized = mblob.toJSON() as SerializedConstructor; + console.log("serialized", serialized); + expect(serialized.kwargs).toHaveProperty("data"); + expect(serialized.kwargs.data.value).toEqual("VGhpcyBpcyBhIHRlc3Q="); + }); + + test("MediaBlob deserialize", async () => { + const serialized: SerializedConstructor = { + lc: 1, + type: "constructor", + id: [ + "langchain", + "google_common", + "experimental", + "utils", + "media_core", + "MediaBlob", + ], + kwargs: { + data: { + value: "VGhpcyBpcyBhIHRlc3Q=", + type: "text/plain", + }, + }, + }; + const mblob: MediaBlob = await load(JSON.stringify(serialized), { + importMap: { + google_common__experimental__utils__media_core: await import( + "../experimental/utils/media_core.js" + ), + }, + }); + console.log("deserialize mblob", mblob); + expect(mblob.dataType).toEqual("text/plain"); + expect(await mblob.asString()).toEqual("This is a test"); + }); + + test("SimpleWebBlobStore fetch", async () => { + const webStore = new SimpleWebBlobStore(); + const exampleBlob = await webStore.fetch("http://example.com/"); + console.log(exampleBlob); + expect(exampleBlob?.mimetype).toEqual("text/html"); + expect(exampleBlob?.encoding).toEqual("utf-8"); + expect(exampleBlob?.size).toBeGreaterThan(0); + expect(exampleBlob?.metadata).toBeDefined(); + expect(exampleBlob?.metadata?.ok).toBeTruthy(); + expect(exampleBlob?.metadata?.status).toEqual(200); + }); + + describe("BackedBlobStore", () => { + test("simple", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + }); + const data = new Blob(["This is a test"], { type: "text/plain" }); + const path = "simple://foo"; + const blob = await MediaBlob.fromBlob(data, { path }); + const storedBlob = await store.store(blob); + expect(storedBlob).toBeDefined(); + const fetchedBlob = await store.fetch(path); + expect(fetchedBlob).toBeDefined(); + }); + + test("missing undefined", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + }); + const path = "simple://foo"; + const fetchedBlob = await store.fetch(path); + expect(fetchedBlob).toBeUndefined(); + }); + + test("missing emptyBlob defaultConfig", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + defaultFetchOptions: { + actionIfBlobMissing: "emptyBlob", + }, + }); + const path = "simple://foo"; + const fetchedBlob = await store.fetch(path); + expect(fetchedBlob).toBeDefined(); + expect(fetchedBlob?.size).toEqual(0); + expect(fetchedBlob?.path).toEqual(path); + }); + + test("missing undefined fetch", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + defaultFetchOptions: { + actionIfBlobMissing: "emptyBlob", + }, + }); + const path = "simple://foo"; + const fetchedBlob = await store.fetch(path, { + actionIfBlobMissing: undefined, + }); + expect(fetchedBlob).toBeUndefined(); + }); + + test("invalid undefined", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + defaultStoreOptions: { + pathPrefix: "example://bar/", + }, + }); + const path = "simple://foo"; + const data = new Blob(["This is a test"], { type: "text/plain" }); + const blob = await MediaBlob.fromBlob(data, { path }); + const storedBlob = await store.store(blob); + expect(storedBlob).toBeUndefined(); + }); + + test("invalid ignore", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + defaultStoreOptions: { + actionIfInvalid: "ignore", + pathPrefix: "example://bar/", + }, + }); + const path = "simple://foo"; + const data = new Blob(["This is a test"], { type: "text/plain" }); + const blob = await MediaBlob.fromBlob(data, { path }); + const storedBlob = await store.store(blob); + expect(storedBlob).toBeDefined(); + expect(storedBlob?.path).toEqual(path); + expect(storedBlob?.metadata).toBeUndefined(); + }); + + test("invalid prefixPath", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + defaultStoreOptions: { + actionIfInvalid: "prefixPath", + pathPrefix: "example://bar/", + }, + }); + const path = "simple://foo"; + const data = new Blob(["This is a test"], { type: "text/plain" }); + const blob = await MediaBlob.fromBlob(data, { path }); + const storedBlob = await store.store(blob); + expect(storedBlob?.path).toEqual("example://bar/foo"); + expect(await storedBlob?.asString()).toEqual("This is a test"); + expect(storedBlob?.metadata?.langchainOldPath).toEqual(path); + }); + + test("invalid prefixUuid", async () => { + const backingStore = new InMemoryStore(); + const store = new BackedBlobStore({ + backingStore, + defaultStoreOptions: { + actionIfInvalid: "prefixUuid4", + pathPrefix: "example://bar/", + }, + }); + const path = "simple://foo"; + const data = new Blob(["This is a test"], { type: "text/plain" }); + const metadata = { + alpha: "one", + bravo: "two", + }; + const blob = await MediaBlob.fromBlob(data, { path, metadata }); + const storedBlob = await store.store(blob); + expect(storedBlob?.path).toMatch( + /example:\/\/bar\/[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}$/i + ); + expect(storedBlob?.size).toEqual(14); + expect(await storedBlob?.asString()).toEqual("This is a test"); + expect(storedBlob?.metadata?.alpha).toEqual("one"); + expect(storedBlob?.metadata?.langchainOldPath).toEqual(path); + }); + }); + + describe("MediaManager", () => { + class MemStore extends InMemoryStore { + get length() { + return Object.keys(this.store).length; + } + } + + let mediaManager: MediaManager; + let aliasMemory: MemStore; + let canonicalMemory: MemStore; + let resolverMemory: MemStore; + + async function store( + blobStore: BlobStore, + path: string, + text: string + ): Promise { + const data = new Blob([text], { type: "text/plain" }); + const blob = await MediaBlob.fromBlob(data, { path }); + await blobStore.store(blob); + } + + beforeEach(async () => { + aliasMemory = new MemStore(); + const aliasStore = new BackedBlobStore({ + backingStore: aliasMemory, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + canonicalMemory = new MemStore(); + const canonicalStore = new BackedBlobStore({ + backingStore: canonicalMemory, + defaultStoreOptions: { + pathPrefix: "canonical://store/", + actionIfInvalid: "prefixPath", + }, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + resolverMemory = new MemStore(); + const resolver = new BackedBlobStore({ + backingStore: resolverMemory, + defaultFetchOptions: { + actionIfBlobMissing: "emptyBlob", + }, + }); + const mediaStore = new ReadThroughBlobStore({ + baseStore: aliasStore, + backingStore: canonicalStore, + }); + mediaManager = new MediaManager({ + store: mediaStore, + resolvers: [resolver], + }); + await store(resolver, "resolve://host/foo", "fooing"); + await store(resolver, "resolve://host2/bar/baz", "barbazing"); + }); + + test("environment", async () => { + expect(resolverMemory.length).toEqual(2); + const fooBlob = await mediaManager.resolvers?.[0]?.fetch( + "resolve://host/foo" + ); + expect(await fooBlob?.asString()).toEqual("fooing"); + }); + + test("simple", async () => { + const uri = "resolve://host/foo"; + const curi = "canonical://store/host/foo"; + const blob = await mediaManager.getMediaBlob(uri); + expect(await blob?.asString()).toEqual("fooing"); + expect(blob?.path).toEqual(curi); + + // In the alias store, + // we should be able to fetch it by the resolve uri, but the + // path in the blob itself should be the canonical uri + expect(aliasMemory.length).toEqual(1); + const mediaStore: ReadThroughBlobStore = + mediaManager.store as ReadThroughBlobStore; + const aliasBlob = await mediaStore.baseStore.fetch(uri); + expect(aliasBlob).toBeDefined(); + expect(aliasBlob?.path).toEqual(curi); + expect(await aliasBlob?.asString()).toEqual("fooing"); + + // For the canonical store, + // fetching it by the resolve uri should fail + // but fetching it by the canonical uri should succeed + expect(canonicalMemory.length).toEqual(1); + const canonicalBlobU = await mediaStore.backingStore.fetch(uri); + expect(canonicalBlobU).toBeUndefined(); + const canonicalBlob = await mediaStore.backingStore.fetch(curi); + expect(canonicalBlob).toBeDefined(); + expect(canonicalBlob?.path).toEqual(curi); + expect(await canonicalBlob?.asString()).toEqual("fooing"); + }); + }); }); function toUint8Array(data: string): Uint8Array { diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index 7721b5136704..4fecd254693b 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -4,6 +4,7 @@ import type { BindToolsInput, } from "@langchain/core/language_models/chat_models"; import type { JsonStream } from "./utils/stream.js"; +import { MediaManager } from "./experimental/utils/media_core.js"; /** * Parameters needed to setup the client connection. @@ -147,7 +148,8 @@ export interface GoogleAIBaseLLMInput extends BaseLLMParams, GoogleConnectionParams, GoogleAIModelParams, - GoogleAISafetyParams {} + GoogleAISafetyParams, + GeminiAPIConfig {} export interface GoogleAIBaseLanguageModelCallOptions extends BaseChatModelCallOptions, @@ -172,6 +174,10 @@ export interface GoogleResponse { data: any; } +export interface GoogleRawResponse extends GoogleResponse { + data: Blob; +} + export interface GeminiPartText { text: string; } @@ -183,7 +189,6 @@ export interface GeminiPartInlineData { }; } -// Vertex AI only export interface GeminiPartFileData { fileData: { mimeType: string; @@ -342,3 +347,7 @@ export interface GeminiJsonSchemaDirty extends GeminiJsonSchema { properties?: Record; additionalProperties?: boolean; } + +export interface GeminiAPIConfig { + mediaManager?: MediaManager; +} diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index d230203a85de..aab0699cec56 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -19,7 +19,6 @@ import { ChatGeneration, ChatGenerationChunk, ChatResult, - Generation, } from "@langchain/core/outputs"; import { ToolCallChunk } from "@langchain/core/messages/tool"; import type { @@ -34,8 +33,32 @@ import type { GenerateContentResponseData, GoogleAISafetyHandler, GeminiPartFunctionCall, + GeminiAPIConfig, } from "../types.js"; import { GoogleAISafetyError } from "./safety.js"; +import { MediaBlob } from "../experimental/utils/media_core.js"; + +export interface FunctionCall { + name: string; + arguments: string; +} + +export interface ToolCall { + id: string; + type: "function"; + function: FunctionCall; +} + +export interface FunctionCallRaw { + name: string; + arguments: object; +} + +export interface ToolCallRaw { + id: string; + type: "function"; + function: FunctionCallRaw; +} const extractMimeType = ( str: string @@ -49,671 +72,686 @@ const extractMimeType = ( return null; }; -function messageContentText( - content: MessageContentText -): GeminiPartText | null { - if (content?.text && content?.text.length > 0) { - return { - text: content.text, - }; - } else { - return null; +export function getGeminiAPI(config?: GeminiAPIConfig) { + function messageContentText( + content: MessageContentText + ): GeminiPartText | null { + if (content?.text && content?.text.length > 0) { + return { + text: content.text, + }; + } else { + return null; + } } -} -function messageContentImageUrl( - content: MessageContentImageUrl -): GeminiPartInlineData | GeminiPartFileData { - const url: string = - typeof content.image_url === "string" - ? content.image_url - : content.image_url.url; - if (!url) { - throw new Error("Missing Image URL"); - } + function messageContentImageUrl( + content: MessageContentImageUrl + ): GeminiPartInlineData | GeminiPartFileData { + const url: string = + typeof content.image_url === "string" + ? content.image_url + : content.image_url.url; + if (!url) { + throw new Error("Missing Image URL"); + } - const mineTypeAndData = extractMimeType(url); - if (mineTypeAndData) { - return { - inlineData: mineTypeAndData, - }; - } else { - // FIXME - need some way to get mime type - return { - fileData: { - mimeType: "image/png", - fileUri: url, - }, - }; + const mineTypeAndData = extractMimeType(url); + if (mineTypeAndData) { + return { + inlineData: mineTypeAndData, + }; + } else { + // FIXME - need some way to get mime type + return { + fileData: { + mimeType: "image/png", + fileUri: url, + }, + }; + } } -} -function messageContentMedia( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - content: Record -): GeminiPartInlineData | GeminiPartFileData { - if ("mimeType" in content && "data" in content) { - return { - inlineData: { - mimeType: content.mimeType, - data: content.data, - }, - }; - } else if ("mimeType" in content && "fileUri" in content) { + async function blobToFileData(blob: MediaBlob): Promise { return { fileData: { - mimeType: content.mimeType, - fileUri: content.fileUri, + fileUri: blob.path!, + mimeType: blob.mimetype, }, }; } - throw new Error("Invalid media content"); -} + async function fileUriContentToBlob( + uri: string + ): Promise { + return config?.mediaManager?.getMediaBlob(uri); + } -export function messageContentToParts(content: MessageContent): GeminiPart[] { - // Convert a string to a text type MessageContent if needed - const messageContent: MessageContent = - typeof content === "string" - ? [ - { - type: "text", - text: content, - }, - ] - : content; - - // eslint-disable-next-line array-callback-return - const parts: GeminiPart[] = messageContent - .map((content) => { - switch (content.type) { - case "text": - if ("text" in content) { - return messageContentText(content as MessageContentText); - } - break; - case "image_url": - if ("image_url" in content) { - // Type guard for MessageContentImageUrl - return messageContentImageUrl(content as MessageContentImageUrl); - } - break; - case "media": - return messageContentMedia(content); - default: - throw new Error( - `Unsupported type received while converting message to message parts` - ); - } - throw new Error( - `Cannot coerce "${content.type}" message part into a string.` - ); - }) - .reduce((acc: GeminiPart[], val: GeminiPart | null | undefined) => { - if (val) { - return [...acc, val]; - } else { - return acc; + async function messageContentMedia( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + content: Record + ): Promise { + if ("mimeType" in content && "data" in content) { + return { + inlineData: { + mimeType: content.mimeType, + data: content.data, + }, + }; + } else if ("mimeType" in content && "fileUri" in content) { + return { + fileData: { + mimeType: content.mimeType, + fileUri: content.fileUri, + }, + }; + } else { + const uri = content.fileUri; + const blob = await fileUriContentToBlob(uri); + if (blob) { + return await blobToFileData(blob); } - }, []); - - return parts; -} + } -function messageToolCallsToParts(toolCalls: ToolCall[]): GeminiPart[] { - if (!toolCalls || toolCalls.length === 0) { - return []; + throw new Error("Invalid media content"); } - return toolCalls.map((tool: ToolCall) => { - let args = {}; - if (tool?.function?.arguments) { - const argStr = tool.function.arguments; - args = JSON.parse(argStr); + async function messageContentComplexToPart( + content: MessageContentComplex + ): Promise { + switch (content.type) { + case "text": + if ("text" in content) { + return messageContentText(content as MessageContentText); + } + break; + case "image_url": + if ("image_url" in content) { + // Type guard for MessageContentImageUrl + return messageContentImageUrl(content as MessageContentImageUrl); + } + break; + case "media": + return await messageContentMedia(content); + default: + throw new Error( + `Unsupported type received while converting message to message parts` + ); } - return { - functionCall: { - name: tool.function.name, - args, - }, - }; - }); -} + throw new Error( + `Cannot coerce "${content.type}" message part into a string.` + ); + } + + async function messageContentComplexToParts( + content: MessageContentComplex[] + ): Promise<(GeminiPart | null)[]> { + const contents = content.map(messageContentComplexToPart); + return Promise.all(contents); + } -function messageKwargsToParts(kwargs: Record): GeminiPart[] { - const ret: GeminiPart[] = []; + async function messageContentToParts( + content: MessageContent + ): Promise { + // Convert a string to a text type MessageContent if needed + const messageContent: MessageContentComplex[] = + typeof content === "string" + ? [ + { + type: "text", + text: content, + }, + ] + : content; + + // Get all of the parts, even those that don't correctly resolve + const allParts = await messageContentComplexToParts(messageContent); + + // Remove any invalid parts + const parts: GeminiPart[] = allParts.reduce( + (acc: GeminiPart[], val: GeminiPart | null | undefined) => { + if (val) { + return [...acc, val]; + } else { + return acc; + } + }, + [] + ); - if (kwargs?.tool_calls) { - ret.push(...messageToolCallsToParts(kwargs.tool_calls as ToolCall[])); + return parts; } - return ret; -} + function messageToolCallsToParts(toolCalls: ToolCall[]): GeminiPart[] { + if (!toolCalls || toolCalls.length === 0) { + return []; + } -function roleMessageToContent( - role: GeminiRole, - message: BaseMessage -): GeminiContent[] { - const contentParts: GeminiPart[] = messageContentToParts(message.content); - let toolParts: GeminiPart[]; - if (isAIMessage(message) && !!message.tool_calls?.length) { - toolParts = message.tool_calls.map( - (toolCall): GeminiPart => ({ + return toolCalls.map((tool: ToolCall) => { + let args = {}; + if (tool?.function?.arguments) { + const argStr = tool.function.arguments; + args = JSON.parse(argStr); + } + return { functionCall: { - name: toolCall.name, - args: toolCall.args, + name: tool.function.name, + args, }, - }) - ); - } else { - toolParts = messageKwargsToParts(message.additional_kwargs); - } - const parts: GeminiPart[] = [...contentParts, ...toolParts]; - return [ - { - role, - parts, - }, - ]; -} + }; + }); + } -function systemMessageToContent( - message: SystemMessage, - useSystemInstruction: boolean -): GeminiContent[] { - return useSystemInstruction - ? roleMessageToContent("system", message) - : [ - ...roleMessageToContent("user", message), - ...roleMessageToContent("model", new AIMessage("Ok")), - ]; -} + function messageKwargsToParts(kwargs: Record): GeminiPart[] { + const ret: GeminiPart[] = []; + + if (kwargs?.tool_calls) { + ret.push(...messageToolCallsToParts(kwargs.tool_calls as ToolCall[])); + } -function toolMessageToContent( - message: ToolMessage, - prevMessage: BaseMessage -): GeminiContent[] { - const contentStr = - typeof message.content === "string" - ? message.content - : message.content.reduce( - (acc: string, content: MessageContentComplex) => { - if (content.type === "text") { - return acc + content.text; - } else { - return acc; - } + return ret; + } + + async function roleMessageToContent( + role: GeminiRole, + message: BaseMessage + ): Promise { + const contentParts: GeminiPart[] = await messageContentToParts( + message.content + ); + let toolParts: GeminiPart[]; + if (isAIMessage(message) && !!message.tool_calls?.length) { + toolParts = message.tool_calls.map( + (toolCall): GeminiPart => ({ + functionCall: { + name: toolCall.name, + args: toolCall.args, }, - "" - ); - // Hacky :( - const responseName = - (isAIMessage(prevMessage) && !!prevMessage.tool_calls?.length - ? prevMessage.tool_calls[0].name - : prevMessage.name) ?? message.tool_call_id; - try { - const content = JSON.parse(contentStr); + }) + ); + } else { + toolParts = messageKwargsToParts(message.additional_kwargs); + } + const parts: GeminiPart[] = [...contentParts, ...toolParts]; return [ { - role: "function", - parts: [ - { - functionResponse: { - name: responseName, - response: { content }, - }, - }, - ], + role, + parts, }, ]; - } catch (_) { - return [ - { - role: "function", - parts: [ - { - functionResponse: { - name: responseName, - response: { content: contentStr }, + } + + async function systemMessageToContent( + message: SystemMessage, + useSystemInstruction: boolean + ): Promise { + return useSystemInstruction + ? roleMessageToContent("system", message) + : [ + ...(await roleMessageToContent("user", message)), + ...(await roleMessageToContent("model", new AIMessage("Ok"))), + ]; + } + + function toolMessageToContent( + message: ToolMessage, + prevMessage: BaseMessage + ): GeminiContent[] { + const contentStr = + typeof message.content === "string" + ? message.content + : message.content.reduce( + (acc: string, content: MessageContentComplex) => { + if (content.type === "text") { + return acc + content.text; + } else { + return acc; + } }, - }, - ], - }, - ]; + "" + ); + // Hacky :( + const responseName = + (isAIMessage(prevMessage) && !!prevMessage.tool_calls?.length + ? prevMessage.tool_calls[0].name + : prevMessage.name) ?? message.tool_call_id; + try { + const content = JSON.parse(contentStr); + return [ + { + role: "function", + parts: [ + { + functionResponse: { + name: responseName, + response: { content }, + }, + }, + ], + }, + ]; + } catch (_) { + return [ + { + role: "function", + parts: [ + { + functionResponse: { + name: responseName, + response: { content: contentStr }, + }, + }, + ], + }, + ]; + } } -} -export function baseMessageToContent( - message: BaseMessage, - prevMessage: BaseMessage | undefined, - useSystemInstruction: boolean -): GeminiContent[] { - const type = message._getType(); - switch (type) { - case "system": - return systemMessageToContent( - message as SystemMessage, - useSystemInstruction - ); - case "human": - return roleMessageToContent("user", message); - case "ai": - return roleMessageToContent("model", message); - case "tool": - if (!prevMessage) { - throw new Error( - "Tool messages cannot be the first message passed to the model." + async function baseMessageToContent( + message: BaseMessage, + prevMessage: BaseMessage | undefined, + useSystemInstruction: boolean + ): Promise { + const type = message._getType(); + switch (type) { + case "system": + return systemMessageToContent( + message as SystemMessage, + useSystemInstruction ); - } - return toolMessageToContent(message as ToolMessage, prevMessage); - default: - console.log(`Unsupported message type: ${type}`); - return []; + case "human": + return roleMessageToContent("user", message); + case "ai": + return roleMessageToContent("model", message); + case "tool": + if (!prevMessage) { + throw new Error( + "Tool messages cannot be the first message passed to the model." + ); + } + return toolMessageToContent(message as ToolMessage, prevMessage); + default: + console.log(`Unsupported message type: ${type}`); + return []; + } } -} -function textPartToMessageContent(part: GeminiPartText): MessageContentText { - return { - type: "text", - text: part.text, - }; -} + function textPartToMessageContent(part: GeminiPartText): MessageContentText { + return { + type: "text", + text: part.text, + }; + } -function inlineDataPartToMessageContent( - part: GeminiPartInlineData -): MessageContentImageUrl { - return { - type: "image_url", - image_url: `data:${part.inlineData.mimeType};base64,${part.inlineData.data}`, - }; -} + function inlineDataPartToMessageContent( + part: GeminiPartInlineData + ): MessageContentImageUrl { + return { + type: "image_url", + image_url: `data:${part.inlineData.mimeType};base64,${part.inlineData.data}`, + }; + } -function fileDataPartToMessageContent( - part: GeminiPartFileData -): MessageContentImageUrl { - return { - type: "image_url", - image_url: part.fileData.fileUri, - }; -} + function fileDataPartToMessageContent( + part: GeminiPartFileData + ): MessageContentImageUrl { + return { + type: "image_url", + image_url: part.fileData.fileUri, + }; + } -export function partsToMessageContent(parts: GeminiPart[]): MessageContent { - return parts - .map((part) => { - if (part === undefined || part === null) { - return null; - } else if ("text" in part) { - return textPartToMessageContent(part); - } else if ("inlineData" in part) { - return inlineDataPartToMessageContent(part); - } else if ("fileData" in part) { - return fileDataPartToMessageContent(part); - } else { - return null; - } - }) - .reduce((acc, content) => { - if (content) { - acc.push(content); - } - return acc; - }, [] as MessageContentComplex[]); -} + function partsToMessageContent(parts: GeminiPart[]): MessageContent { + return parts + .map((part) => { + if (part === undefined || part === null) { + return null; + } else if ("text" in part) { + return textPartToMessageContent(part); + } else if ("inlineData" in part) { + return inlineDataPartToMessageContent(part); + } else if ("fileData" in part) { + return fileDataPartToMessageContent(part); + } else { + return null; + } + }) + .reduce((acc, content) => { + if (content) { + acc.push(content); + } + return acc; + }, [] as MessageContentComplex[]); + } -interface FunctionCall { - name: string; - arguments: string; -} + function toolRawToTool(raw: ToolCallRaw): ToolCall { + return { + id: raw.id, + type: raw.type, + function: { + name: raw.function.name, + arguments: JSON.stringify(raw.function.arguments), + }, + }; + } -interface ToolCall { - id: string; - type: "function"; - function: FunctionCall; -} + function functionCallPartToToolRaw( + part: GeminiPartFunctionCall + ): ToolCallRaw { + return { + id: uuidv4().replace(/-/g, ""), + type: "function", + function: { + name: part.functionCall.name, + arguments: part.functionCall.args ?? {}, + }, + }; + } -interface FunctionCallRaw { - name: string; - arguments: object; -} + function partsToToolsRaw(parts: GeminiPart[]): ToolCallRaw[] { + return parts + .map((part: GeminiPart) => { + if (part === undefined || part === null) { + return null; + } else if ("functionCall" in part) { + return functionCallPartToToolRaw(part); + } else { + return null; + } + }) + .reduce((acc, content) => { + if (content) { + acc.push(content); + } + return acc; + }, [] as ToolCallRaw[]); + } -interface ToolCallRaw { - id: string; - type: "function"; - function: FunctionCallRaw; -} + function toolsRawToTools(raws: ToolCallRaw[]): ToolCall[] { + return raws.map((raw) => toolRawToTool(raw)); + } -function toolRawToTool(raw: ToolCallRaw): ToolCall { - return { - id: raw.id, - type: raw.type, - function: { - name: raw.function.name, - arguments: JSON.stringify(raw.function.arguments), - }, - }; -} + function responseToGenerateContentResponseData( + response: GoogleLLMResponse + ): GenerateContentResponseData { + if ("nextChunk" in response.data) { + throw new Error("Cannot convert Stream to GenerateContentResponseData"); + } else if (Array.isArray(response.data)) { + // Collapse the array of response data as if it was a single one + return response.data.reduce( + ( + acc: GenerateContentResponseData, + val: GenerateContentResponseData + ): GenerateContentResponseData => { + // Add all the parts + // FIXME: Handle other candidates? + const valParts = val?.candidates?.[0]?.content?.parts ?? []; + acc.candidates[0].content.parts.push(...valParts); + + // FIXME: Merge promptFeedback and safety settings + acc.promptFeedback = val.promptFeedback; + return acc; + } + ); + } else { + return response.data as GenerateContentResponseData; + } + } -function functionCallPartToToolRaw(part: GeminiPartFunctionCall): ToolCallRaw { - return { - id: uuidv4().replace(/-/g, ""), - type: "function", - function: { - name: part.functionCall.name, - arguments: part.functionCall.args ?? {}, - }, - }; -} + function responseToParts(response: GoogleLLMResponse): GeminiPart[] { + const responseData = responseToGenerateContentResponseData(response); + const parts = responseData?.candidates?.[0]?.content?.parts ?? []; + return parts; + } -export function partsToToolsRaw(parts: GeminiPart[]): ToolCallRaw[] { - return parts - .map((part: GeminiPart) => { - if (part === undefined || part === null) { - return null; - } else if ("functionCall" in part) { - return functionCallPartToToolRaw(part); - } else { - return null; - } - }) - .reduce((acc, content) => { - if (content) { - acc.push(content); - } - return acc; - }, [] as ToolCallRaw[]); -} + function partToText(part: GeminiPart): string { + return "text" in part ? part.text : ""; + } -export function toolsRawToTools(raws: ToolCallRaw[]): ToolCall[] { - return raws.map((raw) => toolRawToTool(raw)); -} + function responseToString(response: GoogleLLMResponse): string { + const parts = responseToParts(response); + const ret: string = parts.reduce((acc, part) => { + const val = partToText(part); + return acc + val; + }, ""); + return ret; + } -export function responseToGenerateContentResponseData( - response: GoogleLLMResponse -): GenerateContentResponseData { - if ("nextChunk" in response.data) { - throw new Error("Cannot convert Stream to GenerateContentResponseData"); - } else if (Array.isArray(response.data)) { - // Collapse the array of response data as if it was a single one - return response.data.reduce( - ( - acc: GenerateContentResponseData, - val: GenerateContentResponseData - ): GenerateContentResponseData => { - // Add all the parts - // FIXME: Handle other candidates? - const valParts = val?.candidates?.[0]?.content?.parts ?? []; - acc.candidates[0].content.parts.push(...valParts); - - // FIXME: Merge promptFeedback and safety settings - acc.promptFeedback = val.promptFeedback; - return acc; + function safeResponseTo( + response: GoogleLLMResponse, + safetyHandler: GoogleAISafetyHandler, + responseTo: (response: GoogleLLMResponse) => RetType + ): RetType { + try { + const safeResponse = safetyHandler.handle(response); + return responseTo(safeResponse); + } catch (xx) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (xx instanceof GoogleAISafetyError) { + const ret = responseTo(xx.response); + xx.reply = ret; } - ); - } else { - return response.data as GenerateContentResponseData; + throw xx; + } } -} -export function responseToParts(response: GoogleLLMResponse): GeminiPart[] { - const responseData = responseToGenerateContentResponseData(response); - const parts = responseData?.candidates?.[0]?.content?.parts ?? []; - return parts; -} - -export function partToText(part: GeminiPart): string { - return "text" in part ? part.text : ""; -} - -export function responseToString(response: GoogleLLMResponse): string { - const parts = responseToParts(response); - const ret: string = parts.reduce((acc, part) => { - const val = partToText(part); - return acc + val; - }, ""); - return ret; -} + function safeResponseToString( + response: GoogleLLMResponse, + safetyHandler: GoogleAISafetyHandler + ): string { + return safeResponseTo(response, safetyHandler, responseToString); + } -function safeResponseTo( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler, - responseTo: (response: GoogleLLMResponse) => RetType -): RetType { - try { - const safeResponse = safetyHandler.handle(response); - return responseTo(safeResponse); - } catch (xx) { - // eslint-disable-next-line no-instanceof/no-instanceof - if (xx instanceof GoogleAISafetyError) { - const ret = responseTo(xx.response); - xx.reply = ret; + function responseToGenerationInfo(response: GoogleLLMResponse) { + if (!Array.isArray(response.data)) { + return {}; } - throw xx; + const data = response.data[0]; + return { + usage_metadata: { + prompt_token_count: data.usageMetadata?.promptTokenCount, + candidates_token_count: data.usageMetadata?.candidatesTokenCount, + total_token_count: data.usageMetadata?.totalTokenCount, + }, + safety_ratings: data.candidates[0]?.safetyRatings?.map((rating) => ({ + category: rating.category, + probability: rating.probability, + probability_score: rating.probabilityScore, + severity: rating.severity, + severity_score: rating.severityScore, + })), + finish_reason: data.candidates[0]?.finishReason, + }; } -} - -export function safeResponseToString( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler -): string { - return safeResponseTo(response, safetyHandler, responseToString); -} -export function responseToGenerationInfo(response: GoogleLLMResponse) { - if (!Array.isArray(response.data)) { - return {}; + function responseToChatGeneration( + response: GoogleLLMResponse + ): ChatGenerationChunk { + return new ChatGenerationChunk({ + text: responseToString(response), + message: partToMessageChunk(responseToParts(response)[0]), + generationInfo: responseToGenerationInfo(response), + }); } - const data = response.data[0]; - return { - usage_metadata: { - prompt_token_count: data.usageMetadata?.promptTokenCount, - candidates_token_count: data.usageMetadata?.candidatesTokenCount, - total_token_count: data.usageMetadata?.totalTokenCount, - }, - safety_ratings: data.candidates[0]?.safetyRatings?.map((rating) => ({ - category: rating.category, - probability: rating.probability, - probability_score: rating.probabilityScore, - severity: rating.severity, - severity_score: rating.severityScore, - })), - finish_reason: data.candidates[0]?.finishReason, - }; -} - -export function responseToGeneration(response: GoogleLLMResponse): Generation { - return { - text: responseToString(response), - generationInfo: responseToGenerationInfo(response), - }; -} - -export function safeResponseToGeneration( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler -): Generation { - return safeResponseTo(response, safetyHandler, responseToGeneration); -} -export function responseToChatGeneration( - response: GoogleLLMResponse -): ChatGenerationChunk { - return new ChatGenerationChunk({ - text: responseToString(response), - message: partToMessageChunk(responseToParts(response)[0]), - generationInfo: responseToGenerationInfo(response), - }); -} - -export function safeResponseToChatGeneration( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler -): ChatGenerationChunk { - return safeResponseTo(response, safetyHandler, responseToChatGeneration); -} + function safeResponseToChatGeneration( + response: GoogleLLMResponse, + safetyHandler: GoogleAISafetyHandler + ): ChatGenerationChunk { + return safeResponseTo(response, safetyHandler, responseToChatGeneration); + } -export function chunkToString(chunk: BaseMessageChunk): string { - if (chunk === null) { - return ""; - } else if (typeof chunk.content === "string") { - return chunk.content; - } else if (chunk.content.length === 0) { - return ""; - } else if (chunk.content[0].type === "text") { - return chunk.content[0].text; - } else { - throw new Error(`Unexpected chunk: ${chunk}`); + function chunkToString(chunk: BaseMessageChunk): string { + if (chunk === null) { + return ""; + } else if (typeof chunk.content === "string") { + return chunk.content; + } else if (chunk.content.length === 0) { + return ""; + } else if (chunk.content[0].type === "text") { + return chunk.content[0].text; + } else { + throw new Error(`Unexpected chunk: ${chunk}`); + } } -} -export function partToMessageChunk(part: GeminiPart): BaseMessageChunk { - const fields = partsToBaseMessageChunkFields([part]); - if (typeof fields.content === "string") { + function partToMessageChunk(part: GeminiPart): BaseMessageChunk { + const fields = partsToBaseMessageChunkFields([part]); + if (typeof fields.content === "string") { + return new AIMessageChunk(fields); + } else if (fields.content.every((item) => item.type === "text")) { + const newContent = fields.content + .map((item) => ("text" in item ? item.text : "")) + .join(""); + return new AIMessageChunk({ + ...fields, + content: newContent, + }); + } return new AIMessageChunk(fields); - } else if (fields.content.every((item) => item.type === "text")) { - const newContent = fields.content - .map((item) => ("text" in item ? item.text : "")) - .join(""); - return new AIMessageChunk({ - ...fields, - content: newContent, - }); } - return new AIMessageChunk(fields); -} -export function partToChatGeneration(part: GeminiPart): ChatGeneration { - const message = partToMessageChunk(part); - const text = partToText(part); - return new ChatGenerationChunk({ - text, - message, - }); -} + function partToChatGeneration(part: GeminiPart): ChatGeneration { + const message = partToMessageChunk(part); + const text = partToText(part); + return new ChatGenerationChunk({ + text, + message, + }); + } -export function responseToChatGenerations( - response: GoogleLLMResponse -): ChatGeneration[] { - const parts = responseToParts(response); - let ret = parts.map((part) => partToChatGeneration(part)); - if (ret.every((item) => typeof item.message.content === "string")) { - const combinedContent = ret.map((item) => item.message.content).join(""); - const combinedText = ret.map((item) => item.text).join(""); - const toolCallChunks: ToolCallChunk[] | undefined = ret[ - ret.length - 1 - ]?.message.additional_kwargs?.tool_calls?.map((toolCall, i) => ({ - name: toolCall.function.name, - args: toolCall.function.arguments, - id: toolCall.id, - index: i, - type: "tool_call_chunk", - })); - let usageMetadata: UsageMetadata | undefined; - if ("usageMetadata" in response.data) { - usageMetadata = { - input_tokens: response.data.usageMetadata.promptTokenCount as number, - output_tokens: response.data.usageMetadata - .candidatesTokenCount as number, - total_tokens: response.data.usageMetadata.totalTokenCount as number, - }; - } - ret = [ - new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: combinedContent, - additional_kwargs: ret[ret.length - 1]?.message.additional_kwargs, - tool_call_chunks: toolCallChunks, - usage_metadata: usageMetadata, + function responseToChatGenerations( + response: GoogleLLMResponse + ): ChatGeneration[] { + const parts = responseToParts(response); + let ret = parts.map((part) => partToChatGeneration(part)); + if (ret.every((item) => typeof item.message.content === "string")) { + const combinedContent = ret.map((item) => item.message.content).join(""); + const combinedText = ret.map((item) => item.text).join(""); + const toolCallChunks: ToolCallChunk[] | undefined = ret[ + ret.length - 1 + ]?.message.additional_kwargs?.tool_calls?.map((toolCall, i) => ({ + name: toolCall.function.name, + args: toolCall.function.arguments, + id: toolCall.id, + index: i, + type: "tool_call_chunk", + })); + let usageMetadata: UsageMetadata | undefined; + if ("usageMetadata" in response.data) { + usageMetadata = { + input_tokens: response.data.usageMetadata.promptTokenCount as number, + output_tokens: response.data.usageMetadata + .candidatesTokenCount as number, + total_tokens: response.data.usageMetadata.totalTokenCount as number, + }; + } + ret = [ + new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: combinedContent, + additional_kwargs: ret[ret.length - 1]?.message.additional_kwargs, + tool_call_chunks: toolCallChunks, + usage_metadata: usageMetadata, + }), + text: combinedText, + generationInfo: ret[ret.length - 1].generationInfo, }), - text: combinedText, - generationInfo: ret[ret.length - 1].generationInfo, - }), - ]; + ]; + } + return ret; } - return ret; -} - -export function responseToBaseMessageFields( - response: GoogleLLMResponse -): BaseMessageFields { - const parts = responseToParts(response); - return partsToBaseMessageChunkFields(parts); -} -export function partsToBaseMessageChunkFields( - parts: GeminiPart[] -): AIMessageChunkFields { - const fields: AIMessageChunkFields = { - content: partsToMessageContent(parts), - tool_call_chunks: [], - tool_calls: [], - invalid_tool_calls: [], - }; + function responseToBaseMessageFields( + response: GoogleLLMResponse + ): BaseMessageFields { + const parts = responseToParts(response); + return partsToBaseMessageChunkFields(parts); + } - const rawTools = partsToToolsRaw(parts); - if (rawTools.length > 0) { - const tools = toolsRawToTools(rawTools); - for (const tool of tools) { - fields.tool_call_chunks?.push({ - name: tool.function.name, - args: tool.function.arguments, - id: tool.id, - type: "tool_call_chunk", - }); + function partsToBaseMessageChunkFields( + parts: GeminiPart[] + ): AIMessageChunkFields { + const fields: AIMessageChunkFields = { + content: partsToMessageContent(parts), + tool_call_chunks: [], + tool_calls: [], + invalid_tool_calls: [], + }; - try { - fields.tool_calls?.push({ - name: tool.function.name, - args: JSON.parse(tool.function.arguments), - id: tool.id, - type: "tool_call", - }); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - fields.invalid_tool_calls?.push({ + const rawTools = partsToToolsRaw(parts); + if (rawTools.length > 0) { + const tools = toolsRawToTools(rawTools); + for (const tool of tools) { + fields.tool_call_chunks?.push({ name: tool.function.name, args: tool.function.arguments, id: tool.id, - error: e.message, - type: "invalid_tool_call", + type: "tool_call_chunk", }); + + try { + fields.tool_calls?.push({ + name: tool.function.name, + args: JSON.parse(tool.function.arguments), + id: tool.id, + }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + fields.invalid_tool_calls?.push({ + name: tool.function.name, + args: tool.function.arguments, + id: tool.id, + error: e.message, + type: "invalid_tool_call", + }); + } } + fields.additional_kwargs = { + tool_calls: tools, + }; } - fields.additional_kwargs = { - tool_calls: tools, - }; + return fields; } - return fields; -} -export function responseToBaseMessage( - response: GoogleLLMResponse -): BaseMessage { - const fields = responseToBaseMessageFields(response); - return new AIMessage(fields); -} + function responseToBaseMessage(response: GoogleLLMResponse): BaseMessage { + const fields = responseToBaseMessageFields(response); + return new AIMessage(fields); + } -export function safeResponseToBaseMessage( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler -): BaseMessage { - return safeResponseTo(response, safetyHandler, responseToBaseMessage); -} + function safeResponseToBaseMessage( + response: GoogleLLMResponse, + safetyHandler: GoogleAISafetyHandler + ): BaseMessage { + return safeResponseTo(response, safetyHandler, responseToBaseMessage); + } + + function responseToChatResult(response: GoogleLLMResponse): ChatResult { + const generations = responseToChatGenerations(response); + return { + generations, + llmOutput: responseToGenerationInfo(response), + }; + } + + function safeResponseToChatResult( + response: GoogleLLMResponse, + safetyHandler: GoogleAISafetyHandler + ): ChatResult { + return safeResponseTo(response, safetyHandler, responseToChatResult); + } -export function responseToChatResult(response: GoogleLLMResponse): ChatResult { - const generations = responseToChatGenerations(response); return { - generations, - llmOutput: responseToGenerationInfo(response), + messageContentToParts, + baseMessageToContent, + safeResponseToString, + safeResponseToChatGeneration, + chunkToString, + safeResponseToBaseMessage, + safeResponseToChatResult, }; } -export function safeResponseToChatResult( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler -): ChatResult { - return safeResponseTo(response, safetyHandler, responseToChatResult); -} - export function validateGeminiParams(params: GoogleAIModelParams): void { if (params.maxOutputTokens && params.maxOutputTokens < 0) { throw new Error("`maxOutputTokens` must be a positive integer"); diff --git a/libs/langchain-google-gauth/src/index.ts b/libs/langchain-google-gauth/src/index.ts index 7f420a4ed6d0..4cd6bd0176ab 100644 --- a/libs/langchain-google-gauth/src/index.ts +++ b/libs/langchain-google-gauth/src/index.ts @@ -1,3 +1,5 @@ export * from "./chat_models.js"; export * from "./llms.js"; export * from "./embeddings.js"; + +export * from "./media.js"; diff --git a/libs/langchain-google-gauth/src/media.ts b/libs/langchain-google-gauth/src/media.ts new file mode 100644 index 000000000000..b44aadd0f858 --- /dev/null +++ b/libs/langchain-google-gauth/src/media.ts @@ -0,0 +1,31 @@ +import { GoogleAbstractedClient } from "@langchain/google-common"; +import { + BlobStoreGoogleCloudStorageBase, + BlobStoreGoogleCloudStorageBaseParams, + BlobStoreAIStudioFileBase, + BlobStoreAIStudioFileBaseParams, +} from "@langchain/google-common/experimental/media"; +import { GoogleAuthOptions } from "google-auth-library"; +import { GAuthClient } from "./auth.js"; + +export interface BlobStoreGoogleCloudStorageParams + extends BlobStoreGoogleCloudStorageBaseParams {} + +export class BlobStoreGoogleCloudStorage extends BlobStoreGoogleCloudStorageBase { + buildClient( + fields?: BlobStoreGoogleCloudStorageParams + ): GoogleAbstractedClient { + return new GAuthClient(fields); + } +} + +export interface BlobStoreAIStudioFileParams + extends BlobStoreAIStudioFileBaseParams {} + +export class BlobStoreAIStudioFile extends BlobStoreAIStudioFileBase { + buildAbstractedClient( + fields?: BlobStoreAIStudioFileParams + ): GoogleAbstractedClient { + return new GAuthClient(fields); + } +} diff --git a/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts new file mode 100644 index 000000000000..5baa6aaf1973 --- /dev/null +++ b/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts @@ -0,0 +1,315 @@ +import { test } from "@jest/globals"; +import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import { ChatPromptValue } from "@langchain/core/prompt_values"; +import { + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + BaseMessageLike, + HumanMessage, + HumanMessageChunk, + MessageContentComplex, + SystemMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { + BackedBlobStore, + MediaBlob, + MediaManager, + ReadThroughBlobStore, + SimpleWebBlobStore, +} from "@langchain/google-common/experimental/utils/media_core"; +import { GoogleCloudStorageUri } from "@langchain/google-common/experimental/media"; +import { InMemoryStore } from "@langchain/core/stores"; +import { GeminiTool } from "../types.js"; +import { ChatGoogle } from "../chat_models.js"; +import { BlobStoreGoogleCloudStorage } from "../media.js"; + +describe("GAuth Chat", () => { + test("invoke", async () => { + const model = new ChatGoogle(); + try { + const res = await model.invoke("What is 1 + 1?"); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + + /* + expect(aiMessage.content.length).toBeGreaterThan(0); + expect(aiMessage.content[0]).toBeDefined(); + const content = aiMessage.content[0] as MessageContentComplex; + expect(content).toHaveProperty("type"); + expect(content.type).toEqual("text"); + + const textContent = content as MessageContentText; + expect(textContent.text).toBeDefined(); + expect(textContent.text).toEqual("2"); + */ + } catch (e) { + console.error(e); + throw e; + } + }); + + test("generate", async () => { + const model = new ChatGoogle(); + try { + const messages: BaseMessage[] = [ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]; + const res = await model.predictMessages(messages); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(["H", "T"]).toContainEqual(text); + + /* + expect(aiMessage.content.length).toBeGreaterThan(0); + expect(aiMessage.content[0]).toBeDefined(); + + const content = aiMessage.content[0] as MessageContentComplex; + expect(content).toHaveProperty("type"); + expect(content.type).toEqual("text"); + + const textContent = content as MessageContentText; + expect(textContent.text).toBeDefined(); + expect(["H", "T"]).toContainEqual(textContent.text); + */ + } catch (e) { + console.error(e); + throw e; + } + }); + + test("stream", async () => { + const model = new ChatGoogle(); + try { + const input: BaseLanguageModelInput = new ChatPromptValue([ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]); + const res = await model.stream(input); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + expect(resArray).toBeDefined(); + expect(resArray.length).toBeGreaterThanOrEqual(1); + + const lastChunk = resArray[resArray.length - 1]; + expect(lastChunk).toBeDefined(); + expect(lastChunk._getType()).toEqual("ai"); + const aiChunk = lastChunk as AIMessageChunk; + console.log(aiChunk); + + console.log(JSON.stringify(resArray, null, 2)); + } catch (e) { + console.error(e); + throw e; + } + }); + + test("function", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, + }, + required: ["testName"], + }, + }, + ], + }, + ]; + const model = new ChatGoogle().bind({ tools }); + const result = await model.invoke("Run a test on the cobalt project"); + expect(result).toHaveProperty("content"); + expect(result.content).toBe(""); + const args = result?.lc_kwargs?.additional_kwargs; + expect(args).toBeDefined(); + expect(args).toHaveProperty("tool_calls"); + expect(Array.isArray(args.tool_calls)).toBeTruthy(); + expect(args.tool_calls).toHaveLength(1); + const call = args.tool_calls[0]; + expect(call).toHaveProperty("type"); + expect(call.type).toBe("function"); + expect(call).toHaveProperty("function"); + const func = call.function; + expect(func).toBeDefined(); + expect(func).toHaveProperty("name"); + expect(func.name).toBe("test"); + expect(func).toHaveProperty("arguments"); + expect(typeof func.arguments).toBe("string"); + expect(func.arguments.replaceAll("\n", "")).toBe('{"testName":"cobalt"}'); + }); + + test("function reply", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, + }, + required: ["testName"], + }, + }, + ], + }, + ]; + const model = new ChatGoogle().bind({ tools }); + const toolResult = { + testPassed: true, + }; + const messages: BaseMessageLike[] = [ + new HumanMessage("Run a test on the cobalt project."), + new AIMessage("", { + tool_calls: [ + { + id: "test", + type: "function", + function: { + name: "test", + arguments: '{"testName":"cobalt"}', + }, + }, + ], + }), + new ToolMessage(JSON.stringify(toolResult), "test"), + ]; + const res = await model.stream(messages); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + console.log(JSON.stringify(resArray, null, 2)); + }); + + test("withStructuredOutput", async () => { + const tool = { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The name of city to get the weather for.", + }, + }, + required: ["location"], + }, + }; + const model = new ChatGoogle().withStructuredOutput(tool); + const result = await model.invoke("What is the weather in Paris?"); + expect(result).toHaveProperty("location"); + }); + + test("media - fileData", async () => { + class MemStore extends InMemoryStore { + get length() { + return Object.keys(this.store).length; + } + } + const aliasMemory = new MemStore(); + const aliasStore = new BackedBlobStore({ + backingStore: aliasMemory, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + const canonicalStore = new BlobStoreGoogleCloudStorage({ + uriPrefix: new GoogleCloudStorageUri("gs://test-langchainjs/mediatest/"), + defaultStoreOptions: { + actionIfInvalid: "prefixPath", + }, + }); + const blobStore = new ReadThroughBlobStore({ + baseStore: aliasStore, + backingStore: canonicalStore, + }); + const resolver = new SimpleWebBlobStore(); + const mediaManager = new MediaManager({ + store: blobStore, + resolvers: [resolver], + }); + const model = new ChatGoogle({ + modelName: "gemini-1.5-flash", + mediaManager, + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + try { + const res = await model.invoke(messages); + + console.log(res); + + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/LangChain/); + } catch (e) { + console.error(e); + throw e; + } + }); +}); diff --git a/libs/langchain-google-gauth/src/tests/data/blue-square.png b/libs/langchain-google-gauth/src/tests/data/blue-square.png new file mode 100644 index 000000000000..c64355dc335b Binary files /dev/null and b/libs/langchain-google-gauth/src/tests/data/blue-square.png differ diff --git a/libs/langchain-google-gauth/src/tests/media.int.test.ts b/libs/langchain-google-gauth/src/tests/media.int.test.ts new file mode 100644 index 000000000000..665dd4a75761 --- /dev/null +++ b/libs/langchain-google-gauth/src/tests/media.int.test.ts @@ -0,0 +1,137 @@ +import fs from "fs/promises"; +import { test } from "@jest/globals"; +import { GoogleCloudStorageUri } from "@langchain/google-common/experimental/media"; +import { MediaBlob } from "@langchain/google-common/experimental/utils/media_core"; +import { + BlobStoreGoogleCloudStorage, + BlobStoreGoogleCloudStorageParams, +} from "../media.js"; + +describe("GAuth GCS store", () => { + test("save text no-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri = `gs://test-langchainjs/text/test-${Date.now()}-nm`; + const content = "This is a test"; + const blob = await MediaBlob.fromBlob( + new Blob([content], { type: "text/plain" }), + { + path: uri, + } + ); + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const storedBlob = await blobStore.store(blob); + // console.log(storedBlob); + expect(storedBlob?.path).toEqual(uri); + expect(await storedBlob?.asString()).toEqual(content); + expect(storedBlob?.mimetype).toEqual("text/plain"); + expect(storedBlob?.metadata).not.toHaveProperty("metadata"); + expect(storedBlob?.size).toEqual(content.length); + expect(storedBlob?.metadata?.kind).toEqual("storage#object"); + }); + + test("save text with-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri = `gs://test-langchainjs/text/test-${Date.now()}-wm`; + const content = "This is a test"; + const blob = await MediaBlob.fromBlob( + new Blob([content], { type: "text/plain" }), + { + path: uri, + metadata: { + alpha: "one", + bravo: "two", + }, + } + ); + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const storedBlob = await blobStore.store(blob); + // console.log(storedBlob); + expect(storedBlob?.path).toEqual(uri); + expect(await storedBlob?.asString()).toEqual(content); + expect(storedBlob?.mimetype).toEqual("text/plain"); + expect(storedBlob?.metadata).toHaveProperty("metadata"); + expect(storedBlob?.metadata?.metadata?.alpha).toEqual("one"); + expect(storedBlob?.metadata?.metadata?.bravo).toEqual("two"); + expect(storedBlob?.size).toEqual(content.length); + expect(storedBlob?.metadata?.kind).toEqual("storage#object"); + }); + + test("save image no-metadata", async () => { + const filename = `src/tests/data/blue-square.png`; + const dataBuffer = await fs.readFile(filename); + const data = new Blob([dataBuffer], { type: "image/png" }); + + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri = `gs://test-langchainjs/image/test-${Date.now()}-nm`; + const blob = await MediaBlob.fromBlob(data, { + path: uri, + }); + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const storedBlob = await blobStore.store(blob); + // console.log(storedBlob); + expect(storedBlob?.path).toEqual(uri); + expect(storedBlob?.size).toEqual(176); + expect(storedBlob?.mimetype).toEqual("image/png"); + expect(storedBlob?.metadata?.kind).toEqual("storage#object"); + }); + + test("get text no-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri: string = "gs://test-langchainjs/text/test-nm"; + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const blob = await blobStore.fetch(uri); + // console.log(blob); + expect(blob?.path).toEqual(uri); + expect(await blob?.asString()).toEqual("This is a test"); + expect(blob?.mimetype).toEqual("text/plain"); + expect(blob?.metadata).not.toHaveProperty("metadata"); + expect(blob?.size).toEqual(14); + expect(blob?.metadata?.kind).toEqual("storage#object"); + }); + + test("get text with-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri: string = "gs://test-langchainjs/text/test-wm"; + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const blob = await blobStore.fetch(uri); + // console.log(blob); + expect(blob?.path).toEqual(uri); + expect(await blob?.asString()).toEqual("This is a test"); + expect(blob?.mimetype).toEqual("text/plain"); + expect(blob?.metadata).toHaveProperty("metadata"); + expect(blob?.metadata?.metadata?.alpha).toEqual("one"); + expect(blob?.metadata?.metadata?.bravo).toEqual("two"); + expect(blob?.size).toEqual(14); + expect(blob?.metadata?.kind).toEqual("storage#object"); + }); + + test("get image no-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri: string = "gs://test-langchainjs/image/test-nm"; + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const blob = await blobStore.fetch(uri); + // console.log(storedBlob); + expect(blob?.path).toEqual(uri); + expect(blob?.size).toEqual(176); + expect(blob?.mimetype).toEqual("image/png"); + expect(blob?.metadata?.kind).toEqual("storage#object"); + }); +}); diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index a5b9b1001218..8eeed770fee2 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -9,16 +9,31 @@ import { BaseMessageChunk, BaseMessageLike, HumanMessage, + HumanMessageChunk, + MessageContentComplex, SystemMessage, ToolMessage, } from "@langchain/core/messages"; +import { + BlobStoreGoogleCloudStorage, + ChatGoogle, +} from "@langchain/google-gauth"; import { tool } from "@langchain/core/tools"; import { z } from "zod"; import { concat } from "@langchain/core/utils/stream"; +import { + BackedBlobStore, + MediaBlob, + MediaManager, + ReadThroughBlobStore, + SimpleWebBlobStore, +} from "@langchain/google-common/experimental/utils/media_core"; +import { GoogleCloudStorageUri } from "@langchain/google-common/experimental/media"; import { ChatPromptTemplate, MessagesPlaceholder, } from "@langchain/core/prompts"; +import { InMemoryStore } from "@langchain/core/stores"; import { GeminiTool } from "../types.js"; import { ChatVertexAI } from "../chat_models.js"; @@ -196,6 +211,74 @@ describe("GAuth Chat", () => { const result = await model.invoke("What is the weather in Paris?"); expect(result).toHaveProperty("location"); }); + + test("media - fileData", async () => { + class MemStore extends InMemoryStore { + get length() { + return Object.keys(this.store).length; + } + } + const aliasMemory = new MemStore(); + const aliasStore = new BackedBlobStore({ + backingStore: aliasMemory, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + const canonicalStore = new BlobStoreGoogleCloudStorage({ + uriPrefix: new GoogleCloudStorageUri("gs://test-langchainjs/mediatest/"), + defaultStoreOptions: { + actionIfInvalid: "prefixPath", + }, + }); + const blobStore = new ReadThroughBlobStore({ + baseStore: aliasStore, + backingStore: canonicalStore, + }); + const resolver = new SimpleWebBlobStore(); + const mediaManager = new MediaManager({ + store: blobStore, + resolvers: [resolver], + }); + const model = new ChatGoogle({ + modelName: "gemini-1.5-flash", + mediaManager, + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + try { + const res = await model.invoke(messages); + + console.log(res); + + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/LangChain/); + } catch (e) { + console.error(e); + throw e; + } + }); }); test("Stream token count usage_metadata", async () => { diff --git a/libs/langchain-google-webauth/package.json b/libs/langchain-google-webauth/package.json index 320f731a088b..a6830b7eac5d 100644 --- a/libs/langchain-google-webauth/package.json +++ b/libs/langchain-google-webauth/package.json @@ -58,7 +58,8 @@ "release-it": "^17.6.0", "rollup": "^4.5.2", "ts-jest": "^29.1.0", - "typescript": "<5.2.0" + "typescript": "<5.2.0", + "zod": "^3.23.8" }, "publishConfig": { "access": "public" diff --git a/libs/langchain-google-webauth/src/media.ts b/libs/langchain-google-webauth/src/media.ts new file mode 100644 index 000000000000..4b631b519d31 --- /dev/null +++ b/libs/langchain-google-webauth/src/media.ts @@ -0,0 +1,33 @@ +import { + GoogleAbstractedClient, + GoogleBaseLLMInput, +} from "@langchain/google-common"; +import { + BlobStoreAIStudioFileBase, + BlobStoreAIStudioFileBaseParams, + BlobStoreGoogleCloudStorageBase, + BlobStoreGoogleCloudStorageBaseParams, +} from "@langchain/google-common/experimental/media"; +import { WebGoogleAuth, WebGoogleAuthOptions } from "./auth.js"; + +export interface BlobStoreGoogleCloudStorageParams + extends BlobStoreGoogleCloudStorageBaseParams {} + +export class BlobStoreGoogleCloudStorage extends BlobStoreGoogleCloudStorageBase { + buildClient( + fields?: GoogleBaseLLMInput + ): GoogleAbstractedClient { + return new WebGoogleAuth(fields); + } +} + +export interface BlobStoreAIStudioFileParams + extends BlobStoreAIStudioFileBaseParams {} + +export class BlobStoreAIStudioFile extends BlobStoreAIStudioFileBase { + buildAbstractedClient( + fields?: BlobStoreAIStudioFileParams + ): GoogleAbstractedClient { + return new WebGoogleAuth(fields); + } +} diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts new file mode 100644 index 000000000000..60fcba032d18 --- /dev/null +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -0,0 +1,247 @@ +/* eslint-disable import/no-extraneous-dependencies */ +import { StructuredTool } from "@langchain/core/tools"; +import { z } from "zod"; +import { test } from "@jest/globals"; +import { + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + HumanMessage, + HumanMessageChunk, + MessageContentComplex, + SystemMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import { ChatPromptValue } from "@langchain/core/prompt_values"; +import { + MediaManager, + SimpleWebBlobStore, +} from "@langchain/google-common/experimental/utils/media_core"; +import { ChatGoogle } from "../chat_models.js"; +import { BlobStoreAIStudioFile } from "../media.js"; + +class WeatherTool extends StructuredTool { + schema = z.object({ + locations: z + .array(z.object({ name: z.string() })) + .describe("The name of cities to get the weather for."), + }); + + description = + "Get the weather of a specific location and return the temperature in Celsius."; + + name = "get_weather"; + + async _call(input: z.infer) { + console.log(`WeatherTool called with input: ${input}`); + return `The weather in ${JSON.stringify(input.locations)} is 25°C`; + } +} + +describe("Google APIKey Chat", () => { + test("invoke", async () => { + const model = new ChatGoogle(); + try { + const res = await model.invoke("What is 1 + 1?"); + console.log(res); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + console.log(aiMessage); + expect(aiMessage.content).toBeDefined(); + expect(aiMessage.content.length).toBeGreaterThan(0); + expect(aiMessage.content[0]).toBeDefined(); + + // const content = aiMessage.content[0] as MessageContentComplex; + // expect(content).toHaveProperty("type"); + // expect(content.type).toEqual("text"); + + // const textContent = content as MessageContentText; + // expect(textContent.text).toBeDefined(); + // expect(textContent.text).toEqual("2"); + } catch (e) { + console.error(e); + throw e; + } + }); + + test("generate", async () => { + const model = new ChatGoogle(); + try { + const messages: BaseMessage[] = [ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]; + const res = await model.predictMessages(messages); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + expect(aiMessage.content.length).toBeGreaterThan(0); + expect(aiMessage.content[0]).toBeDefined(); + console.log(aiMessage); + + // const content = aiMessage.content[0] as MessageContentComplex; + // expect(content).toHaveProperty("type"); + // expect(content.type).toEqual("text"); + + // const textContent = content as MessageContentText; + // expect(textContent.text).toBeDefined(); + // expect(["H", "T"]).toContainEqual(textContent.text); + } catch (e) { + console.error(e); + throw e; + } + }); + + test("stream", async () => { + const model = new ChatGoogle(); + try { + const input: BaseLanguageModelInput = new ChatPromptValue([ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]); + const res = await model.stream(input); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + expect(resArray).toBeDefined(); + expect(resArray.length).toBeGreaterThanOrEqual(1); + + const lastChunk = resArray[resArray.length - 1]; + expect(lastChunk).toBeDefined(); + expect(lastChunk._getType()).toEqual("ai"); + const aiChunk = lastChunk as AIMessageChunk; + console.log(aiChunk); + + console.log(JSON.stringify(resArray, null, 2)); + } catch (e) { + console.error(e); + throw e; + } + }); + + test.skip("Tool call", async () => { + const chat = new ChatGoogle().bindTools([new WeatherTool()]); + const res = await chat.invoke("What is the weather in SF and LA"); + console.log(res); + expect(res.tool_calls?.length).toEqual(1); + expect(res.tool_calls?.[0].args).toEqual( + JSON.parse(res.additional_kwargs.tool_calls?.[0].function.arguments ?? "") + ); + }); + + test.skip("Few shotting with tool calls", async () => { + const chat = new ChatGoogle().bindTools([new WeatherTool()]); + const res = await chat.invoke("What is the weather in SF"); + console.log(res); + const res2 = await chat.invoke([ + new HumanMessage("What is the weather in SF?"), + new AIMessage({ + content: "", + tool_calls: [ + { + id: "12345", + name: "get_current_weather", + args: { + location: "SF", + }, + }, + ], + }), + new ToolMessage({ + tool_call_id: "12345", + content: "It is currently 24 degrees with hail in SF.", + }), + new AIMessage("It is currently 24 degrees in SF with hail in SF."), + new HumanMessage("What did you say the weather was?"), + ]); + console.log(res2); + expect(res2.content).toContain("24"); + }); + + test.skip("withStructuredOutput", async () => { + const tool = { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The name of city to get the weather for.", + }, + }, + required: ["location"], + }, + }; + const model = new ChatGoogle().withStructuredOutput(tool); + const result = await model.invoke("What is the weather in Paris?"); + expect(result).toHaveProperty("location"); + }); + + test("media - fileData", async () => { + const canonicalStore = new BlobStoreAIStudioFile({}); + const resolver = new SimpleWebBlobStore(); + const mediaManager = new MediaManager({ + store: canonicalStore, + resolvers: [resolver], + }); + const model = new ChatGoogle({ + modelName: "gemini-1.5-flash", + apiVersion: "v1beta", + mediaManager, + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + try { + const res = await model.invoke(messages); + + // console.log(res); + + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/LangChain/); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + console.error(e); + console.error(JSON.stringify(e.details, null, 1)); + throw e; + } + }); +}); diff --git a/libs/langchain-google-webauth/src/tests/data/blue-square.png b/libs/langchain-google-webauth/src/tests/data/blue-square.png new file mode 100644 index 000000000000..c64355dc335b Binary files /dev/null and b/libs/langchain-google-webauth/src/tests/data/blue-square.png differ diff --git a/libs/langchain-google-webauth/src/tests/data/rainbow.mp4 b/libs/langchain-google-webauth/src/tests/data/rainbow.mp4 new file mode 100644 index 000000000000..13779560fe9f Binary files /dev/null and b/libs/langchain-google-webauth/src/tests/data/rainbow.mp4 differ diff --git a/libs/langchain-google-webauth/src/tests/media.int.test.ts b/libs/langchain-google-webauth/src/tests/media.int.test.ts new file mode 100644 index 000000000000..b1678dbd1945 --- /dev/null +++ b/libs/langchain-google-webauth/src/tests/media.int.test.ts @@ -0,0 +1,201 @@ +import fs from "fs/promises"; +import { test } from "@jest/globals"; +import { GoogleCloudStorageUri } from "@langchain/google-common/experimental/media"; +import { MediaBlob } from "@langchain/google-common/experimental/utils/media_core"; +import { + BlobStoreAIStudioFile, + BlobStoreGoogleCloudStorage, + BlobStoreGoogleCloudStorageParams, +} from "../media.js"; + +describe("Google Webauth GCS store", () => { + test("save text no-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri = `gs://test-langchainjs/text/test-${Date.now()}-nm`; + const content = "This is a test"; + const blob = await MediaBlob.fromBlob( + new Blob([content], { type: "text/plain" }), + { + path: uri, + } + ); + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const storedBlob = await blobStore.store(blob); + // console.log(storedBlob); + expect(storedBlob?.path).toEqual(uri); + expect(await storedBlob?.asString()).toEqual(content); + expect(storedBlob?.mimetype).toEqual("text/plain"); + expect(storedBlob?.metadata).not.toHaveProperty("metadata"); + expect(storedBlob?.size).toEqual(content.length); + expect(storedBlob?.metadata?.kind).toEqual("storage#object"); + }); + + test("save text with-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri = `gs://test-langchainjs/text/test-${Date.now()}-wm`; + const content = "This is a test"; + const blob = await MediaBlob.fromBlob( + new Blob([content], { type: "text/plain" }), + { + path: uri, + metadata: { + alpha: "one", + bravo: "two", + }, + } + ); + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const storedBlob = await blobStore.store(blob); + // console.log(storedBlob); + expect(storedBlob?.path).toEqual(uri); + expect(await storedBlob?.asString()).toEqual(content); + expect(storedBlob?.mimetype).toEqual("text/plain"); + expect(storedBlob?.metadata).toHaveProperty("metadata"); + expect(storedBlob?.metadata?.metadata?.alpha).toEqual("one"); + expect(storedBlob?.metadata?.metadata?.bravo).toEqual("two"); + expect(storedBlob?.size).toEqual(content.length); + expect(storedBlob?.metadata?.kind).toEqual("storage#object"); + }); + + test("save image no-metadata", async () => { + const filename = `src/tests/data/blue-square.png`; + const dataBuffer = await fs.readFile(filename); + const data = new Blob([dataBuffer], { type: "image/png" }); + + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri = `gs://test-langchainjs/image/test-${Date.now()}-nm`; + const blob = await MediaBlob.fromBlob(data, { + path: uri, + }); + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const storedBlob = await blobStore.store(blob); + // console.log(storedBlob); + expect(storedBlob?.path).toEqual(uri); + expect(storedBlob?.size).toEqual(176); + expect(storedBlob?.mimetype).toEqual("image/png"); + expect(storedBlob?.metadata?.kind).toEqual("storage#object"); + }); + + test("get text no-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri: string = "gs://test-langchainjs/text/test-nm"; + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const blob = await blobStore.fetch(uri); + // console.log(blob); + expect(blob?.path).toEqual(uri); + expect(await blob?.asString()).toEqual("This is a test"); + expect(blob?.mimetype).toEqual("text/plain"); + expect(blob?.metadata).not.toHaveProperty("metadata"); + expect(blob?.size).toEqual(14); + expect(blob?.metadata?.kind).toEqual("storage#object"); + }); + + test("get text with-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri: string = "gs://test-langchainjs/text/test-wm"; + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const blob = await blobStore.fetch(uri); + // console.log(blob); + expect(blob?.path).toEqual(uri); + expect(await blob?.asString()).toEqual("This is a test"); + expect(blob?.mimetype).toEqual("text/plain"); + expect(blob?.metadata).toHaveProperty("metadata"); + expect(blob?.metadata?.metadata?.alpha).toEqual("one"); + expect(blob?.metadata?.metadata?.bravo).toEqual("two"); + expect(blob?.size).toEqual(14); + expect(blob?.metadata?.kind).toEqual("storage#object"); + }); + + test("get image no-metadata", async () => { + const uriPrefix = new GoogleCloudStorageUri("gs://test-langchainjs/"); + const uri: string = "gs://test-langchainjs/image/test-nm"; + const config: BlobStoreGoogleCloudStorageParams = { + uriPrefix, + }; + const blobStore = new BlobStoreGoogleCloudStorage(config); + const blob = await blobStore.fetch(uri); + // console.log(storedBlob); + expect(blob?.path).toEqual(uri); + expect(blob?.size).toEqual(176); + expect(blob?.mimetype).toEqual("image/png"); + expect(blob?.metadata?.kind).toEqual("storage#object"); + }); +}); + +describe("Google APIKey AIStudioBlobStore", () => { + test("save image no metadata", async () => { + const filename = `src/tests/data/blue-square.png`; + const dataBuffer = await fs.readFile(filename); + const data = new Blob([dataBuffer], { type: "image/png" }); + const blob = await MediaBlob.fromBlob(data, { + path: filename, + }); + const blobStore = new BlobStoreAIStudioFile(); + const storedBlob = await blobStore.store(blob); + console.log(storedBlob); + + // The blob itself is expected to have no data right now, + // but this will hopefully change in the future. + expect(storedBlob?.size).toEqual(0); + expect(storedBlob?.dataType).toEqual("image/png"); + expect(storedBlob?.metadata?.sizeBytes).toEqual("176"); + expect(storedBlob?.metadata?.state).toEqual("ACTIVE"); + }); + + test("save video with retry", async () => { + const filename = `src/tests/data/rainbow.mp4`; + const dataBuffer = await fs.readFile(filename); + const data = new Blob([dataBuffer], { type: "video/mp4" }); + const blob = await MediaBlob.fromBlob(data, { + path: filename, + }); + const blobStore = new BlobStoreAIStudioFile(); + const storedBlob = await blobStore.store(blob); + console.log(storedBlob); + + // The blob itself is expected to have no data right now, + // but this will hopefully change in the future. + expect(storedBlob?.size).toEqual(0); + expect(storedBlob?.dataType).toEqual("video/mp4"); + expect(storedBlob?.metadata?.sizeBytes).toEqual("1020253"); + expect(storedBlob?.metadata?.state).toEqual("ACTIVE"); + expect(storedBlob?.metadata?.videoMetadata?.videoDuration).toEqual("8s"); + }); + + test("save video no retry", async () => { + const filename = `src/tests/data/rainbow.mp4`; + const dataBuffer = await fs.readFile(filename); + const data = new Blob([dataBuffer], { type: "video/mp4" }); + const blob = await MediaBlob.fromBlob(data, { + path: filename, + }); + const blobStore = new BlobStoreAIStudioFile({ + retryTime: -1, + }); + const storedBlob = await blobStore.store(blob); + console.log(storedBlob); + + // The blob itself is expected to have no data right now, + // but this will hopefully change in the future. + expect(storedBlob?.size).toEqual(0); + expect(storedBlob?.dataType).toEqual("video/mp4"); + expect(storedBlob?.metadata?.sizeBytes).toEqual("1020253"); + expect(storedBlob?.metadata?.state).toEqual("PROCESSING"); + expect(storedBlob?.metadata?.videoMetadata).toBeUndefined(); + }); +}); diff --git a/yarn.lock b/yarn.lock index eabf31ae1d3b..d2f441828856 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11968,6 +11968,7 @@ __metadata: ts-jest: ^29.1.0 typescript: <5.2.0 web-auth-library: ^1.0.3 + zod: ^3.23.8 languageName: unknown linkType: soft