Skip to content

Commit

Permalink
Merge pull request #465 from boostcampwm2023/BE/feature/#460-AI-통신-So…
Browse files Browse the repository at this point in the history
…cket-통신-서버-측-테스트-코드-작성-and-예외-처리

Be/feature/#460 ai 통신 socket 통신 서버 측 리팩토링 and 테스트 코드 작성
  • Loading branch information
HeoJiye authored Jan 18, 2024
2 parents d95fbd2 + e1afb99 commit 0efdf01
Show file tree
Hide file tree
Showing 38 changed files with 1,474 additions and 496 deletions.
69 changes: 69 additions & 0 deletions backend/was/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/was/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"husky": "^8.0.3",
"jest": "^29.5.0",
"prettier": "^3.0.0",
"socket.io-client": "^4.7.3",
"source-map-support": "^0.5.21",
"supertest": "^6.3.3",
"ts-jest": "^29.1.0",
Expand Down
6 changes: 4 additions & 2 deletions backend/was/src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ import { ConfigModule } from '@nestjs/config';
import { APP_INTERCEPTOR } from '@nestjs/core';
import { AuthModule } from './auth/auth.module';
import { ChatModule } from './chat/chat.module';
import { ChatbotModule } from './chatbot/chatbot.module';
import { RedisCacheModule } from './common/config/cache/redis-cache.module';
import { MysqlModule } from './common/config/database/mysql.module';
import { JwtConfigModule } from './common/config/jwt/jwt.module';
import { ErrorsInterceptor } from './common/interceptors/errors.interceptor';
import { EventsModule } from './events/events.module';
import { LoggerModule } from './logger/logger.module';
import { MembersModule } from './members/members.module';
import { SocketModule } from './socket/socket.module';
import { TarotModule } from './tarot/tarot.module';

