Skip to content

Commit

Permalink
google-common[major]: Media manager (#5835)
Browse files Browse the repository at this point in the history
  • Loading branch information
afirstenberg authored Sep 11, 2024
1 parent 6a723fa commit 08092cd
Show file tree
Hide file tree
Showing 27 changed files with 3,949 additions and 740 deletions.
8 changes: 8 additions & 0 deletions libs/langchain-google-common/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ types.cjs
types.js
types.d.ts
types.d.cts
experimental/media.cjs
experimental/media.js
experimental/media.d.ts
experimental/media.d.cts
experimental/utils/media_core.cjs
experimental/utils/media_core.js
experimental/utils/media_core.d.ts
experimental/utils/media_core.d.cts
node_modules
dist
.yarn
2 changes: 2 additions & 0 deletions libs/langchain-google-common/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export const config = {
index: "index",
utils: "utils/index",
types: "types",
"experimental/media": "experimental/media",
"experimental/utils/media_core": "experimental/utils/media_core",
},
tsConfigPath: resolve("./tsconfig.json"),
cjsSource: "./dist-cjs",
Expand Down
28 changes: 27 additions & 1 deletion libs/langchain-google-common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@
"import": "./types.js",
"require": "./types.cjs"
},
"./experimental/media": {
"types": {
"import": "./experimental/media.d.ts",
"require": "./experimental/media.d.cts",
"default": "./experimental/media.d.ts"
},
"import": "./experimental/media.js",
"require": "./experimental/media.cjs"
},
"./experimental/utils/media_core": {
"types": {
"import": "./experimental/utils/media_core.d.ts",
"require": "./experimental/utils/media_core.d.cts",
"default": "./experimental/utils/media_core.d.ts"
},
"import": "./experimental/utils/media_core.js",
"require": "./experimental/utils/media_core.cjs"
},
"./package.json": "./package.json"
},
"files": [
Expand All @@ -107,6 +125,14 @@
"types.cjs",
"types.js",
"types.d.ts",
"types.d.cts"
"types.d.cts",
"experimental/media.cjs",
"experimental/media.js",
"experimental/media.d.ts",
"experimental/media.d.cts",
"experimental/utils/media_core.cjs",
"experimental/utils/media_core.js",
"experimental/utils/media_core.d.ts",
"experimental/utils/media_core.d.cts"
]
}
36 changes: 28 additions & 8 deletions libs/langchain-google-common/src/auth.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { ReadableJsonStream } from "./utils/stream.js";
import { GooglePlatformType } from "./types.js";

export type GoogleAbstractedClientOpsMethod = "GET" | "POST";
export type GoogleAbstractedClientOpsMethod = "GET" | "POST" | "DELETE";

export type GoogleAbstractedClientOpsResponseType = "json" | "stream";
export type GoogleAbstractedClientOpsResponseType = "json" | "stream" | "blob";

export type GoogleAbstractedClientOps = {
url?: string;
Expand All @@ -28,6 +28,17 @@ export abstract class GoogleAbstractedFetchClient

abstract request(opts: GoogleAbstractedClientOps): unknown;

async _buildData(res: Response, opts: GoogleAbstractedClientOps) {
switch (opts.responseType) {
case "json":
return res.json();
case "stream":
return new ReadableJsonStream(res.body);
default:
return res.blob();
}
}

async _request(
url: string | undefined,
opts: GoogleAbstractedClientOps,
Expand All @@ -47,7 +58,11 @@ export abstract class GoogleAbstractedFetchClient
},
};
if (opts.data !== undefined) {
fetchOptions.body = JSON.stringify(opts.data);
if (typeof opts.data === "string") {
fetchOptions.body = opts.data;
} else {
fetchOptions.body = JSON.stringify(opts.data);
}
}

