Skip to content

Commit

Permalink
Merge pull request #6 from zya/deepinfra
Browse files Browse the repository at this point in the history
Feature: Add support for deep infra
  • Loading branch information
zya authored Nov 14, 2023
2 parents baa5c51 + 520847c commit d5ed468
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ We aim to support all features that [LiteLLM python package](https://github.com/
| [ollama](https://docs.litellm.ai/docs/providers/ollama) |||
| [ai21](https://docs.litellm.ai/docs/providers/ai21) |||
| [replicate](https://docs.litellm.ai/docs/providers/replicate) |||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) |||
| [huggingface](https://docs.litellm.ai/docs/providers/huggingface) |||
| [together_ai](https://docs.litellm.ai/docs/providers/togetherai) |||
| [openrouter](https://docs.litellm.ai/docs/providers/openrouter) |||
Expand All @@ -62,7 +63,6 @@ We aim to support all features that [LiteLLM python package](https://github.com/
| [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) |||
| [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) |||
| [petals](https://docs.litellm.ai/docs/providers/petals) |||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) |||

# Development

Expand Down
2 changes: 2 additions & 0 deletions src/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { OllamaHandler } from './handlers/ollama';
import { OpenAIHandler } from './handlers/openai';
import { AI21Handler } from './handlers/ai21';
import { ReplicateHandler } from './handlers/replicate';
import { DeepInfraHandler } from './handlers/deepinfra';

export const MODEL_HANDLER_MAPPINGS: Record<string, Handler> = {
'claude-': AnthropicHandler,
Expand All @@ -22,6 +23,7 @@ export const MODEL_HANDLER_MAPPINGS: Record<string, Handler> = {
'ollama/': OllamaHandler,
'j2-': AI21Handler,
'replicate/': ReplicateHandler,
'deepinfra/': DeepInfraHandler,
};

export async function completion(
Expand Down
85 changes: 85 additions & 0 deletions src/handlers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import { ChatCompletion } from 'openai/resources/chat';
import {
HandlerParams,
HandlerParamsNotStreaming,
HandlerParamsStreaming,
Message,
ResultNotStreaming,
ResultStreaming,
StreamingChunk,
} from '../types';

async function* iterateResponse(
response: Response,
): AsyncIterable<StreamingChunk> {
const reader = response.body?.getReader();
let done = false;

while (!done) {
const next = await reader?.read();
if (next?.value) {
done = next.done;
const decoded = new TextDecoder().decode(next.value);
if (decoded.startsWith('data: [DONE]')) {
done = true;
} else {
const [, value] = decoded.split('data: ');
yield JSON.parse(value);
}
} else {
done = true;
}
}
}

async function getDeepInfraResponse(
model: string,
messages: Message[],
baseUrl: string,
stream: boolean,
): Promise<Response> {
return fetch(`${baseUrl}/v1/openai/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.DEEPINFRA_API_KEY}`,
},
body: JSON.stringify({
messages,
model,
stream,
}),
});
}

export async function DeepInfraHandler(
params: HandlerParamsNotStreaming,
): Promise<ResultNotStreaming>;

export async function DeepInfraHandler(
params: HandlerParamsStreaming,
): Promise<ResultStreaming>;

export async function DeepInfraHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming>;

export async function DeepInfraHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const baseUrl = params.baseUrl ?? 'https://api.deepinfra.com';
const model = params.model.split('deepinfra/')[1];

const res = await getDeepInfraResponse(
model,
params.messages,
baseUrl,
params.stream ?? false,
);

if (params.stream) {
return iterateResponse(res);
}

return res.json() as Promise<ChatCompletion>;
}
14 changes: 9 additions & 5 deletions src/utils/toUsage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import { EmbeddingResponse } from '../embedding';
import { ConsistentResponseUsage } from '../types';
import { encoderCl100K } from './encoders';

export function countTokens(text: string): number {
return encoderCl100K.encode(text).length;
}

export function toUsage(
prompt: string,
completion: string | undefined,
Expand All @@ -10,12 +14,12 @@ export function toUsage(
return undefined;
}

const promptTokens = encoderCl100K.encode(prompt);
const completionTokens = encoderCl100K.encode(completion);
const promptTokens = countTokens(prompt);
const completionTokens = countTokens(completion);
return {
prompt_tokens: promptTokens.length,
completion_tokens: completionTokens.length,
total_tokens: promptTokens.concat(completionTokens).length,
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
};
}

Expand Down
2 changes: 2 additions & 0 deletions tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ describe('e2e', () => {
${'command-nightly'}
${'j2-light'}
${'replicate/meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3'}
${'deepinfra/mistralai/Mistral-7B-Instruct-v0.1'}
`(
'gets response from supported model $model',
async ({ model }) => {
Expand All @@ -37,6 +38,7 @@ describe('e2e', () => {
${'command-nightly'}
${'j2-light'}
${'replicate/meta/llama-2-7b-chat:ac944f2e49c55c7e965fc3d93ad9a7d9d947866d6793fb849dd6b4747d0c061c'}
${'deepinfra/mistralai/Mistral-7B-Instruct-v0.1'}
`(
'gets streaming response from supported model $model',
async ({ model }) => {
Expand Down
2 changes: 2 additions & 0 deletions tests/handlers/getHandler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { MODEL_HANDLER_MAPPINGS } from '../../src/completion';
import { AI21Handler } from '../../src/handlers/ai21';
import { AnthropicHandler } from '../../src/handlers/anthropic';
import { CohereHandler } from '../../src/handlers/cohere';
import { DeepInfraHandler } from '../../src/handlers/deepinfra';
import { getHandler } from '../../src/handlers/getHandler';
import { OllamaHandler } from '../../src/handlers/ollama';
import { OpenAIHandler } from '../../src/handlers/openai';
Expand All @@ -21,6 +22,7 @@ describe('getHandler', () => {
{ model: 'j2-mid-instruct', expectedHandler: AI21Handler },
{ model: 'j2-ultra-instruct', expectedHandler: AI21Handler },
{ model: 'replicate/test/test', expectedHandler: ReplicateHandler },
{ model: 'deepinfra/test/test', expectedHandler: DeepInfraHandler },
{ model: 'unknown', expectedHandler: null },
])(
'should return the correct handler for a given model name',
Expand Down

0 comments on commit d5ed468

Please sign in to comment.