From 23be365fa48322646293420d7666ca1df7d8e552 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Mon, 11 Dec 2023 11:29:54 -0800 Subject: [PATCH 1/2] Use different import maps for core vs main langchain --- langchain-core/src/load/index.ts | 21 ++++++++++----------- langchain-core/src/runnables/base.ts | 2 ++ langchain-core/src/runnables/index.ts | 1 + langchain/src/load/tests/load.int.test.ts | 14 ++++++++++++++ langchain/src/load/tests/load.test.ts | 17 +++++++++++++++++ langchain/src/schema/runnable/base.ts | 1 + 6 files changed, 45 insertions(+), 11 deletions(-) create mode 100644 langchain/src/load/tests/load.int.test.ts diff --git a/langchain-core/src/load/index.ts b/langchain-core/src/load/index.ts index 870b7a374f93..babdcd43ce08 100644 --- a/langchain-core/src/load/index.ts +++ b/langchain-core/src/load/index.ts @@ -6,7 +6,7 @@ import { get_lc_unique_name, } from "./serializable.js"; import { optionalImportEntrypoints as defaultOptionalImportEntrypoints } from "./import_constants.js"; -import * as defaultImportMap from "./import_map.js"; +import * as coreImportMap from "./import_map.js"; import type { OptionalImportMap, SecretMap } from "./import_type.js"; import { type SerializedFields, keyFromJson, mapKeys } from "./map_keys.js"; import { getEnvironmentVariable } from "../utils/env.js"; @@ -98,10 +98,11 @@ async function reviver( const str = JSON.stringify(serialized); const [name, ...namespaceReverse] = serialized.id.slice().reverse(); const namespace = namespaceReverse.reverse(); - const finalImportMap = { ...defaultImportMap, ...importMap }; + const importMaps = { langchain_core: coreImportMap, langchain: importMap }; let module: - | (typeof finalImportMap)[keyof typeof finalImportMap] + | (typeof importMaps)["langchain_core"][keyof (typeof importMaps)["langchain_core"]] + | (typeof importMaps)["langchain"][keyof (typeof importMaps)["langchain"]] | OptionalImportMap[keyof OptionalImportMap] | null = null; @@ -132,14 +133,12 @@ async function reviver( ); } } else { - // Currently, we only support langchain imports. - if ( - namespace[0] === "langchain" || - namespace[0] === "langchain_core" || - namespace[0] === "langchain_community" || - namespace[0] === "langchain_anthropic" || - namespace[0] === "langchain_openai" - ) { + let finalImportMap: + | (typeof importMaps)["langchain"] + | (typeof importMaps)["langchain_core"]; + // Currently, we only support langchain and langchain_core imports. + if (namespace[0] === "langchain" || namespace[0] === "langchain_core") { + finalImportMap = importMaps[namespace[0]]; namespace.shift(); } else { throw new Error(`Invalid namespace: ${pathStr} -> ${str}`); diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index 477ce73716c2..cd8e50eab4f5 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -1487,6 +1487,8 @@ export class RunnableLambda extends Runnable< } } +export class RunnableParallel extends RunnableMap {} + /** * A Runnable that can fallback to other Runnables if it fails. */ diff --git a/langchain-core/src/runnables/index.ts b/langchain-core/src/runnables/index.ts index 1409fb0d9580..dd8dce77c3e6 100644 --- a/langchain-core/src/runnables/index.ts +++ b/langchain-core/src/runnables/index.ts @@ -10,6 +10,7 @@ export { RunnableRetry, RunnableSequence, RunnableMap, + RunnableParallel, RunnableLambda, RunnableWithFallbacks, _coerceToRunnable, diff --git a/langchain/src/load/tests/load.int.test.ts b/langchain/src/load/tests/load.int.test.ts new file mode 100644 index 000000000000..012b6e333689 --- /dev/null +++ b/langchain/src/load/tests/load.int.test.ts @@ -0,0 +1,14 @@ +import { test, expect } from "@jest/globals"; +import { RunnableSequence } from "@langchain/core/runnables"; + +import { load } from "../index.js"; + +test("Should load and invoke real-world serialized chain", async () => { + const serializedValue = `{"lc": 1, "type": "constructor", "id": ["langchain_core", "runnables", "RunnableSequence"], "kwargs": {"first": {"lc": 1, "type": "constructor", "id": ["langchain_core", "runnables", "RunnableParallel"], "kwargs": {"steps": {"equation_statement": {"lc": 1, "type": "constructor", "id": ["langchain_core", "runnables", "RunnablePassthrough"], "kwargs": {"func": null, "afunc": null, "input_type": null}}}}}, "middle": [{"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "chat", "ChatPromptTemplate"], "kwargs": {"input_variables": ["equation_statement"], "messages": [{"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "chat", "SystemMessagePromptTemplate"], "kwargs": {"prompt": {"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "prompt", "PromptTemplate"], "kwargs": {"input_variables": [], "template": "Write out the following equation using algebraic symbols then solve it. Use the format\\n\\nEQUATION:...\\nSOLUTION:...\\n\\n", "template_format": "f-string", "partial_variables": {}}}}}, {"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "chat", "HumanMessagePromptTemplate"], "kwargs": {"prompt": {"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "prompt", "PromptTemplate"], "kwargs": {"input_variables": ["equation_statement"], "template": "{equation_statement}", "template_format": "f-string", "partial_variables": {}}}}}]}}, {"lc": 1, "type": "constructor", "id": ["langchain", "chat_models", "openai", "ChatOpenAI"], "kwargs": {"temperature": 0.0, "openai_api_key": {"lc": 1, "type": "secret", "id": ["OPENAI_API_KEY"]}}}], "last": {"lc": 1, "type": "constructor", "id": ["langchain_core", "output_parsers", "string", "StrOutputParser"], "kwargs": {}}}}`; + const chain = await load(serializedValue); + const result = await chain.invoke( + "x raised to the third plus seven equals 12" + ); + console.log(result); + expect(typeof result).toBe("string"); +}); diff --git a/langchain/src/load/tests/load.test.ts b/langchain/src/load/tests/load.test.ts index 2c78537329fa..d142750cae29 100644 --- a/langchain/src/load/tests/load.test.ts +++ b/langchain/src/load/tests/load.test.ts @@ -521,3 +521,20 @@ test("Should load traces even if the constructor name changes (minified environm console.log(JSON.stringify(llm2, null, 2)); expect(JSON.stringify(llm2, null, 2)).toBe(str); }); + +test("Should load a real-world serialized chain", async () => { + const serializedValue = `{"lc": 1, "type": "constructor", "id": ["langchain_core", "runnables", "RunnableSequence"], "kwargs": {"first": {"lc": 1, "type": "constructor", "id": ["langchain_core", "runnables", "RunnableParallel"], "kwargs": {"steps": {"equation_statement": {"lc": 1, "type": "constructor", "id": ["langchain_core", "runnables", "RunnablePassthrough"], "kwargs": {"func": null, "afunc": null, "input_type": null}}}}}, "middle": [{"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "chat", "ChatPromptTemplate"], "kwargs": {"input_variables": ["equation_statement"], "messages": [{"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "chat", "SystemMessagePromptTemplate"], "kwargs": {"prompt": {"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "prompt", "PromptTemplate"], "kwargs": {"input_variables": [], "template": "Write out the following equation using algebraic symbols then solve it. Use the format\\n\\nEQUATION:...\\nSOLUTION:...\\n\\n", "template_format": "f-string", "partial_variables": {}}}}}, {"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "chat", "HumanMessagePromptTemplate"], "kwargs": {"prompt": {"lc": 1, "type": "constructor", "id": ["langchain_core", "prompts", "prompt", "PromptTemplate"], "kwargs": {"input_variables": ["equation_statement"], "template": "{equation_statement}", "template_format": "f-string", "partial_variables": {}}}}}]}}, {"lc": 1, "type": "constructor", "id": ["langchain", "chat_models", "openai", "ChatOpenAI"], "kwargs": {"temperature": 0.0, "openai_api_key": {"lc": 1, "type": "secret", "id": ["OPENAI_API_KEY"]}}}], "last": {"lc": 1, "type": "constructor", "id": ["langchain_core", "output_parsers", "string", "StrOutputParser"], "kwargs": {}}}}`; + const chain = await load(serializedValue, { + OPENAI_API_KEY: "openai-key", + }); + // @ts-expect-error testing + expect(chain.first.constructor.lc_name()).toBe("RunnableMap"); + // @ts-expect-error testing + expect(chain.middle.length).toBe(2); + // @ts-expect-error testing + expect(chain.middle[0].constructor.lc_name()).toBe(`ChatPromptTemplate`); + // @ts-expect-error testing + expect(chain.middle[1].constructor.lc_name()).toBe(`ChatOpenAI`); + // @ts-expect-error testing + expect(chain.last.constructor.lc_name()).toBe(`StrOutputParser`); +}); diff --git a/langchain/src/schema/runnable/base.ts b/langchain/src/schema/runnable/base.ts index 9aaf2f30c5a5..e4d45ce37c5a 100644 --- a/langchain/src/schema/runnable/base.ts +++ b/langchain/src/schema/runnable/base.ts @@ -10,6 +10,7 @@ export { RunnableRetry, RunnableSequence, RunnableMap, + RunnableParallel, RunnableLambda, RunnableWithFallbacks, _coerceToRunnable, From 6b4e22385e4e1afbf31ee1a27e20c18ed079e6de Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Mon, 11 Dec 2023 11:47:40 -0800 Subject: [PATCH 2/2] Expand serialization test to include more expected entrypoints --- .../src/load/tests/cross_language.int.test.ts | 86 +++++++++++++++---- 1 file changed, 70 insertions(+), 16 deletions(-) diff --git a/langchain/src/load/tests/cross_language.int.test.ts b/langchain/src/load/tests/cross_language.int.test.ts index b657bff1490d..bc314a0c3d5a 100644 --- a/langchain/src/load/tests/cross_language.int.test.ts +++ b/langchain/src/load/tests/cross_language.int.test.ts @@ -3,7 +3,7 @@ import { fileURLToPath } from "node:url"; import { readFileSync } from "node:fs"; import * as path from "node:path"; -// import { load } from "../index.js"; +import { load } from "../index.js"; const IMPORTANT_IMPORTS = JSON.parse( readFileSync( @@ -15,24 +15,78 @@ const IMPORTANT_IMPORTS = JSON.parse( ).toString() ); +const CURRENT_KNOWN_FAILURES = [ + "langchain/schema/agent/AgentAction", + "langchain/schema/agent/AgentFinish", + "langchain/schema/prompt_template/BasePromptTemplate", + "langchain/schema/agent/AgentActionMessageLog", + "langchain/schema/agent/OpenAIToolAgentAction", + "langchain/prompts/chat/BaseMessagePromptTemplate", + "langchain/schema/output/ChatGeneration", + "langchain/schema/output/Generation", + "langchain/schema/document/Document", + "langchain/schema/runnable/DynamicRunnable", + "langchain/schema/prompt/PromptValue", + "langchain/llms/openai/BaseOpenAI", + "langchain/llms/openai/AzureOpenAI", + "langchain/schema/prompt_template/BaseChatPromptTemplate", + "langchain/prompts/few_shot_with_templates/FewShotPromptWithTemplates", + "langchain/prompts/base/StringPromptTemplate", + "langchain/prompts/chat/BaseStringMessagePromptTemplate", + "langchain/prompts/chat/ChatPromptValue", + "langchain/prompts/chat/ChatPromptValueConcrete", + "langchain/schema/runnable/HubRunnable", + "langchain/schema/runnable/RunnableBindingBase", + "langchain/schema/runnable/OpenAIFunctionsRouter", + "langchain/schema/runnable/RunnableEachBase", + "langchain/schema/runnable/RunnableConfigurableAlternatives", + "langchain/schema/runnable/RunnableConfigurableFields", + "langchain_core/agents/AgentAction", + "langchain_core/agents/AgentFinish", + "langchain_core/agents/AgentActionMessageLog", + "langchain/agents/output_parsers/openai_tools/OpenAIToolAgentAction", + "langchain_core/outputs/chat_generation/ChatGeneration", + "langchain_core/outputs/generation/Generation", + "langchain_core/runnables/configurable/DynamicRunnable", + "langchain_core/prompt_values/PromptValue", + "langchain/llms/openai/BaseOpenAI", + "langchain/llms/openai/AzureOpenAI", + "langchain_core/prompts/few_shot_with_templates/FewShotPromptWithTemplates", + "langchain_core/prompts/string/StringPromptTemplate", + "langchain_core/prompts/chat/BaseStringMessagePromptTemplate", + "langchain_core/prompt_values/ChatPromptValueConcrete", + "langchain/runnables/hub/HubRunnable", + "langchain_core/runnables/base/RunnableBindingBase", + "langchain/runnables/openai_functions/OpenAIFunctionsRouter", + "langchain_core/runnables/base/RunnableEachBase", + "langchain_core/runnables/configurable/RunnableConfigurableAlternatives", + "langchain_core/runnables/configurable/RunnableConfigurableFields", +]; + +const CROSS_LANGUAGE_ENTRYPOINTS = Object.keys(IMPORTANT_IMPORTS) + .concat(Object.values(IMPORTANT_IMPORTS)) + .filter((v) => !CURRENT_KNOWN_FAILURES.includes(v)); + describe("Test cross language serialization of important modules", () => { // https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/load/mapping.py - test.each(Object.keys(IMPORTANT_IMPORTS))( + test.each(CROSS_LANGUAGE_ENTRYPOINTS)( "Test matching serialization names for: %s", - async (_item) => { - // const idComponents = item.split("/"); - // const mockItem = { - // lc: 1, - // type: "constructor", - // id: idComponents, - // kwargs: {} - // }; - // try { - // const result = await load(JSON.stringify(mockItem)) as any; - // expect(result.constructor.name).toEqual(idComponents[idComponents.length - 1]); - // } catch (e: any) { - // expect(e.message).not.toContain("Invalid identifer: $"); - // } + async (item) => { + const idComponents = item.split("/"); + const mockItem = { + lc: 1, + type: "constructor", + id: idComponents, + kwargs: {}, + }; + try { + const result = (await load(JSON.stringify(mockItem))) as any; + expect(result.constructor.name).toEqual( + idComponents[idComponents.length - 1] + ); + } catch (e: any) { + expect(e.message).not.toContain("Invalid identifer: $"); + } } ); });