const res = await fetch(url, fetchOptions);
Expand All @@ -57,16 +72,21 @@ export abstract class GoogleAbstractedFetchClient
const error = new Error(
`Google request failed with status code ${res.status}: ${resText}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
/* eslint-disable @typescript-eslint/no-explicit-any */
(error as any).response = res;
(error as any).details = {
url,
opts,
fetchOptions,
result: res,
};
/* eslint-enable @typescript-eslint/no-explicit-any */
throw error;
}

const data = await this._buildData(res, opts);
return {
data:
opts.responseType === "json"
? await res.json()
: new ReadableJsonStream(res.body),
data,
config: {},
status: res.status,
statusText: res.statusText,
Expand Down
87 changes: 49 additions & 38 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ import {
copyAndValidateModelParamsInto,
} from "./utils/common.js";
import { AbstractGoogleLLMConnection } from "./connection.js";
import {
baseMessageToContent,
safeResponseToChatGeneration,
safeResponseToChatResult,
DefaultGeminiSafetyHandler,
} from "./utils/gemini.js";
import { DefaultGeminiSafetyHandler } from "./utils/gemini.js";
import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js";
import { JsonStream } from "./utils/stream.js";
import { ensureParams } from "./utils/failed_handler.js";
Expand All @@ -55,6 +50,7 @@ import type {
GeminiFunctionDeclaration,
GeminiFunctionSchema,
GoogleAIToolType,
GeminiAPIConfig,
} from "./types.js";
import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js";

Expand Down Expand Up @@ -100,61 +96,69 @@ class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
return true;
}

formatContents(
async formatContents(
input: BaseMessage[],
_parameters: GoogleAIModelParams
): GeminiContent[] {
return input
.map((msg, i) =>
baseMessageToContent(msg, input[i - 1], this.useSystemInstruction)
): Promise<GeminiContent[]> {
const inputPromises: Promise<GeminiContent[]>[] = input.map((msg, i) =>
this.api.baseMessageToContent(
msg,
input[i - 1],
this.useSystemInstruction
)
.reduce((acc, cur) => {
// Filter out the system content
if (cur.every((content) => content.role === "system")) {
return acc;
}

// Combine adjacent function messages
if (
cur[0]?.role === "function" &&
acc.length > 0 &&
acc[acc.length - 1].role === "function"
) {
acc[acc.length - 1].parts = [
...acc[acc.length - 1].parts,
...cur[0].parts,
];
} else {
acc.push(...cur);
}
);
const inputs = await Promise.all(inputPromises);

return inputs.reduce((acc, cur) => {
// Filter out the system content
if (cur.every((content) => content.role === "system")) {
return acc;
}, [] as GeminiContent[]);
}

// Combine adjacent function messages
if (
cur[0]?.role === "function" &&
acc.length > 0 &&
acc[acc.length - 1].role === "function"
) {
acc[acc.length - 1].parts = [
...acc[acc.length - 1].parts,
...cur[0].parts,
];
} else {
acc.push(...cur);
}

return acc;
}, [] as GeminiContent[]);
}

formatSystemInstruction(
async formatSystemInstruction(
input: BaseMessage[],
_parameters: GoogleAIModelParams
): GeminiContent {
): Promise<GeminiContent> {
if (!this.useSystemInstruction) {
return {} as GeminiContent;
}

let ret = {} as GeminiContent;
input.forEach((message, index) => {
for (let index = 0; index < input.length; index += 1) {
const message = input[index];
if (message._getType() === "system") {
// For system types, we only want it if it is the first message,
// if it appears anywhere else, it should be an error.
if (index === 0) {
// eslint-disable-next-line prefer-destructuring
ret = baseMessageToContent(message, undefined, true)[0];
ret = (
await this.api.baseMessageToContent(message, undefined, true)
)[0];
} else {
throw new Error(
"System messages are only permitted as the first passed message."
);
}
}
});
}

return ret;
}
Expand All @@ -168,6 +172,7 @@ export interface ChatGoogleBaseInput<AuthOptions>
GoogleConnectionParams<AuthOptions>,
GoogleAIModelParams,
GoogleAISafetyParams,
GeminiAPIConfig,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

/**
Expand Down Expand Up @@ -338,7 +343,10 @@ export abstract class ChatGoogleBase<AuthOptions>
parameters,
options
);
const ret = safeResponseToChatResult(response, this.safetyHandler);
const ret = this.connection.api.safeResponseToChatResult(
response,
this.safetyHandler
);
await runManager?.handleLLMNewToken(ret.generations[0].text);
return ret;
}
Expand Down Expand Up @@ -378,7 +386,10 @@ export abstract class ChatGoogleBase<AuthOptions>
}
const chunk =
output !== null
? safeResponseToChatGeneration({ data: output }, this.safetyHandler)
? this.connection.api.safeResponseToChatGeneration(
{ data: output },
this.safetyHandler
)
: new ChatGenerationChunk({
text: "",
generationInfo: { finishReason: "stop" },
Expand Down
Loading

0 comments on commit 08092cd

Please sign in to comment.