Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add inference for RunnableMap RunOutput type #3517

Merged
merged 23 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c7cb793
add RunnableMapLike to infer RunnableMap output
davidilling Dec 4, 2023
d6cfaae
remove unneeded changes
davidilling Dec 4, 2023
eab8457
fix linting
davidilling Dec 4, 2023
acdab12
format
davidilling Dec 4, 2023
095a4e8
fix runnable_stream_log.test
davidilling Dec 4, 2023
723e2ed
upgrade typescript version
davidilling Dec 6, 2023
102042b
clean types
davidilling Dec 6, 2023
954a14f
fix structured_output_runnables.int.test
davidilling Dec 6, 2023
b84f3e4
Merge remote-tracking branch 'upstream' into infer-map-types
davidilling Dec 6, 2023
a8acd75
ts version ~5.1.6
davidilling Dec 6, 2023
6e97687
Merge remote-tracking branch 'upstream/main' into infer-map-types
davidilling Dec 6, 2023
7652bf2
remove unused eslint-disable-next-line
davidilling Dec 6, 2023
5cb1edd
remove another disable no-explicit-any
davidilling Dec 6, 2023
d07ce12
remove another no-explicit-any
davidilling Dec 6, 2023
ca586cb
Merge branch 'main' into infer-map-types
dilling Dec 9, 2023
7826562
move eslint
davidilling Dec 9, 2023
1979976
Merge branch 'main' of https://github.com/hwchase17/langchainjs into …
jacoblee93 Dec 12, 2023
602be50
Format
jacoblee93 Dec 12, 2023
fd3204f
Merge branch 'main' into infer-map-types
dilling Dec 13, 2023
503ceda
Merge branch 'main' of https://github.com/hwchase17/langchainjs into …
jacoblee93 Dec 13, 2023
a81428b
Default runnable maps to any type in case inference is not possible
jacoblee93 Dec 13, 2023
361c250
Add tests
jacoblee93 Dec 13, 2023
f206dcb
Merge branch 'main' of https://github.com/hwchase17/langchainjs into …
jacoblee93 Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ export type RunnableFunc<RunInput, RunOutput> = (
| (Record<string, any> & { config: RunnableConfig })
) => RunOutput | Promise<RunOutput>;

export type RunnableMapLike<RunInput, RunOutput> = {
[K in keyof RunOutput]: RunnableLike<RunInput, RunOutput[K]>;
};

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type RunnableLike<RunInput = any, RunOutput = any> =
| Runnable<RunInput, RunOutput>
| RunnableFunc<RunInput, RunOutput>
| { [key: string]: RunnableLike<RunInput, RunOutput> };
| RunnableMapLike<RunInput, RunOutput>;

