Skip to content

Commit

Permalink
feat: Update custom custom chat models classes
Browse files Browse the repository at this point in the history
  • Loading branch information
zAlweNy26 committed Oct 1, 2024
1 parent c2bbee5 commit 8d03872
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 24 deletions.
Binary file modified bun.lockb
Binary file not shown.
10 changes: 5 additions & 5 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@
"@elysiajs/server-timing": "^1.1.0",
"@elysiajs/static": "^1.1.1",
"@elysiajs/stream": "^1.1.0",
"@elysiajs/swagger": "^1.1.3",
"@elysiajs/swagger": "^1.1.1",
"@langchain/anthropic": "^0.3.3",
"@langchain/azure-openai": "^0.0.11",
"@langchain/cohere": "^0.3.0",
"@langchain/community": "^0.3.4",
"@langchain/core": "^0.3.5",
"@langchain/community": "^0.3.3",
"@langchain/core": "^0.3.3",
"@langchain/google-genai": "^0.1.0",
"@langchain/mistralai": "0.1.1",
"@langchain/ollama": "0.1.0",
"@langchain/openai": "^0.3.4",
"@langchain/openai": "^0.3.2",
"@mgcrea/pino-pretty-compact": "^1.3.0",
"@qdrant/js-client-rest": "^1.11.0",
"callsites": "^4.2.0",
Expand Down Expand Up @@ -110,7 +110,7 @@
"vitest": "^2.1.1"
},
"overrides": {
"@langchain/core": "^0.3.5"
"@langchain/core": "^0.3.3"
},
"patchedDependencies": {
"[email protected]": "patches/[email protected]"
Expand Down
81 changes: 62 additions & 19 deletions src/factory/custom_llm.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import type { BaseLLMParams } from '@langchain/core/language_models/llms'
import type { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
import type { BaseMessage } from '@langchain/core/messages'
import type { ChatResult } from '@langchain/core/outputs'
import { join } from 'node:path'
import { BaseChatModel } from '@langchain/core/language_models/chat_models'
import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models'
import { ChatOllama } from '@langchain/ollama'
import { ChatOpenAI } from '@langchain/openai'
import { ofetch } from 'ofetch'

export class DefaultLLM extends BaseChatModel {
constructor(params?: BaseLLMParams) {
constructor(params?: BaseChatModelParams) {
super(params ?? {})
}

Expand All @@ -24,24 +25,66 @@ export class DefaultLLM extends BaseChatModel {
}
}

export class CustomOpenAILLM extends ChatOpenAI {
public url = ''
public openAIApiBase = ''

constructor(params?: ConstructorParameters<typeof ChatOpenAI>[0]) {
const modelKwargs = {
repeatPenalty: params?.modelKwargs?.repeatPenalty ?? 1.0,
topK: params?.modelKwargs?.topK ?? 40,
stop: params?.modelKwargs?.stop ?? [],
export class CustomLLM extends BaseChatModel {
private url!: string
private apiKey: string | undefined
private options: Record<string, any> = {}

constructor(params: BaseChatModelParams & { baseURL: string, apiKey?: string, options?: Record<string, any> }) {
const { baseURL, apiKey, options, ...rest } = params
super(rest)
this.url = baseURL
this.apiKey = apiKey
this.options = options ?? {}
}

async _generate(messages: BaseMessage[], _options: this['ParsedCallOptions'], _runManager?: CallbackManagerForLLMRun): Promise<ChatResult> {
const res = await ofetch<ChatResult>(this.url, {
method: 'POST',
body: {
messages,
apiKey: this.apiKey,
options: this.options,
},
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
},
})

return res
}

_llmType(): string {
return 'custom'
}

_identifyingParams(): Record<string, any> {
return {
url: this.url,
apiKey: this.apiKey,
options: this.options,
}
}
}

export class CustomOpenAILLM extends ChatOpenAI {
constructor(params: ConstructorParameters<typeof ChatOpenAI>[0] & { baseURL: string }) {
const { baseURL, ...args } = params
super(args, { baseURL })
}

_llmType(): string {
return 'custom'
}
}

export class CustomOllamaLLM extends ChatOllama {
constructor(params: Omit<NonNullable<ConstructorParameters<typeof ChatOllama>[0]>, 'baseUrl'> & { baseURL: string }) {
const { baseURL, ...args } = params
super({
openAIApiKey: ' ',
modelKwargs,
...params,
...args,
baseUrl: baseURL.endsWith('/') ? baseURL.slice(0, -1) : baseURL,
})

this.url = params?.modelKwargs?.url as string ?? ''
this.openAIApiBase = join(this.url, 'v1')
}
}

0 comments on commit 8d03872

Please sign in to comment.