Skip to content

Commit

Permalink
Merge pull request #10 from enricobellato/main
Browse files Browse the repository at this point in the history
Implement 'apiKey' Parameter for Direct API Key Overrides
  • Loading branch information
zya authored Dec 21, 2023
2 parents 7e2df4e + f6ca576 commit 69409ee
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { MistralEmbeddingHandler } from './handlers/mistralEmbedding';
export interface EmbeddingParams {
input: string | string[];
model: string;
apiKey?: string;
baseUrl?: string;
}

Expand Down
6 changes: 4 additions & 2 deletions src/handlers/ai21.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ async function getAI21Response(
model: string,
prompt: string,
baseUrl: string,
apiKey: string,
): Promise<Response> {
return fetch(`${baseUrl}/studio/v1/${model}/complete`, {
method: 'POST',
headers: {
Authorization: `Bearer ${process.env.AI21_API_KEY}`,
Authorization: `Bearer ${apiKey}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
Expand All @@ -132,10 +133,11 @@ export async function AI21Handler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const baseUrl = params.baseUrl ?? 'https://api.ai21.com';
const apiKey = params. apiKey ?? process.env.AI21_API_KEY!;
const model = params.model;
const prompt = combinePrompts(params.messages);

const res = await getAI21Response(model, prompt, baseUrl);
const res = await getAI21Response(model, prompt, baseUrl, apiKey);

if (!res.ok) {
throw new Error(`Recieved an error with code ${res.status} from AI21 API.`);
Expand Down
4 changes: 3 additions & 1 deletion src/handlers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ export async function AnthropicHandler(
export async function AnthropicHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const apiKey = params. apiKey ?? process.env.ANTHROPIC_API_KEY;

const anthropic = new Anthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
apiKey: apiKey,
});
const prompt = toAnthropicPrompt(params.messages);

Expand Down
4 changes: 3 additions & 1 deletion src/handlers/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ export async function CohereHandler(
export async function CohereHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
cohere.init(process.env.COHERE_API_KEY!);
const apiKey = params. apiKey ?? process.env.COHERE_API_KEY!;

cohere.init(apiKey);
const textsCombined = combinePrompts(params.messages);

const config = {
Expand Down
5 changes: 4 additions & 1 deletion src/handlers/deepinfra.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ async function getDeepInfraResponse(
model: string,
messages: Message[],
baseUrl: string,
apiKey: string,
stream: boolean,
): Promise<Response> {
return fetch(`${baseUrl}/v1/openai/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.DEEPINFRA_API_KEY}`,
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({
messages,
Expand All @@ -68,12 +69,14 @@ export async function DeepInfraHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const baseUrl = params.baseUrl ?? 'https://api.deepinfra.com';
const apiKey = params. apiKey ?? process.env.DEEPINFRA_API_KEY!;
const model = params.model.split('deepinfra/')[1];

const res = await getDeepInfraResponse(
model,
params.messages,
baseUrl,
apiKey,
params.stream ?? false,
);

Expand Down
5 changes: 4 additions & 1 deletion src/handlers/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ async function getMistralResponse(
model: string,
messages: Message[],
baseUrl: string,
apiKey: string,
stream: boolean,
): Promise<Response> {
return fetch(`${baseUrl}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.MISTRAL_API_KEY}`,
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({
messages,
Expand All @@ -69,12 +70,14 @@ export async function MistralHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const baseUrl = params.baseUrl ?? 'https://api.mistral.ai';
const apiKey = params. apiKey ?? process.env.MISTRAL_API_KEY!;
const model = params.model.split('mistral/')[1];

const res = await getMistralResponse(
model,
params.messages,
baseUrl,
apiKey,
params.stream ?? false,
);

Expand Down
11 changes: 9 additions & 2 deletions src/handlers/mistralEmbedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ async function getMistralResponse(
model: string,
input: EmbeddingParams['input'],
baseUrl: string,
apiKey: string,
): Promise<Response> {
return fetch(`${baseUrl}/v1/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.MISTRAL_API_KEY}`,
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({
model,
Expand All @@ -24,7 +25,13 @@ export async function MistralEmbeddingHandler(
): Promise<EmbeddingResponse> {
const model = params.model.split('mistral/')[1];
const baseUrl = params.baseUrl ?? 'https://api.mistral.ai';
const response = await getMistralResponse(model, params.input, baseUrl);
const apiKey = params. apiKey ?? process.env.MISTRAL_API_KEY!;
const response = await getMistralResponse(
model,
params.input,
baseUrl,
apiKey,
);

if (!response.ok) {
throw new Error(
Expand Down
6 changes: 5 additions & 1 deletion src/handlers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ export async function OpenAIHandler(
export async function OpenAIHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const openai = new OpenAI();
const apiKey = params. apiKey ?? process.env.OPENAI_API_KEY;

const openai = new OpenAI({
apiKey: apiKey,
});

if (params.stream) {
const response = await openai.chat.completions.create({
Expand Down
6 changes: 5 additions & 1 deletion src/handlers/openaiEmbedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import { EmbeddingParams, EmbeddingResponse } from '../embedding';
export async function OpenAIEmbeddingHandler(
params: EmbeddingParams,
): Promise<EmbeddingResponse> {
const openai = new OpenAI();
const apiKey = params.apiKey ?? process.env.OPENAI_API_KEY;

const openai = new OpenAI({
apiKey: apiKey,
});
return openai.embeddings.create({ input: params.input, model: params.model });
}
3 changes: 2 additions & 1 deletion src/handlers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ export async function ReplicateHandler(
export async function ReplicateHandler(
params: HandlerParams,
): Promise<ResultNotStreaming | ResultStreaming> {
const apiKey = params. apiKey ?? process.env.REPLICATE_API_KEY;
const replicate = new Replicate({
auth: process.env.REPLICATE_API_KEY,
auth: apiKey,
});
const model = params.model.split('replicate/')[1];
const version = model.split(':')[1];
Expand Down
1 change: 1 addition & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export interface HandlerParamsBase {
presence_penalty?: number | null;
n?: number | null;
max_tokens?: number | null;
apiKey?: string;
functions?: ChatCompletionCreateParams.Function[];
function_call?:
| 'none'
Expand Down

0 comments on commit 69409ee

Please sign in to comment.