@Module({
Expand All @@ -21,7 +22,8 @@ import { TarotModule } from './tarot/tarot.module';
MysqlModule,
ChatModule,
TarotModule,
EventsModule,
ChatbotModule,
SocketModule,
LoggerModule,
AuthModule,
],
Expand Down
13 changes: 6 additions & 7 deletions backend/was/src/chat/dto/create-chatting-message.dto.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { IsBoolean, IsString, IsUUID } from 'class-validator';
import { Message } from 'src/events/type';
import { ChatLog } from 'src/common/types/chatbot';

export class CreateChattingMessageDto {
@IsUUID()
Expand All @@ -20,11 +20,10 @@ export class CreateChattingMessageDto {
})
readonly message: string;

static fromMessage(message: Message): CreateChattingMessageDto {
return {
roomId: message.roomId,
isHost: message.chat.role === 'assistant',
message: message.chat.content,
};
static fromChatLog(
roomId: string,
chatLog: ChatLog,
): CreateChattingMessageDto {
return { roomId, ...chatLog };
}
}
12 changes: 12 additions & 0 deletions backend/was/src/chatbot/chatbot.interface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import type { ChatLog } from 'src/common/types/chatbot';

export interface ChatbotService {
generateTalk(
chatLogs: ChatLog[],
message: string,
): Promise<ReadableStream<Uint8Array>>;
generateTarotReading(
chatLogs: ChatLog[],
cardIdx: number,
): Promise<ReadableStream<Uint8Array>>;
}
8 changes: 8 additions & 0 deletions backend/was/src/chatbot/chatbot.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { Module } from '@nestjs/common';
import { ClovaStudioService } from './clova-studio/clova-studio.service';

@Module({
providers: [{ provide: 'ChatbotService', useClass: ClovaStudioService }],
exports: ['ChatbotService'],
})
export class ChatbotModule {}
44 changes: 44 additions & 0 deletions backend/was/src/chatbot/clova-studio/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import {
CLOVA_API_DEFAULT_BODY_OPTIONS,
CLOVA_API_DEFAULT_HEADER_OPTIONS,
CLOVA_URL,
} from 'src/common/constants/clova-studio';
import { ERR_MSG } from 'src/common/constants/errors';
import type {
ClovaStudioApiKeys,
ClovaStudioMessage,
} from 'src/common/types/clova-studio';

type APIOptions = {
apiKeys: ClovaStudioApiKeys;
messages: ClovaStudioMessage[];
maxTokens: number;
};

export async function clovaStudioApi({
apiKeys,
messages,
maxTokens,
}: APIOptions): Promise<ReadableStream<Uint8Array>> {
const response = await fetch(CLOVA_URL, {
method: 'POST',
headers: {
...CLOVA_API_DEFAULT_HEADER_OPTIONS,
...apiKeys,
},
body: JSON.stringify({
...CLOVA_API_DEFAULT_BODY_OPTIONS,
maxTokens,
messages,
}),
});

if (!response.ok) {
const errorMessage = `${ERR_MSG.AI_API_FAILED}: 상태코드 ${response.statusText}`;
throw new Error(errorMessage);
}
if (!response.body) {
throw new Error(ERR_MSG.AI_API_RESPONSE_EMPTY);
}
return response.body;
}
72 changes: 72 additions & 0 deletions backend/was/src/chatbot/clova-studio/clova-studio.service.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing';
import { CLOVA_API_KEY_NAMES } from 'src/common/constants/clova-studio';
import { string2Uint8ArrayStream } from 'src/common/utils/stream';
import {
clovaStudioApiMock,
configServieMock,
createAllEventStringMock,
vaildateTokenStream,
} from 'src/mocks/clova-studio';
import { ClovaStudioService, getAPIKeys } from './clova-studio.service';

jest.mock('./api');

describe('ClovaStudioService', () => {
let clovaStudioService: ClovaStudioService;
const tokens = ['안', '녕', '하', '세', '요'];

beforeEach(async () => {
jest.clearAllMocks();

const module: TestingModule = await Test.createTestingModule({
providers: [
ClovaStudioService,
{ provide: ConfigService, useValue: configServieMock },
],
}).compile();

clovaStudioService = module.get<ClovaStudioService>(ClovaStudioService);
});

it('ClovaStudioService 생성', () => {
expect(clovaStudioService).toBeDefined();
});

describe('function getAPIKeys()', () => {
it('getAPIKeys(): clova api key 불러 와서 객체로 만들어서 반환', () => {
const apiKeys = getAPIKeys(configServieMock);

CLOVA_API_KEY_NAMES.forEach((key) => {
expect(apiKeys[key.replaceAll('_', '-')]).toBe(key);
});
});
});

describe('ClovaStudioService.generateTalk()', () => {
it('사용자의 메세지 입력으로 AI의 답변을 생성해서 token stream 형식으로 반환', async () => {
setApiMock(tokens);

const tokenStream = await clovaStudioService.generateTalk([], '안녕!');

const result = await vaildateTokenStream(tokenStream, tokens);
expect(result).toBeTruthy();
});
});

describe('ClovaStudioService.generateTarotReading()', () => {
it('사용자가 뽑은 카드 인덱스로 AI의 해설을 생성해서 token stream 형식으로 반환', async () => {
setApiMock(tokens);

const tokenStream = await clovaStudioService.generateTarotReading([], 21);

const result = await vaildateTokenStream(tokenStream, tokens);
expect(result).toBeTruthy();
});
});
});

function setApiMock(tokens: string[]) {
const chunks = createAllEventStringMock(tokens);
clovaStudioApiMock.mockReturnValueOnce(string2Uint8ArrayStream(chunks));
}
71 changes: 71 additions & 0 deletions backend/was/src/chatbot/clova-studio/clova-studio.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import {
CHAT_MAX_TOKENS,
CLOVA_API_KEY_NAMES,
TAROT_MAX_TOKENS,
} from 'src/common/constants/clova-studio';
import { ERR_MSG } from 'src/common/constants/errors';
import type { ChatLog } from 'src/common/types/chatbot';
import type {
ClovaStudioApiKeys,
ClovaStudioMessage,
} from 'src/common/types/clova-studio';
import { ChatbotService } from '../chatbot.interface';
import { clovaStudioApi } from './api';
import {
buildTalkMessages,
buildTarotReadingMessages,
chatLog2clovaStudioMessages,
} from './message';
import { apiResponseStream2TokenStream } from './stream';

@Injectable()
export class ClovaStudioService implements ChatbotService {
private readonly apiKeys: ClovaStudioApiKeys;

constructor(private readonly configService: ConfigService) {
this.apiKeys = getAPIKeys(this.configService);
}

generateTalk(
chatLogs: ChatLog[],
userMessage: string,
): Promise<ReadableStream<Uint8Array>> {
const convertedMessages = chatLog2clovaStudioMessages(chatLogs);
const messages = buildTalkMessages(convertedMessages, userMessage);

return this.api(messages, CHAT_MAX_TOKENS);
}

generateTarotReading(
chatLogs: ChatLog[],
cardIdx: number,
): Promise<ReadableStream<Uint8Array>> {
const convertedMessages = chatLog2clovaStudioMessages(chatLogs);
const messages = buildTarotReadingMessages(convertedMessages, cardIdx);

return this.api(messages, TAROT_MAX_TOKENS);
}

private async api(
messages: ClovaStudioMessage[],
maxTokens: number,
): Promise<ReadableStream<Uint8Array>> {
const options = { apiKeys: this.apiKeys, messages, maxTokens };
const responseStream = await clovaStudioApi(options);

return await apiResponseStream2TokenStream(responseStream);
}
}

export function getAPIKeys(configService: ConfigService) {
return CLOVA_API_KEY_NAMES.reduce((acc, key) => {
const value = configService.get(key);

if (!value) throw new Error(ERR_MSG.AI_API_KEY_NOT_FOUND);

acc[key.replaceAll('_', '-')] = value;
return acc;
}, {} as ClovaStudioApiKeys);
}
Loading

0 comments on commit 0efdf01

Please sign in to comment.