From 520847cfbc5bb3feac3ca6cf31ddb632ce53e47d Mon Sep 17 00:00:00 2001 From: Ehsan Ziya Date: Tue, 14 Nov 2023 16:11:55 +0000 Subject: [PATCH] Feature: Add support for deep infra --- README.md | 2 +- src/completion.ts | 2 + src/handlers/deepinfra.ts | 85 +++++++++++++++++++++++++++++++ src/utils/toUsage.ts | 14 +++-- tests/e2e.test.ts | 2 + tests/handlers/getHandler.test.ts | 2 + 6 files changed, 101 insertions(+), 6 deletions(-) create mode 100644 src/handlers/deepinfra.ts diff --git a/README.md b/README.md index 9a28937..80e73e0 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ We aim to support all features that [LiteLLM python package](https://github.com/ | [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | | [ai21](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | | [replicate](https://docs.litellm.ai/docs/providers/replicate) | ✅ | ✅ | +| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | | [huggingface](https://docs.litellm.ai/docs/providers/huggingface) | ❌ | ❌ | | [together_ai](https://docs.litellm.ai/docs/providers/togetherai) | ❌ | ❌ | | [openrouter](https://docs.litellm.ai/docs/providers/openrouter) | ❌ | ❌ | @@ -62,7 +63,6 @@ We aim to support all features that [LiteLLM python package](https://github.com/ | [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ❌ | ❌ | | [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ❌ | ❌ | | [petals](https://docs.litellm.ai/docs/providers/petals) | ❌ | ❌ | -| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ❌ | ❌ | # Development diff --git a/src/completion.ts b/src/completion.ts index 5820574..9558ce6 100644 --- a/src/completion.ts +++ b/src/completion.ts @@ -14,6 +14,7 @@ import { OllamaHandler } from './handlers/ollama'; import { OpenAIHandler } from './handlers/openai'; import { AI21Handler } from './handlers/ai21'; import { ReplicateHandler } from './handlers/replicate'; +import { DeepInfraHandler } from './handlers/deepinfra'; export const MODEL_HANDLER_MAPPINGS: Record = { 'claude-': AnthropicHandler, @@ -22,6 +23,7 @@ export const MODEL_HANDLER_MAPPINGS: Record = { 'ollama/': OllamaHandler, 'j2-': AI21Handler, 'replicate/': ReplicateHandler, + 'deepinfra/': DeepInfraHandler, }; export async function completion( diff --git a/src/handlers/deepinfra.ts b/src/handlers/deepinfra.ts new file mode 100644 index 0000000..ccf1c54 --- /dev/null +++ b/src/handlers/deepinfra.ts @@ -0,0 +1,85 @@ +import { ChatCompletion } from 'openai/resources/chat'; +import { + HandlerParams, + HandlerParamsNotStreaming, + HandlerParamsStreaming, + Message, + ResultNotStreaming, + ResultStreaming, + StreamingChunk, +} from '../types'; + +async function* iterateResponse( + response: Response, +): AsyncIterable { + const reader = response.body?.getReader(); + let done = false; + + while (!done) { + const next = await reader?.read(); + if (next?.value) { + done = next.done; + const decoded = new TextDecoder().decode(next.value); + if (decoded.startsWith('data: [DONE]')) { + done = true; + } else { + const [, value] = decoded.split('data: '); + yield JSON.parse(value); + } + } else { + done = true; + } + } +} + +async function getDeepInfraResponse( + model: string, + messages: Message[], + baseUrl: string, + stream: boolean, +): Promise { + return fetch(`${baseUrl}/v1/openai/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${process.env.DEEPINFRA_API_KEY}`, + }, + body: JSON.stringify({ + messages, + model, + stream, + }), + }); +} + +export async function DeepInfraHandler( + params: HandlerParamsNotStreaming, +): Promise; + +export async function DeepInfraHandler( + params: HandlerParamsStreaming, +): Promise; + +export async function DeepInfraHandler( + params: HandlerParams, +): Promise; + +export async function DeepInfraHandler( + params: HandlerParams, +): Promise { + const baseUrl = params.baseUrl ?? 'https://api.deepinfra.com'; + const model = params.model.split('deepinfra/')[1]; + + const res = await getDeepInfraResponse( + model, + params.messages, + baseUrl, + params.stream ?? false, + ); + + if (params.stream) { + return iterateResponse(res); + } + + return res.json() as Promise; +} diff --git a/src/utils/toUsage.ts b/src/utils/toUsage.ts index c8d1f71..dbfc9ac 100644 --- a/src/utils/toUsage.ts +++ b/src/utils/toUsage.ts @@ -2,6 +2,10 @@ import { EmbeddingResponse } from '../embedding'; import { ConsistentResponseUsage } from '../types'; import { encoderCl100K } from './encoders'; +export function countTokens(text: string): number { + return encoderCl100K.encode(text).length; +} + export function toUsage( prompt: string, completion: string | undefined, @@ -10,12 +14,12 @@ export function toUsage( return undefined; } - const promptTokens = encoderCl100K.encode(prompt); - const completionTokens = encoderCl100K.encode(completion); + const promptTokens = countTokens(prompt); + const completionTokens = countTokens(completion); return { - prompt_tokens: promptTokens.length, - completion_tokens: completionTokens.length, - total_tokens: promptTokens.concat(completionTokens).length, + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens, }; } diff --git a/tests/e2e.test.ts b/tests/e2e.test.ts index e85a8ef..aca4969 100644 --- a/tests/e2e.test.ts +++ b/tests/e2e.test.ts @@ -16,6 +16,7 @@ describe('e2e', () => { ${'command-nightly'} ${'j2-light'} ${'replicate/meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3'} + ${'deepinfra/mistralai/Mistral-7B-Instruct-v0.1'} `( 'gets response from supported model $model', async ({ model }) => { @@ -37,6 +38,7 @@ describe('e2e', () => { ${'command-nightly'} ${'j2-light'} ${'replicate/meta/llama-2-7b-chat:ac944f2e49c55c7e965fc3d93ad9a7d9d947866d6793fb849dd6b4747d0c061c'} + ${'deepinfra/mistralai/Mistral-7B-Instruct-v0.1'} `( 'gets streaming response from supported model $model', async ({ model }) => { diff --git a/tests/handlers/getHandler.test.ts b/tests/handlers/getHandler.test.ts index c57acd3..3dcbb38 100644 --- a/tests/handlers/getHandler.test.ts +++ b/tests/handlers/getHandler.test.ts @@ -2,6 +2,7 @@ import { MODEL_HANDLER_MAPPINGS } from '../../src/completion'; import { AI21Handler } from '../../src/handlers/ai21'; import { AnthropicHandler } from '../../src/handlers/anthropic'; import { CohereHandler } from '../../src/handlers/cohere'; +import { DeepInfraHandler } from '../../src/handlers/deepinfra'; import { getHandler } from '../../src/handlers/getHandler'; import { OllamaHandler } from '../../src/handlers/ollama'; import { OpenAIHandler } from '../../src/handlers/openai'; @@ -21,6 +22,7 @@ describe('getHandler', () => { { model: 'j2-mid-instruct', expectedHandler: AI21Handler }, { model: 'j2-ultra-instruct', expectedHandler: AI21Handler }, { model: 'replicate/test/test', expectedHandler: ReplicateHandler }, + { model: 'deepinfra/test/test', expectedHandler: DeepInfraHandler }, { model: 'unknown', expectedHandler: null }, ])( 'should return the correct handler for a given model name',