export type RunnableBatchOptions = {
maxConcurrency?: number;
Expand Down Expand Up @@ -1368,11 +1372,12 @@ export class RunnableSequence<
* const result = await mapChain.invoke({ topic: "bear" });
* ```
*/
export class RunnableMap<RunInput> extends Runnable<
RunInput,
export class RunnableMap<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Record<string, any>
> {
RunInput = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
> extends Runnable<RunInput, RunOutput> {
static lc_name() {
return "RunnableMap";
}
Expand All @@ -1387,23 +1392,28 @@ export class RunnableMap<RunInput> extends Runnable<
return Object.keys(this.steps);
}

constructor(fields: { steps: Record<string, RunnableLike<RunInput>> }) {
constructor(fields: { steps: RunnableMapLike<RunInput, RunOutput> }) {
super(fields);
this.steps = {};
for (const [key, value] of Object.entries(fields.steps)) {
this.steps[key] = _coerceToRunnable(value);
}
}

static from<RunInput>(steps: Record<string, RunnableLike<RunInput>>) {
return new RunnableMap<RunInput>({ steps });
static from<
RunInput,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
steps: RunnableMapLike<RunInput, RunOutput>
): RunnableMap<RunInput, RunOutput> {
return new RunnableMap<RunInput, RunOutput>({ steps });
}

async invoke(
input: RunInput,
options?: Partial<BaseCallbackConfig>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Promise<Record<string, any>> {
): Promise<RunOutput> {
const callbackManager_ = await getCallbackMangerForConfig(options);
const runManager = await callbackManager_?.handleChainStart(
this.toJSON(),
Expand Down Expand Up @@ -1432,7 +1442,7 @@ export class RunnableMap<RunInput> extends Runnable<
throw e;
}
await runManager?.handleChainEnd(output);
return output;
return output as RunOutput;
}
}

Expand Down Expand Up @@ -1665,9 +1675,9 @@ export function _coerceToRunnable<RunInput, RunOutput>(
} else if (!Array.isArray(coerceable) && typeof coerceable === "object") {
const runnables: Record<string, Runnable<RunInput>> = {};
for (const [key, value] of Object.entries(coerceable)) {
runnables[key] = _coerceToRunnable(value);
runnables[key] = _coerceToRunnable(value as RunnableLike);
}
return new RunnableMap<RunInput>({
return new RunnableMap({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you drop the generic passing here?

Copy link
Contributor Author

@dilling dilling Dec 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when passing as RunnableMap<RunInput, RunOutput>, the type of RunOutput does not line up until the line below

    }) as unknown as Runnable<RunInput, Exclude<RunOutput, Error>>;

Passing just the RunInput defaults the RunOutput to Record<string, any>. So you end up having to pass both values (which is not possible to the reason above) or neither

steps: runnables,
}) as unknown as Runnable<RunInput, Exclude<RunOutput, Error>>;
} else {
Expand Down
34 changes: 1 addition & 33 deletions langchain-core/src/runnables/tests/runnable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,18 @@ import { StringOutputParser } from "../../output_parsers/string.js";
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
} from "../../prompts/chat.js";
import { PromptTemplate } from "../../prompts/prompt.js";
import {
FakeLLM,
FakeChatModel,
FakeRetriever,
FakeStreamingLLM,
FakeSplitIntoListParser,
FakeRunnable,
FakeListChatModel,
} from "../../utils/testing/index.js";
import { RunnableSequence, RunnableMap, RunnableLambda } from "../base.js";
import { RunnableSequence, RunnableLambda } from "../base.js";
import { RouterRunnable } from "../router.js";
import { Document } from "../../documents/document.js";

test("Test batch", async () => {
const llm = new FakeLLM({});
Expand Down Expand Up @@ -70,35 +67,6 @@ test("Pipe from one runnable to the next", async () => {
expect(result).toBe("Hello world!");
});

test("Create a runnable sequence with a runnable map", async () => {
const promptTemplate = ChatPromptTemplate.fromMessages<{
documents: string;
question: string;
}>([
SystemMessagePromptTemplate.fromTemplate(`You are a nice assistant.`),
HumanMessagePromptTemplate.fromTemplate(
`Context:\n{documents}\n\nQuestion:\n{question}`
),
]);
const llm = new FakeChatModel({});
const inputs = {
question: (input: string) => input,
documents: RunnableSequence.from([
new FakeRetriever(),
(docs: Document[]) => JSON.stringify(docs),
]),
extraField: new FakeLLM({}),
};
const runnable = new RunnableMap({ steps: inputs })
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
.pipe(promptTemplate)
.pipe(llm);
const result = await runnable.invoke("Do you know the Muffin Man?");
console.log(result);
expect(result.content).toEqual(
`You are a nice assistant.\nContext:\n[{"pageContent":"foo","metadata":{}},{"pageContent":"bar","metadata":{}}]\n\nQuestion:\nDo you know the Muffin Man?`
);
});

test("Stream the entire way through", async () => {
const llm = new FakeStreamingLLM({});
const stream = await llm.pipe(new StringOutputParser()).stream("Hi there!");
Expand Down
105 changes: 105 additions & 0 deletions langchain-core/src/runnables/tests/runnable_map.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/* eslint-disable no-promise-executor-return */
/* eslint-disable @typescript-eslint/no-explicit-any */

import { StringOutputParser } from "../../output_parsers/string.js";
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
} from "../../prompts/chat.js";
import {
FakeLLM,
FakeChatModel,
FakeRetriever,
} from "../../utils/testing/index.js";
import { RunnableSequence, RunnableMap } from "../base.js";
import { RunnablePassthrough } from "../passthrough.js";

test("Create a runnable sequence with a runnable map", async () => {
const promptTemplate = ChatPromptTemplate.fromMessages<{
documents: string;
question: string;
}>([
SystemMessagePromptTemplate.fromTemplate(`You are a nice assistant.`),
HumanMessagePromptTemplate.fromTemplate(
`Context:\n{documents}\n\nQuestion:\n{question}`
),
]);
const llm = new FakeChatModel({});
const inputs = {
question: (input: string) => input,
documents: RunnableSequence.from([
new FakeRetriever(),
(docs: Document[]) => JSON.stringify(docs),
]),
extraField: new FakeLLM({}),
};
const runnable = new RunnableMap({ steps: inputs })
.pipe(promptTemplate)
.pipe(llm);
const result = await runnable.invoke("Do you know the Muffin Man?");
console.log(result);
expect(result.content).toEqual(
`You are a nice assistant.\nContext:\n[{"pageContent":"foo","metadata":{}},{"pageContent":"bar","metadata":{}}]\n\nQuestion:\nDo you know the Muffin Man?`
);
});

test("Test map inference in a sequence", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
);
const chain = RunnableSequence.from([
{
question: new RunnablePassthrough(),
context: async () => "SOME STUFF",
},
prompt,
new FakeLLM({}),
new StringOutputParser(),
]);
const response = await chain.invoke("Just passing through.");
console.log(response);
expect(response).toBe(
`Human: context: SOME STUFF, question: Just passing through.`
);
});

test("Should not allow mismatched inputs", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
);
const badChain = RunnableSequence.from([
{
// @ts-expect-error TS compiler should flag mismatched input types
question: new FakeLLM({}),
context: async (input: number) => input,
},
prompt,
new FakeLLM({}),
new StringOutputParser(),
]);
console.log(badChain);
});

test("Should not allow improper inputs into a map in a sequence", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
);
const map = RunnableMap.from({
question: new FakeLLM({}),
context: async (_input: string) => 9,
});
// @ts-expect-error TS compiler should flag mismatched output types
const runnable = prompt.pipe(map);
console.log(runnable);
});

test("Should not allow improper outputs from a map into the next item in a sequence", async () => {
const map = RunnableMap.from({
question: new FakeLLM({}),
context: async (_input: string) => 9,
});
// @ts-expect-error TS compiler should flag mismatched output types
const runnable = map.pipe(new FakeLLM({}));
console.log(runnable);
});
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ test("Runnable streamLog method with a more complicated sequence", async () => {
response: "testing",
}).withConfig({ tags: ["only_one"] }),
};
const runnable = new RunnableMap({ steps: inputs })

const runnable = new RunnableMap({
steps: inputs,
})
.pipe(promptTemplate)
.pipe(llm);
const stream = await runnable.streamLog(
Expand Down