Skip to content

Commit

Permalink
refactor: throw error if model does not support tools
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-goldman committed Jul 12, 2024
1 parent 0ab2506 commit 93909a0
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 49 deletions.
5 changes: 2 additions & 3 deletions scripts/parallel-function-calls/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { ChatCompletionToolMessageParam } from 'openai/resources/index.mjs'

import { CompletionParams } from '../../src/chat'
import { TokenJS } from '../../src/index'
import { models } from '../../src/models'
import { SUPPORTED_MODELS } from '../../src/models'
import { getCurrentWeather } from './utils'

dotenv.config()
Expand All @@ -27,7 +27,7 @@ async function runConversation() {
)
}

const model: CompletionParams['model'] = models[provider].models[0]
const model: CompletionParams['model'] = SUPPORTED_MODELS[provider].models[0]

const messages: CompletionParams['messages'] = [
{
Expand Down Expand Up @@ -98,7 +98,6 @@ async function runConversation() {
}
messages.push(message)
}
messages.push({ content: 'Hi', role: 'user' })
const secondResponse = await client.chat.completions.create({
provider,
model,
Expand Down
23 changes: 12 additions & 11 deletions src/chat/index.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import { ChatCompletionCreateParamsBase } from 'openai/resources/chat/completions'

import { getHandler } from '../handlers/utils'
import { models } from '../models'
import { SUPPORTED_MODELS } from '../models'
import {
CompletionResponse,
ConfigOptions,
StreamCompletionResponse,
} from '../userTypes'

export type OpenAIModel = (typeof models.openai.models)[number]
export type AI21Model = (typeof models.ai21.models)[number]
export type AnthropicModel = (typeof models.anthropic.models)[number]
export type GeminiModel = (typeof models.gemini.models)[number]
export type CohereModel = (typeof models.cohere.models)[number]
export type BedrockModel = (typeof models.bedrock.models)[number]
export type MistralModel = (typeof models.mistral.models)[number]
export type PerplexityModel = (typeof models.perplexity.models)[number]
export type GroqModel = (typeof models.groq.models)[number]
export type OpenAIModel = (typeof SUPPORTED_MODELS.openai.models)[number]
export type AI21Model = (typeof SUPPORTED_MODELS.ai21.models)[number]
export type AnthropicModel = (typeof SUPPORTED_MODELS.anthropic.models)[number]
export type GeminiModel = (typeof SUPPORTED_MODELS.gemini.models)[number]
export type CohereModel = (typeof SUPPORTED_MODELS.cohere.models)[number]
export type BedrockModel = (typeof SUPPORTED_MODELS.bedrock.models)[number]
export type MistralModel = (typeof SUPPORTED_MODELS.mistral.models)[number]
export type PerplexityModel =
(typeof SUPPORTED_MODELS.perplexity.models)[number]
export type GroqModel = (typeof SUPPORTED_MODELS.groq.models)[number]

export type LLMChatModel =
| OpenAIModel
Expand All @@ -29,7 +30,7 @@ export type LLMChatModel =
| PerplexityModel
| GroqModel

export type LLMProvider = keyof typeof models
export type LLMProvider = keyof typeof SUPPORTED_MODELS

type ProviderModelMap = {
openai: OpenAIModel
Expand Down
23 changes: 22 additions & 1 deletion src/handlers/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@ export abstract class BaseHandler<T extends LLMChatModel> {
protected models: readonly T[]
protected supportsJSON: readonly T[]
protected supportsImages: readonly T[]
protected supportsToolCalls: readonly T[]

constructor(
opts: ConfigOptions,
models: readonly T[],
supportsJSON: readonly T[],
supportsImages: readonly T[]
supportsImages: readonly T[],
supportsToolCalls: readonly T[]
) {
this.opts = opts
this.models = models
this.supportsJSON = supportsJSON
this.supportsImages = supportsImages
this.supportsToolCalls = supportsToolCalls
}

abstract create(
Expand All @@ -33,6 +36,24 @@ export abstract class BaseHandler<T extends LLMChatModel> {
throw new InputError(`Invalid 'model' field: ${body.model}.`)
}

if (
body.tools !== undefined &&
!this.supportsToolCalls.includes(body.model)
) {
throw new InputError(
`Detected a 'tools' parameter, but the following model does not support tools: ${body.model}`
)
}

if (
body.tool_choice !== undefined &&
!this.supportsToolCalls.includes(body.model)
) {
throw new InputError(
`Detected a 'tool_choice' parameter, but the following model does not support tools: ${body.model}`
)
}

if (typeof body.temperature === 'number' && body.temperature > 2) {
throw new InputError(
`Expected a temperature less than or equal to 2, but got: ${body.temperature}`
Expand Down
4 changes: 0 additions & 4 deletions src/handlers/groq.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ export class GroqHandler extends BaseHandler<GroqModel> {
)
}
}

if (body.tools && body.tools?.length > 0) {
throw new InputError(`Groq does not support tools`)
}
}

async create(
Expand Down
65 changes: 37 additions & 28 deletions src/handlers/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import chalk from 'chalk'
import { lookup } from 'mime-types'

import { LLMChatModel, LLMProvider } from '../chat'
import { models } from '../models'
import { SUPPORTED_MODELS } from '../models'
import { ConfigOptions } from '../userTypes'
import { AI21Handler } from './ai21'
import { AnthropicHandler } from './anthropic'
Expand All @@ -21,65 +21,74 @@ export const Handlers: Record<string, (opts: ConfigOptions) => any> = {
['openai']: (opts: ConfigOptions) =>
new OpenAIHandler(
opts,
models.openai.models,
models.openai.supportsJSON,
models.openai.supportsImages
SUPPORTED_MODELS.openai.models,
SUPPORTED_MODELS.openai.supportsJSON,
SUPPORTED_MODELS.openai.supportsImages,
SUPPORTED_MODELS.openai.supportsToolCalls
),
['anthropic']: (opts: ConfigOptions) =>
new AnthropicHandler(
opts,
models.anthropic.models,
models.anthropic.supportsJSON,
models.anthropic.supportsImages
SUPPORTED_MODELS.anthropic.models,
SUPPORTED_MODELS.anthropic.supportsJSON,
SUPPORTED_MODELS.anthropic.supportsImages,
SUPPORTED_MODELS.anthropic.supportsToolCalls
),
['gemini']: (opts: ConfigOptions) =>
new GeminiHandler(
opts,
models.gemini.models,
models.gemini.supportsJSON,
models.gemini.supportsImages
SUPPORTED_MODELS.gemini.models,
SUPPORTED_MODELS.gemini.supportsJSON,
SUPPORTED_MODELS.gemini.supportsImages,
SUPPORTED_MODELS.gemini.supportsToolCalls
),
['cohere']: (opts: ConfigOptions) =>
new CohereHandler(
opts,
models.cohere.models,
models.cohere.supportsJSON,
models.cohere.supportsImages
SUPPORTED_MODELS.cohere.models,
SUPPORTED_MODELS.cohere.supportsJSON,
SUPPORTED_MODELS.cohere.supportsImages,
SUPPORTED_MODELS.cohere.supportsToolCalls
),
['bedrock']: (opts: ConfigOptions) =>
new BedrockHandler(
opts,
models.bedrock.models,
models.bedrock.supportsJSON,
models.bedrock.supportsImages
SUPPORTED_MODELS.bedrock.models,
SUPPORTED_MODELS.bedrock.supportsJSON,
SUPPORTED_MODELS.bedrock.supportsImages,
SUPPORTED_MODELS.bedrock.supportsToolCalls
),
['mistral']: (opts: ConfigOptions) =>
new MistralHandler(
opts,
models.mistral.models,
models.mistral.supportsJSON,
models.mistral.supportsImages
SUPPORTED_MODELS.mistral.models,
SUPPORTED_MODELS.mistral.supportsJSON,
SUPPORTED_MODELS.mistral.supportsImages,
SUPPORTED_MODELS.mistral.supportsToolCalls
),
['groq']: (opts: ConfigOptions) =>
new GroqHandler(
opts,
models.groq.models,
models.groq.supportsJSON,
models.groq.supportsImages
SUPPORTED_MODELS.groq.models,
SUPPORTED_MODELS.groq.supportsJSON,
SUPPORTED_MODELS.groq.supportsImages,
SUPPORTED_MODELS.groq.supportsToolCalls
),
['ai21']: (opts: ConfigOptions) =>
new AI21Handler(
opts,
models.ai21.models,
models.ai21.supportsJSON,
models.ai21.supportsImages
SUPPORTED_MODELS.ai21.models,
SUPPORTED_MODELS.ai21.supportsJSON,
SUPPORTED_MODELS.ai21.supportsImages,
SUPPORTED_MODELS.ai21.supportsToolCalls
),
['perplexity']: (opts: ConfigOptions) =>
new PerplexityHandler(
opts,
models.perplexity.models,
models.perplexity.supportsJSON,
models.perplexity.supportsImages
SUPPORTED_MODELS.perplexity.models,
SUPPORTED_MODELS.perplexity.supportsJSON,
SUPPORTED_MODELS.perplexity.supportsImages,
SUPPORTED_MODELS.perplexity.supportsToolCalls
),
}

Expand Down
52 changes: 51 additions & 1 deletion src/models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const models = {
export const SUPPORTED_MODELS = {
openai: {
models: [
'gpt-4o',
Expand Down Expand Up @@ -46,11 +46,27 @@ export const models = {
'gpt-4-1106-preview',
'gpt-4-vision-preview',
] as const,
supportsToolCalls: [
'gpt-4o',
'gpt-4o-2024-05-13',
'gpt-4-turbo',
'gpt-4-turbo-2024-04-09',
'gpt-4-turbo-preview',
'gpt-4-0125-preview',
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-0613',
'gpt-3.5-turbo',
'gpt-3.5-turbo-0125',
'gpt-3.5-turbo-1106',
'gpt-3.5-turbo-0613',
] as const,
},
ai21: {
models: ['jamba-instruct'] as const,
supportsJSON: [] as const,
supportsImages: [] as const,
supportsToolCalls: [] as const,
},
anthropic: {
models: [
Expand All @@ -69,11 +85,22 @@ export const models = {
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
] as const,
supportsToolCalls: [
'claude-3-5-sonnet-20240620',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
] as const,
},
gemini: {
models: ['gemini-1.5-pro', 'gemini-1.5-flash', 'gemini-1.0-pro'] as const,
supportsJSON: ['gemini-1.5-pro', 'gemini-1.5-flash'] as const,
supportsImages: ['gemini-1.5-pro', 'gemini-1.5-flash'] as const,
supportsToolCalls: [
'gemini-1.5-pro',
'gemini-1.5-flash',
'gemini-1.0-pro',
] as const,
},
cohere: {
models: [
Expand All @@ -86,6 +113,11 @@ export const models = {
] as const,
supportsJSON: [] as const,
supportsImages: [] as const,
supportsToolCalls: [
'command-r-plus',
'command-r',
'command-nightly',
] as const,
},
bedrock: {
models: [
Expand Down Expand Up @@ -119,6 +151,14 @@ export const models = {
'anthropic.claude-3-opus-20240229-v1:0',
'anthropic.claude-3-haiku-20240307-v1:0',
] as const,
supportsToolCalls: [
'anthropic.claude-3-opus-20240229-v1:0',
'anthropic.claude-3-sonnet-20240229-v1:0',
'anthropic.claude-3-haiku-20240307-v1:0',
'cohere.command-r-plus-v1:0',
'cohere.command-r-v1:0',
'mistral.mistral-large-2402-v1:0',
] as const,
},
mistral: {
models: [
Expand Down Expand Up @@ -152,6 +192,14 @@ export const models = {
'codestral-2405',
] as const,
supportsImages: [] as const,
supportsToolCalls: [
'open-mixtral-8x22b',
'open-mixtral-8x22b-2404',
'mistral-small-latest',
'mistral-small-2402',
'mistral-large-latest',
'mistral-large-2402',
] as const,
},
groq: {
models: [
Expand All @@ -167,6 +215,7 @@ export const models = {
// to ensure that we only support models that reliably produce decent results.
supportsJSON: ['llama3-70b-8192', 'gemma-7b-it', 'gemma2-9b-it'] as const,
supportsImages: [] as const,
supportsToolCalls: [] as const,
},
perplexity: {
models: [
Expand All @@ -180,5 +229,6 @@ export const models = {
] as const,
supportsJSON: [] as const,
supportsImages: [] as const,
supportsToolCalls: [] as const,
},
}
Loading

0 comments on commit 93909a0

Please sign in to comment.