-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: Add support for AI21 models
- Loading branch information
Showing
7 changed files
with
169 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
OPENAI_API_KEY=<Your OpenAI API key> | ||
COHERE_API_KEY=<Your Cohere API key> | ||
ANTHROPIC_API_KEY=<Your Anthropic API key> | ||
ANTHROPIC_API_KEY=<Your Anthropic API key> | ||
AI21_API_KEY=<Your Anthropic API key> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import { | ||
ConsistentResponseUsage, | ||
FinishReason, | ||
HandlerParams, | ||
HandlerParamsNotStreaming, | ||
HandlerParamsStreaming, | ||
ResultNotStreaming, | ||
ResultStreaming, | ||
Role, | ||
StreamingChunk, | ||
} from '../types'; | ||
import { combinePrompts } from '../utils/combinePrompts'; | ||
import { getUnixTimestamp } from '../utils/getUnixTimestamp'; | ||
|
||
const FINISH_REASON_MAP: Record<string, FinishReason> = { | ||
length: 'length', | ||
endoftext: 'stop', | ||
}; | ||
|
||
interface AI21GeneratedToken { | ||
generatedToken: { | ||
token: string; | ||
logprob: number; | ||
raw_logprob: number; | ||
}; | ||
} | ||
|
||
interface AI21Response { | ||
id: string; | ||
prompt: { | ||
text: string; | ||
tokens: AI21GeneratedToken[]; | ||
}; | ||
completions: { | ||
finishReason: { | ||
reason: string; | ||
}; | ||
data: { | ||
text: string; | ||
tokens: AI21GeneratedToken[]; | ||
}; | ||
}[]; | ||
} | ||
|
||
function toUsage(response: AI21Response): ConsistentResponseUsage { | ||
const promptTokens = response.prompt.tokens.length; | ||
const completionTokens = response.completions.reduce((acc, completion) => { | ||
return acc + completion.data.tokens.length; | ||
}, 0); | ||
return { | ||
prompt_tokens: promptTokens, | ||
completion_tokens: completionTokens, | ||
total_tokens: promptTokens + completionTokens, | ||
}; | ||
} | ||
|
||
// eslint-disable-next-line @typescript-eslint/require-await | ||
async function* toStream( | ||
response: AI21Response, | ||
model: string, | ||
): AsyncIterable<StreamingChunk> { | ||
yield { | ||
model: model, | ||
created: getUnixTimestamp(), | ||
usage: toUsage(response), | ||
choices: [ | ||
{ | ||
delta: { | ||
content: response.completions[0].data.text, | ||
role: 'assistant', | ||
}, | ||
finish_reason: | ||
FINISH_REASON_MAP[response.completions[0].finishReason.reason] ?? | ||
'stop', | ||
index: 0, | ||
}, | ||
], | ||
}; | ||
} | ||
|
||
function toResponse(response: AI21Response, model: string): ResultNotStreaming { | ||
const choices = response.completions.map((completion, i) => { | ||
return { | ||
finish_reason: | ||
FINISH_REASON_MAP[completion.finishReason.reason] ?? 'stop', | ||
index: i, | ||
message: { | ||
content: completion.data.text, | ||
role: 'assistant' as Role, | ||
}, | ||
}; | ||
}); | ||
return { | ||
model: model, | ||
created: getUnixTimestamp(), | ||
usage: toUsage(response), | ||
choices: choices, | ||
}; | ||
} | ||
|
||
async function getAI21Response( | ||
model: string, | ||
prompt: string, | ||
baseUrl: string, | ||
): Promise<Response> { | ||
return fetch(`${baseUrl}/studio/v1/${model}/complete`, { | ||
method: 'POST', | ||
headers: { | ||
Authorization: `Bearer ${process.env.AI21_API_KEY}`, | ||
'Content-Type': 'application/json', | ||
accept: 'application/json', | ||
}, | ||
body: JSON.stringify({ | ||
prompt, | ||
}), | ||
}); | ||
} | ||
|
||
export async function AI21Handler( | ||
params: HandlerParamsNotStreaming, | ||
): Promise<ResultNotStreaming>; | ||
|
||
export async function AI21Handler( | ||
params: HandlerParamsStreaming, | ||
): Promise<ResultStreaming>; | ||
|
||
export async function AI21Handler( | ||
params: HandlerParams, | ||
): Promise<ResultNotStreaming | ResultStreaming>; | ||
|
||
export async function AI21Handler( | ||
params: HandlerParams, | ||
): Promise<ResultNotStreaming | ResultStreaming> { | ||
const baseUrl = params.baseUrl ?? 'https://api.ai21.com'; | ||
const model = params.model; | ||
const prompt = combinePrompts(params.messages); | ||
|
||
const res = await getAI21Response(model, prompt, baseUrl); | ||
|
||
if (!res.ok) { | ||
throw new Error(`Recieved an error with code ${res.status} from AI21 API.`); | ||
} | ||
|
||
const body = (await res.json()) as AI21Response; | ||
|
||
if (params.stream) { | ||
return toStream(body, model); | ||
} | ||
|
||
return toResponse(body, model); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters