Skip to content

Commit

Permalink
🎨 refactor: refactor the route auth as a middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Apr 10, 2024
1 parent be4bcca commit ef5ee2a
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/app/api/chat/[provider]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
import { LobeRuntimeAI } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';

import { getJWTPayload } from '../auth';
import AgentRuntime from './agentRuntime';
import AgentRuntime from '../agentRuntime';
import { getJWTPayload } from '../auth/utils';
import { POST } from './route';

vi.mock('../auth', () => ({
vi.mock('../auth/utils', () => ({
getJWTPayload: vi.fn(),
checkAuthMethod: vi.fn(),
}));
Expand Down
22 changes: 5 additions & 17 deletions src/app/api/chat/[provider]/route.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
import { getPreferredRegion } from '@/app/api/config';
import { createErrorResponse } from '@/app/api/errorResponse';
import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
import { AgentRuntimeError, ChatCompletionErrorPayload } from '@/libs/agent-runtime';
import { ChatCompletionErrorPayload } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';
import { ChatStreamPayload } from '@/types/openai/chat';
import { getTracePayload } from '@/utils/trace';

import { checkAuthMethod, getJWTPayload } from '../auth';
import AgentRuntime from './agentRuntime';
import AgentRuntime from '../agentRuntime';
import { checkAuth } from '../auth';

export const runtime = 'edge';

export const preferredRegion = getPreferredRegion();

export const POST = async (req: Request, { params }: { params: { provider: string } }) => {
export const POST = checkAuth(async (req: Request, { params, jwtPayload }) => {
const { provider } = params;

try {
// ============ 1. init chat model ============ //

// get Authorization from header
const authorization = req.headers.get(LOBE_CHAT_AUTH_HEADER);
const oauthAuthorized = !!req.headers.get(OAUTH_AUTHORIZED);

if (!authorization) throw AgentRuntimeError.createError(ChatErrorType.Unauthorized);

// check the Auth With payload
const jwtPayload = await getJWTPayload(authorization);
checkAuthMethod(jwtPayload.accessCode, jwtPayload.apiKey, oauthAuthorized);

const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload);

// ============ 2. create chat completion ============ //
Expand All @@ -55,4 +43,4 @@ export const POST = async (req: Request, { params }: { params: { provider: strin

return createErrorResponse(errorType, { error, ...res, provider });
}
};
});
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
} from '@/libs/agent-runtime';
import { TraceClient } from '@/libs/traces';

import apiKeyManager from '../apiKeyManager';
import apiKeyManager from './apiKeyManager';

export interface AgentChatOptions {
enableTrace?: boolean;
Expand Down
42 changes: 42 additions & 0 deletions src/app/api/chat/auth/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { createErrorResponse } from '@/app/api/errorResponse';
import { JWTPayload, LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
import { AgentRuntimeError, ChatCompletionErrorPayload } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';

import { checkAuthMethod, getJWTPayload } from './utils';

type RequestOptions = { params: { provider: string } };

export type RequestHandler = (
req: Request,
options: RequestOptions & { jwtPayload: JWTPayload },
) => Promise<Response>;

export const checkAuth =
(handler: RequestHandler) => async (req: Request, options: RequestOptions) => {
let jwtPayload: JWTPayload;

try {
// get Authorization from header
const authorization = req.headers.get(LOBE_CHAT_AUTH_HEADER);
const oauthAuthorized = !!req.headers.get(OAUTH_AUTHORIZED);

if (!authorization) throw AgentRuntimeError.createError(ChatErrorType.Unauthorized);

// check the Auth With payload
jwtPayload = await getJWTPayload(authorization);
checkAuthMethod(jwtPayload.accessCode, jwtPayload.apiKey, oauthAuthorized);
} catch (e) {
const {
errorType = ChatErrorType.InternalServerError,
error: errorContent,
...res
} = e as ChatCompletionErrorPayload;

const error = errorContent || e;

return createErrorResponse(errorType, { error, ...res, provider: options.params?.provider });
}

return handler(req, { ...options, jwtPayload });
};
File renamed without changes.
2 changes: 1 addition & 1 deletion src/app/api/plugin/gateway/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk';
import { createGatewayOnEdgeRuntime } from '@lobehub/chat-plugins-gateway';

import { getJWTPayload } from '@/app/api/chat/auth';
import { getJWTPayload } from '@/app/api/chat/auth/utils';
import { createErrorResponse } from '@/app/api/errorResponse';
import { getServerConfig } from '@/config/server';
import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
Expand Down

0 comments on commit ef5ee2a

Please sign in to comment.