Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

all[patch]: Fix typing across different core versions by using interfaces instead of abstract classes #3709

Merged
merged 11 commits into from
Dec 19, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ In the above code snippet, the fromLLM method of the `ConversationalRetrievalQAC

```typescript
static fromLLM(
llm: BaseLanguageModel,
retriever: BaseRetriever,
llm: BaseLanguageModelInterface,
retriever: BaseRetrieverInterface,
options?: {
questionGeneratorChainOptions?: {
llm?: BaseLanguageModel;
llm?: BaseLanguageModelInterface;
template?: string;
};
qaChainOptions?: QAChainParams;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ abstract class BaseVectorStore implements VectorStore {
static fromTexts(
texts: string[],
metadatas: object[] | object,
embeddings: Embeddings,
embeddings: EmbeddingsInterface,
dbConfig: Record<string, any>
): Promise<VectorStore>;

static fromDocuments(
docs: Document[],
embeddings: Embeddings,
embeddings: EmbeddingsInterface,
dbConfig: Record<string, any>
): Promise<VectorStore>;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The interface for prompt selectors is quite simple:

```typescript
abstract class BasePromptSelector {
abstract getPrompt(llm: BaseLanguageModel): BasePromptTemplate;
abstract getPrompt(llm: BaseLanguageModelInterface): BasePromptTemplate;
}
```

Expand All @@ -31,7 +31,7 @@ The example below shows how to use a prompt selector when loading a chain:

```typescript
const loadQAStuffChain = (
llm: BaseLanguageModel,
llm: BaseLanguageModelInterface,
params: StuffQAChainParams = {}
) => {
const { prompt = QA_PROMPT_SELECTOR.getPrompt(llm) } = params;
Expand Down
6 changes: 3 additions & 3 deletions examples/src/chains/advanced_subclass_call.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import { BasePromptTemplate, PromptTemplate } from "langchain/prompts";
import { BaseLanguageModel } from "langchain/base_language";
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import { CallbackManagerForChainRun } from "langchain/callbacks";
import { BaseChain, ChainInputs } from "langchain/chains";
import { ChainValues } from "langchain/schema";

export interface MyCustomChainInputs extends ChainInputs {
llm: BaseLanguageModel;
llm: BaseLanguageModelInterface;
promptTemplate: string;
}

export class MyCustomChain extends BaseChain implements MyCustomChainInputs {
llm: BaseLanguageModel;
llm: BaseLanguageModelInterface;

promptTemplate: string;

Expand Down
4 changes: 2 additions & 2 deletions langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
get_lc_unique_name,
} from "../load/serializable.js";
import type { SerializedFields } from "../load/map_keys.js";
import { Document } from "../documents/document.js";
import type { DocumentInterface } from "../documents/document.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type Error = any;
Expand Down Expand Up @@ -242,7 +242,7 @@ abstract class BaseCallbackHandlerMethodsClass {
Promise<any> | any;

handleRetrieverEnd?(
documents: Document[],
documents: DocumentInterface[],
runId: string,
parentRunId?: string,
tags?: string[]
Expand Down
4 changes: 2 additions & 2 deletions langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {
} from "../tracers/tracer_langchain.js";
import { consumeCallback } from "./promises.js";
import { Serialized } from "../load/serializable.js";
import { Document } from "../documents/document.js";
import type { DocumentInterface } from "../documents/document.js";

type BaseCallbackManagerMethods = {
[K in keyof CallbackHandlerMethods]?: (
Expand Down Expand Up @@ -153,7 +153,7 @@ export class CallbackManagerForRetrieverRun
return manager;
}

async handleRetrieverEnd(documents: Document[]): Promise<void> {
async handleRetrieverEnd(documents: DocumentInterface[]): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand Down
11 changes: 10 additions & 1 deletion langchain-core/src/documents/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@ export interface DocumentInput<
metadata?: Metadata;
}

export interface DocumentInterface<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Metadata extends Record<string, any> = Record<string, any>
> {
pageContent: string;

metadata: Metadata;
}

/**
* Interface for interacting with a document.
*/
export class Document<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Metadata extends Record<string, any> = Record<string, any>
> implements DocumentInput
> implements DocumentInput, DocumentInterface
{
pageContent: string;

Expand Down
14 changes: 9 additions & 5 deletions langchain-core/src/documents/transformers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Runnable } from "../runnables/base.js";
import type { BaseCallbackConfig } from "../callbacks/manager.js";
import type { Document } from "./document.js";
import type { DocumentInterface } from "./document.js";

/**
* Abstract base class for document transformation systems.
Expand All @@ -13,8 +13,8 @@ import type { Document } from "./document.js";
* many smaller documents.
*/
export abstract class BaseDocumentTransformer<
RunInput extends Document[] = Document[],
RunOutput extends Document[] = Document[]
RunInput extends DocumentInterface[] = DocumentInterface[],
RunOutput extends DocumentInterface[] = DocumentInterface[]
> extends Runnable<RunInput, RunOutput> {
lc_namespace = ["langchain_core", "documents", "transformers"];

Expand Down Expand Up @@ -42,7 +42,9 @@ export abstract class BaseDocumentTransformer<
* for each input document.
*/
export abstract class MappingDocumentTransformer extends BaseDocumentTransformer {
async transformDocuments(documents: Document[]): Promise<Document[]> {
async transformDocuments(
documents: DocumentInterface[]
): Promise<DocumentInterface[]> {
const newDocuments = [];
for (const document of documents) {
const transformedDocument = await this._transformDocument(document);
Expand All @@ -51,5 +53,7 @@ export abstract class MappingDocumentTransformer extends BaseDocumentTransformer
return newDocuments;
}

abstract _transformDocument(document: Document): Promise<Document>;
abstract _transformDocument(
document: DocumentInterface
): Promise<DocumentInterface>;
}
21 changes: 20 additions & 1 deletion langchain-core/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,30 @@ import { AsyncCaller, AsyncCallerParams } from "./utils/async_caller.js";
*/
export type EmbeddingsParams = AsyncCallerParams;

export interface EmbeddingsInterface {
/**
* An abstract method that takes an array of documents as input and
* returns a promise that resolves to an array of vectors for each
* document.
* @param documents An array of documents to be embedded.
* @returns A promise that resolves to an array of vectors for each document.
*/
embedDocuments(documents: string[]): Promise<number[][]>;

/**
* An abstract method that takes a single document as input and returns a
* promise that resolves to a vector for the query document.
* @param document A single document to be embedded.
* @returns A promise that resolves to a vector for the query document.
*/
embedQuery(document: string): Promise<number[]>;
}

/**
* An abstract class that provides methods for embedding documents and
* queries using LangChain.
*/
export abstract class Embeddings {
export abstract class Embeddings implements EmbeddingsInterface {
/**
* The async caller should be used by subclasses to make any async calls,
* which will thus benefit from the concurrency and retry logic.
Expand Down
21 changes: 13 additions & 8 deletions langchain-core/src/example_selectors/conditional.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { BaseChatModel } from "../language_models/chat_models.js";
import type { BasePromptTemplate } from "../prompts/base.js";
import type { BaseLanguageModel } from "../language_models/base.js";
import type { BaseLanguageModelInterface } from "../language_models/base.js";
import type { BaseLLM } from "../language_models/llms.js";
import type { PartialValues } from "../utils/types.js";

Expand All @@ -20,7 +20,7 @@ export abstract class BasePromptSelector {
* @param llm The language model for which to get a prompt.
* @returns A prompt template.
*/
abstract getPrompt(llm: BaseLanguageModel): BasePromptTemplate;
abstract getPrompt(llm: BaseLanguageModelInterface): BasePromptTemplate;

/**
* Asynchronous version of `getPrompt` that also accepts an options object
Expand All @@ -30,7 +30,7 @@ export abstract class BasePromptSelector {
* @returns A Promise that resolves to a prompt template.
*/
async getPromptAsync(
llm: BaseLanguageModel,
llm: BaseLanguageModelInterface,
options?: BaseGetPromptAsyncOptions
): Promise<BasePromptTemplate> {
const prompt = this.getPrompt(llm);
Expand All @@ -47,14 +47,17 @@ export class ConditionalPromptSelector extends BasePromptSelector {
defaultPrompt: BasePromptTemplate;

conditionals: Array<
[condition: (llm: BaseLanguageModel) => boolean, prompt: BasePromptTemplate]
[
condition: (llm: BaseLanguageModelInterface) => boolean,
prompt: BasePromptTemplate
]
>;

constructor(
default_prompt: BasePromptTemplate,
conditionals: Array<
[
condition: (llm: BaseLanguageModel) => boolean,
condition: (llm: BaseLanguageModelInterface) => boolean,
prompt: BasePromptTemplate
]
> = []
Expand All @@ -70,7 +73,7 @@ export class ConditionalPromptSelector extends BasePromptSelector {
* @param llm The language model for which to get a prompt.
* @returns A prompt template.
*/
getPrompt(llm: BaseLanguageModel): BasePromptTemplate {
getPrompt(llm: BaseLanguageModelInterface): BasePromptTemplate {
for (const [condition, prompt] of this.conditionals) {
if (condition(llm)) {
return prompt;
Expand All @@ -84,14 +87,16 @@ export class ConditionalPromptSelector extends BasePromptSelector {
* Type guard function that checks if a given language model is of type
* `BaseLLM`.
*/
export function isLLM(llm: BaseLanguageModel): llm is BaseLLM {
export function isLLM(llm: BaseLanguageModelInterface): llm is BaseLLM {
return llm._modelType() === "base_llm";
}

/**
* Type guard function that checks if a given language model is of type
* `BaseChatModel`.
*/
export function isChatModel(llm: BaseLanguageModel): llm is BaseChatModel {
export function isChatModel(
llm: BaseLanguageModelInterface
): llm is BaseChatModel {
return llm._modelType() === "base_chat_model";
}
14 changes: 9 additions & 5 deletions langchain-core/src/example_selectors/semantic_similarity.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import type { Embeddings } from "../embeddings.js";
import type { VectorStore, VectorStoreRetriever } from "../vectorstores.js";
import type {
VectorStoreInterface,
VectorStoreRetrieverInterface,
VectorStore,
} from "../vectorstores.js";
import type { Example } from "../prompts/base.js";
import { Document } from "../documents/document.js";
import { BaseExampleSelector } from "./base.js";
Expand All @@ -15,7 +19,7 @@ function sortedValues<T>(values: Record<string, T>): T[] {
* class.
*/
export type SemanticSimilarityExampleSelectorInput<
V extends VectorStore = VectorStore
V extends VectorStoreInterface = VectorStoreInterface
> =
| {
vectorStore: V;
Expand All @@ -26,7 +30,7 @@ export type SemanticSimilarityExampleSelectorInput<
vectorStoreRetriever?: never;
}
| {
vectorStoreRetriever: VectorStoreRetriever<V>;
vectorStoreRetriever: VectorStoreRetrieverInterface<V>;
exampleKeys?: string[];
inputKeys?: string[];
vectorStore?: never;
Expand Down Expand Up @@ -64,9 +68,9 @@ export type SemanticSimilarityExampleSelectorInput<
* ```
*/
export class SemanticSimilarityExampleSelector<
V extends VectorStore = VectorStore
V extends VectorStoreInterface = VectorStoreInterface
> extends BaseExampleSelector {
vectorStoreRetriever: VectorStoreRetriever<V>;
vectorStoreRetriever: VectorStoreRetrieverInterface<V>;

exampleKeys?: string[];

Expand Down
Loading