Skip to content

Commit

Permalink
Feature: Add support for AI21 models
Browse files Browse the repository at this point in the history
  • Loading branch information
zya committed Oct 23, 2023
1 parent a68af75 commit d6202c7
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .example.env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
OPENAI_API_KEY=<Your OpenAI API key>
COHERE_API_KEY=<Your Cohere API key>
ANTHROPIC_API_KEY=<Your Anthropic API key>
ANTHROPIC_API_KEY=<Your Anthropic API key>
AI21_API_KEY=<Your Anthropic API key>
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ We aim to support all features that [LiteLLM python package](https://github.com/
| [cohere](https://docs.litellm.ai/docs/providers/cohere) |||
| [anthropic](https://docs.litellm.ai/docs/providers/anthropic) |||
| [ollama](https://docs.litellm.ai/docs/providers/ollama) |||
| [ai21](https://docs.litellm.ai/docs/providers/ai21) |||
| [replicate](https://docs.litellm.ai/docs/providers/replicate) |||
| [huggingface](https://docs.litellm.ai/docs/providers/huggingface) |||
| [together_ai](https://docs.litellm.ai/docs/providers/togetherai) |||
| [openrouter](https://docs.litellm.ai/docs/providers/openrouter) |||
| [vertex_ai](https://docs.litellm.ai/docs/providers/vertex) |||
| [palm](https://docs.litellm.ai/docs/providers/palm) |||
| [ai21](https://docs.litellm.ai/docs/providers/ai21) |||
| [baseten](https://docs.litellm.ai/docs/providers/baseten) |||
| [azure](https://docs.litellm.ai/docs/providers/azure) |||
| [sagemaker](https://docs.litellm.ai/docs/providers/aws_sagemaker) |||
Expand Down
2 changes: 2 additions & 0 deletions src/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import { AnthropicHandler } from './handlers/anthropic';
import { CohereHandler } from './handlers/cohere';
import { OllamaHandler } from './handlers/ollama';
import { OpenAIHandler } from './handlers/openai';
import { AI21Handler } from './handlers/ai21';

export const MODEL_HANDLER_MAPPINGS: Record<string, Handler> = {
'claude-2': AnthropicHandler,
'gpt-': OpenAIHandler,
command: CohereHandler,
'ollama/': OllamaHandler,
'j2-': AI21Handler,
};

export async function completion(
Expand Down
151 changes: 151 additions & 0 deletions src/handlers/ai21.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import {
ConsistentResponseUsage,
FinishReason,
HandlerParams,
HandlerParamsNotStreaming,
HandlerParamsStreaming,
ResultNotStreaming,
ResultStreaming,
Role,
StreamingChunk,
} from '../types';
import { combinePrompts } from '../utils/combinePrompts';
import { getUnixTimestamp } from '../utils/getUnixTimestamp';

const FINISH_REASON_MAP: Record<string, FinishReason> = {
length: 'length',
endoftext: 'stop',
};

interface AI21GeneratedToken {
generatedToken: {
token: string;
logprob: number;
raw_logprob: number;
};
}

interface AI21Response {
id: string;
prompt: {
text: string;
tokens: AI21GeneratedToken[];
};
completions: {
finishReason: {
reason: string;
};
data: {
text: string;
tokens: AI21GeneratedToken[];
};
}[];
}

function toUsage(response: AI21Response): ConsistentResponseUsage {
const promptTokens = response.prompt.tokens.length;
const completionTokens = response.completions.reduce((acc, completion) => {
return acc + completion.data.tokens.length;
}, 0);
return {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
};
}

// eslint-disable-next-line @typescript-eslint/require-await
async function* toStream(
response: AI21Response,
model: string,
): AsyncIterable<StreamingChunk> {
yield {
model: model,
created: getUnixTimestamp(),
usage: toUsage(response),
choices: [
{
delta: {
content: response.completions[0].data.text,
role: 'assistant',
},
finish_reason:
FINISH_REASON_MAP[response.completions[0].finishReason.reason] ??
'stop',
index: 0,
},
],
};
}

function toResponse(response: AI21Response, model: string): ResultNotStreaming {
const choices = response.completions.map((completion, i) => {
return {
finish_reason:
FINISH_REASON_MAP[completion.finishReason.reason] ?? 'stop',
index: i,
message: {
content: completion.data.text,
role: 'assistant' as Role,
},
};
});
return {
model: model,
created: getUnixTimestamp(),
usage: toUsage(response),
choices: choices,
};
}

async function getAI21Response(
model: string,
prompt: string,
baseUrl: string,
): Promise<Response> {
return fetch(`${baseUrl}/studio/v1/${model}/complete`, {
method: 'POST',
headers: {
Authorization: `Bearer ${process.env.AI21_API_KEY}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
body: JSON.stringify({
prompt,
}),
});
}

export async function AI21Handler(
params: HandlerParamsNotStreaming,
): Promise<ResultNotStreaming>;

export async function AI21Handler(
params: HandlerParamsStreaming,
): Promise<ResultStreaming>;

export async function AI21Handler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming>;

export async function AI21Handler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const baseUrl = params.baseUrl ?? 'https://api.ai21.com';
const model = params.model;
const prompt = combinePrompts(params.messages);

const res = await getAI21Response(model, prompt, baseUrl);

if (!res.ok) {
throw new Error(`Recieved an error with code ${res.status} from AI21 API.`);
}

const body = (await res.json()) as AI21Response;

if (params.stream) {
return toStream(body, model);
}

return toResponse(body, model);
}
7 changes: 3 additions & 4 deletions src/handlers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ async function getOllamaResponse(
): Promise<Response> {
return fetch(`${baseUrl}/api/generate`, {
method: 'POST',

headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model,
prompt,
headers: {
'Content-Type': 'application/json',
},
}),
});
}
Expand Down
3 changes: 3 additions & 0 deletions tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ const MODELS = [
{
model: 'command-nightly',
},
{
model: 'j2-light',
},
];

const EMBEDDING_MODELS = [
Expand Down
7 changes: 7 additions & 0 deletions tests/handlers/getHandler.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { MODEL_HANDLER_MAPPINGS } from '../../src/completion';
import { AI21Handler } from '../../src/handlers/ai21';
import { AnthropicHandler } from '../../src/handlers/anthropic';
import { CohereHandler } from '../../src/handlers/cohere';
import { getHandler } from '../../src/handlers/getHandler';
Expand All @@ -11,6 +12,12 @@ describe('getHandler', () => {
{ model: 'gpt-3.5-turbo', expectedHandler: OpenAIHandler },
{ model: 'ollama/llama2', expectedHandler: OllamaHandler },
{ model: 'command-nightly', expectedHandler: CohereHandler },
{ model: 'j2-light', expectedHandler: AI21Handler },
{ model: 'j2-mid', expectedHandler: AI21Handler },
{ model: 'j2-ultra', expectedHandler: AI21Handler },
{ model: 'j2-grande-instruct', expectedHandler: AI21Handler },
{ model: 'j2-mid-instruct', expectedHandler: AI21Handler },
{ model: 'j2-ultra-instruct', expectedHandler: AI21Handler },
{ model: 'unknown', expectedHandler: null },
])(
'should return the correct handler for a given model name',
Expand Down

0 comments on commit d6202c7

Please sign in to comment.