Skip to content

Commit

Permalink
Feature: Add support for Mistral API
Browse files Browse the repository at this point in the history
  • Loading branch information
zya committed Dec 17, 2023
1 parent ad6bc25 commit d73b02c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .example.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ OPENAI_API_KEY=<Your OpenAI API key>
COHERE_API_KEY=<Your Cohere API key>
ANTHROPIC_API_KEY=<Your Anthropic API key>
AI21_API_KEY=<Your Anthropic API key>
REPLICATE_API_KEY=<Your Replicate API key>
REPLICATE_API_KEY=<Your Replicate API key>
MISTRAL_API_KEY=<Your Mistral API Key>
2 changes: 2 additions & 0 deletions src/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { OpenAIHandler } from './handlers/openai';
import { AI21Handler } from './handlers/ai21';
import { ReplicateHandler } from './handlers/replicate';
import { DeepInfraHandler } from './handlers/deepinfra';
import { MistralHandler } from './handlers/mistral';

export const MODEL_HANDLER_MAPPINGS: Record<string, Handler> = {
'claude-': AnthropicHandler,
Expand All @@ -24,6 +25,7 @@ export const MODEL_HANDLER_MAPPINGS: Record<string, Handler> = {
'j2-': AI21Handler,
'replicate/': ReplicateHandler,
'deepinfra/': DeepInfraHandler,
'mistral/': MistralHandler,
};

export async function completion(
Expand Down
86 changes: 86 additions & 0 deletions src/handlers/mistral.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { ChatCompletion } from 'openai/resources/chat';

import {
HandlerParams,
HandlerParamsNotStreaming,
HandlerParamsStreaming,
Message,
ResultNotStreaming,
ResultStreaming,
StreamingChunk,
} from '../types';

async function* iterateResponse(
response: Response,
): AsyncIterable<StreamingChunk> {
const reader = response.body?.getReader();
let done = false;

while (!done) {
const next = await reader?.read();
if (next?.value) {
done = next.done;
const decoded = new TextDecoder().decode(next.value);
if (decoded.startsWith('data: [DONE]')) {
done = true;
} else {
const [, value] = decoded.split('data: ');
yield JSON.parse(value);
}
} else {
done = true;
}
}
}

async function getMistralResponse(
model: string,
messages: Message[],
baseUrl: string,
stream: boolean,
): Promise<Response> {
return fetch(`${baseUrl}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.MISTRAL_API_KEY}`,
},
body: JSON.stringify({
messages,
model,
stream,
}),
});
}

export async function MistralHandler(
params: HandlerParamsNotStreaming,
): Promise<ResultNotStreaming>;

export async function MistralHandler(
params: HandlerParamsStreaming,
): Promise<ResultStreaming>;

export async function MistralHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming>;

export async function MistralHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const baseUrl = params.baseUrl ?? 'https://api.mistral.ai';
const model = params.model.split('mistral/')[1];

const res = await getMistralResponse(
model,
params.messages,
baseUrl,
params.stream ?? false,
);

if (params.stream) {
return iterateResponse(res);
}

return res.json() as Promise<ChatCompletion>;
}
2 changes: 2 additions & 0 deletions tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ describe('e2e', () => {
${'j2-light'}
${'replicate/meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3'}
${'deepinfra/mistralai/Mistral-7B-Instruct-v0.1'}
${'mistral/mistral-tiny'}
`(
'gets response from supported model $model',
async ({ model }) => {
Expand All @@ -39,6 +40,7 @@ describe('e2e', () => {
${'j2-light'}
${'replicate/meta/llama-2-7b-chat:ac944f2e49c55c7e965fc3d93ad9a7d9d947866d6793fb849dd6b4747d0c061c'}
${'deepinfra/mistralai/Mistral-7B-Instruct-v0.1'}
${'mistral/mistral-tiny'}
`(
'gets streaming response from supported model $model',
async ({ model }) => {
Expand Down
2 changes: 2 additions & 0 deletions tests/handlers/getHandler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { AnthropicHandler } from '../../src/handlers/anthropic';
import { CohereHandler } from '../../src/handlers/cohere';
import { DeepInfraHandler } from '../../src/handlers/deepinfra';
import { getHandler } from '../../src/handlers/getHandler';
import { MistralHandler } from '../../src/handlers/mistral';
import { OllamaHandler } from '../../src/handlers/ollama';
import { OpenAIHandler } from '../../src/handlers/openai';
import { ReplicateHandler } from '../../src/handlers/replicate';
Expand All @@ -23,6 +24,7 @@ describe('getHandler', () => {
{ model: 'j2-ultra-instruct', expectedHandler: AI21Handler },
{ model: 'replicate/test/test', expectedHandler: ReplicateHandler },
{ model: 'deepinfra/test/test', expectedHandler: DeepInfraHandler },
{ model: 'mistral/mistral-tiny', expectedHandler: MistralHandler },
{ model: 'unknown', expectedHandler: null },
])(
'should return the correct handler for a given model name',
Expand Down

0 comments on commit d73b02c

Please sign in to comment.