Skip to content

Commit

Permalink
refactor: throw error if model does not support tools
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-goldman committed Jul 12, 2024
1 parent 0ab2506 commit 118411c
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 23 deletions.
4 changes: 2 additions & 2 deletions scripts/parallel-function-calls/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ async function runConversation() {
},
{
role: 'system',
content: "To reiterate, respond to the user's question, using tools when necessary.",
content:
"To reiterate, respond to the user's question, using tools when necessary.",
},
]
const tools: CompletionParams['tools'] = [
Expand Down Expand Up @@ -98,7 +99,6 @@ async function runConversation() {
}
messages.push(message)
}
messages.push({ content: 'Hi', role: 'user' })
const secondResponse = await client.chat.completions.create({
provider,
model,
Expand Down
23 changes: 22 additions & 1 deletion src/handlers/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@ export abstract class BaseHandler<T extends LLMChatModel> {
protected models: readonly T[]
protected supportsJSON: readonly T[]
protected supportsImages: readonly T[]
protected supportsToolCalls: readonly T[]

constructor(
opts: ConfigOptions,
models: readonly T[],
supportsJSON: readonly T[],
supportsImages: readonly T[]
supportsImages: readonly T[],
supportsToolCalls: readonly T[]
) {
this.opts = opts
this.models = models
this.supportsJSON = supportsJSON
this.supportsImages = supportsImages
this.supportsToolCalls = supportsToolCalls
}

abstract create(
Expand All @@ -33,6 +36,24 @@ export abstract class BaseHandler<T extends LLMChatModel> {
throw new InputError(`Invalid 'model' field: ${body.model}.`)
}

if (
body.tools !== undefined &&
!this.supportsToolCalls.includes(body.model)
) {
throw new InputError(
`Detected a 'tools' parameter, but the following model does not support tools: ${body.model}`
)
}

if (
body.tool_choice !== undefined &&
!this.supportsToolCalls.includes(body.model)
) {
throw new InputError(
`Detected a 'tool_choice' parameter, but the following model does not support tools: ${body.model}`
)
}

if (typeof body.temperature === 'number' && body.temperature > 2) {
throw new InputError(
`Expected a temperature less than or equal to 2, but got: ${body.temperature}`
Expand Down
4 changes: 0 additions & 4 deletions src/handlers/groq.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ export class GroqHandler extends BaseHandler<GroqModel> {
)
}
}

if (body.tools && body.tools?.length > 0) {
throw new InputError(`Groq does not support tools`)
}
}

async create(
Expand Down
27 changes: 18 additions & 9 deletions src/handlers/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,72 @@ export const Handlers: Record<string, (opts: ConfigOptions) => any> = {
opts,
models.openai.models,
models.openai.supportsJSON,
models.openai.supportsImages
models.openai.supportsImages,
models.openai.supportsToolCalls
),
['anthropic']: (opts: ConfigOptions) =>
new AnthropicHandler(
opts,
models.anthropic.models,
models.anthropic.supportsJSON,
models.anthropic.supportsImages
models.anthropic.supportsImages,
models.anthropic.supportsToolCalls
),
['gemini']: (opts: ConfigOptions) =>
new GeminiHandler(
opts,
models.gemini.models,
models.gemini.supportsJSON,
models.gemini.supportsImages
models.gemini.supportsImages,
models.gemini.supportsToolCalls
),
['cohere']: (opts: ConfigOptions) =>
new CohereHandler(
opts,
models.cohere.models,
models.cohere.supportsJSON,
models.cohere.supportsImages
models.cohere.supportsImages,
models.cohere.supportsToolCalls
),
['bedrock']: (opts: ConfigOptions) =>
new BedrockHandler(
opts,
models.bedrock.models,
models.bedrock.supportsJSON,
models.bedrock.supportsImages
models.bedrock.supportsImages,
models.bedrock.supportsToolCalls
),
['mistral']: (opts: ConfigOptions) =>
new MistralHandler(
opts,
models.mistral.models,
models.mistral.supportsJSON,
models.mistral.supportsImages
models.mistral.supportsImages,
models.mistral.supportsToolCalls
),
['groq']: (opts: ConfigOptions) =>
new GroqHandler(
opts,
models.groq.models,
models.groq.supportsJSON,
models.groq.supportsImages
models.groq.supportsImages,
models.groq.supportsToolCalls
),
['ai21']: (opts: ConfigOptions) =>
new AI21Handler(
opts,
models.ai21.models,
models.ai21.supportsJSON,
models.ai21.supportsImages
models.ai21.supportsImages,
models.ai21.supportsToolCalls
),
['perplexity']: (opts: ConfigOptions) =>
new PerplexityHandler(
opts,
models.perplexity.models,
models.perplexity.supportsJSON,
models.perplexity.supportsImages
models.perplexity.supportsImages,
models.perplexity.supportsToolCalls
),
}

Expand Down
50 changes: 50 additions & 0 deletions src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,27 @@ export const models = {
'gpt-4-1106-preview',
'gpt-4-vision-preview',
] as const,
supportsToolCalls: [
'gpt-4o',
'gpt-4o-2024-05-13',
'gpt-4-turbo',
'gpt-4-turbo-2024-04-09',
'gpt-4-turbo-preview',
'gpt-4-0125-preview',
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-0613',
'gpt-3.5-turbo',
'gpt-3.5-turbo-0125',
'gpt-3.5-turbo-1106',
'gpt-3.5-turbo-0613',
] as const,
},
ai21: {
models: ['jamba-instruct'] as const,
supportsJSON: [] as const,
supportsImages: [] as const,
supportsToolCalls: [] as const,
},
anthropic: {
models: [
Expand All @@ -69,11 +85,22 @@ export const models = {
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
] as const,
supportsToolCalls: [
'claude-3-5-sonnet-20240620',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
] as const,
},
gemini: {
models: ['gemini-1.5-pro', 'gemini-1.5-flash', 'gemini-1.0-pro'] as const,
supportsJSON: ['gemini-1.5-pro', 'gemini-1.5-flash'] as const,
supportsImages: ['gemini-1.5-pro', 'gemini-1.5-flash'] as const,
supportsToolCalls: [
'gemini-1.5-pro',
'gemini-1.5-flash',
'gemini-1.0-pro',
] as const,
},
cohere: {
models: [
Expand All @@ -86,6 +113,11 @@ export const models = {
] as const,
supportsJSON: [] as const,
supportsImages: [] as const,
supportsToolCalls: [
'command-r-plus',
'command-r',
'command-nightly',
] as const,
},
bedrock: {
models: [
Expand Down Expand Up @@ -119,6 +151,14 @@ export const models = {
'anthropic.claude-3-opus-20240229-v1:0',
'anthropic.claude-3-haiku-20240307-v1:0',
] as const,
supportsToolCalls: [
'anthropic.claude-3-opus-20240229-v1:0',
'anthropic.claude-3-sonnet-20240229-v1:0',
'anthropic.claude-3-haiku-20240307-v1:0',
'cohere.command-r-plus-v1:0',
'cohere.command-r-v1:0',
'mistral.mistral-large-2402-v1:0',
] as const,
},
mistral: {
models: [
Expand Down Expand Up @@ -152,6 +192,14 @@ export const models = {
'codestral-2405',
] as const,
supportsImages: [] as const,
supportsToolCalls: [
'open-mixtral-8x22b',
'open-mixtral-8x22b-2404',
'mistral-small-latest',
'mistral-small-2402',
'mistral-large-latest',
'mistral-large-2402',
] as const,
},
groq: {
models: [
Expand All @@ -167,6 +215,7 @@ export const models = {
// to ensure that we only support models that reliably produce decent results.
supportsJSON: ['llama3-70b-8192', 'gemma-7b-it', 'gemma2-9b-it'] as const,
supportsImages: [] as const,
supportsToolCalls: [] as const,
},
perplexity: {
models: [
Expand All @@ -180,5 +229,6 @@ export const models = {
] as const,
supportsJSON: [] as const,
supportsImages: [] as const,
supportsToolCalls: [] as const,
},
}
40 changes: 39 additions & 1 deletion test/handlers/base.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { describe, expect, it } from 'vitest'

import { TokenJS } from '../../src'
import { getDummyMessages, getDummyMessagesWithImage } from '../dummy'
import { InputError } from '../../src/handlers/types'
import {
getDummyMessages,
getDummyMessagesWithImage,
getDummyTool,
} from '../dummy'

describe('Base Handler', () => {
it('throws an error for a number greater than the max temperature', async () => {
Expand Down Expand Up @@ -32,4 +36,38 @@ describe('Base Handler', () => {
)
)
})

it("throws an error when 'tool_choice' parameter is present but the model doesn't support tools", async () => {
const tokenjs = new TokenJS()
await expect(
tokenjs.chat.completions.create({
provider: 'ai21',
model: 'jamba-instruct',
messages: getDummyMessages(),
temperature: 0.5,
tool_choice: 'auto',
})
).rejects.toThrow(
new InputError(
`Detected a 'tool_choice' parameter, but the following model does not support tools: jamba-instruct`
)
)
})

it("throws an error when 'tools' parameter is present but the model doesn't support tools", async () => {
const tokenjs = new TokenJS()
await expect(
tokenjs.chat.completions.create({
provider: 'ai21',
model: 'jamba-instruct',
messages: getDummyMessages(),
temperature: 0.5,
tools: [getDummyTool()],
})
).rejects.toThrow(
new InputError(
`Detected a 'tools' parameter, but the following model does not support tools: jamba-instruct`
)
)
})
})
9 changes: 6 additions & 3 deletions test/handlers/gemini.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,8 @@ describe('GeminiHandler', () => {
handlerOptions,
models.gemini.models,
models.gemini.supportsJSON,
models.gemini.supportsImages
models.gemini.supportsImages,
models.gemini.supportsToolCalls
)

;(GoogleGenerativeAI as any).mockImplementationOnce(() => ({
Expand Down Expand Up @@ -1454,7 +1455,8 @@ describe('GeminiHandler', () => {
handlerOptions,
models.gemini.models,
models.gemini.supportsJSON,
models.gemini.supportsImages
models.gemini.supportsImages,
models.gemini.supportsToolCalls
)

;(GoogleGenerativeAI as any).mockImplementationOnce(() => ({
Expand Down Expand Up @@ -1510,7 +1512,8 @@ describe('GeminiHandler', () => {
handlerOptions,
models.gemini.models,
models.gemini.supportsJSON,
models.gemini.supportsImages
models.gemini.supportsImages,
models.gemini.supportsToolCalls
)

;(GoogleGenerativeAI as any).mockImplementationOnce(() => ({
Expand Down
7 changes: 4 additions & 3 deletions test/handlers/mistral.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ describe('MistralHandler', () => {
id: '46d695c96b4545f4b842d4801632da3b',
object: 'chat.completion',
created: 1720551878,
model: 'open-mistral-7b',
model: 'open-mixtral-8x22b',
choices: [
{
index: 0,
Expand Down Expand Up @@ -747,7 +747,8 @@ describe('MistralHandler', () => {
handlerOptions,
models.mistral.models,
models.mistral.supportsJSON,
models.mistral.supportsImages
models.mistral.supportsImages,
models.mistral.supportsToolCalls
)

it('should return a completion response', async () => {
Expand All @@ -758,7 +759,7 @@ describe('MistralHandler', () => {

const params: CompletionParams = {
provider: 'mistral',
model: 'open-mistral-7b',
model: 'open-mixtral-8x22b',
messages: [
{
role: 'user',
Expand Down

0 comments on commit 118411c

Please sign in to comment.