Skip to content

Commit

Permalink
Merge pull request #6 from goormthon-Univ/develope
Browse files Browse the repository at this point in the history
feat : gpt 설정 및 s3 설정 추가
  • Loading branch information
Sirius506775 authored Mar 23, 2024
2 parents cb8e3e3 + deb5f43 commit f6a5ee4
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 63 deletions.
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ dependencies {
implementation 'io.awspring.cloud:spring-cloud-starter-aws:2.4.2'

implementation 'javax.xml.bind:jaxb-api:2.3.1'

implementation 'com.amazonaws:aws-java-sdk-s3:1.11.238'

/* 검증을 위한 validation 의존성 추가*/
implementation 'org.springframework.boot:spring-boot-starter-validation'

}

tasks.named('test') {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package site.balpyo.ai.controller;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.function.ServerRequest;
import site.balpyo.ai.dto.AIGenerateRequest;
import site.balpyo.ai.service.AIGenerateService;
import site.balpyo.common.dto.CommonResponse;
Expand All @@ -18,6 +18,7 @@
@RestController
@RequestMapping("/user/ai")
@RequiredArgsConstructor
@Slf4j
public class AIUserController {

private final AIGenerateService aiGenerateService;
Expand All @@ -34,6 +35,10 @@ public ResponseEntity<CommonResponse> generateScript(@Valid @RequestBody AIGener
System.out.println(uid);
if(CommonUtils.isAnyParameterNullOrBlank(uid))return CommonResponse.error(ErrorEnum.BALPYO_UID_KEY_MISSING);

log.info("-------------------- 스크립트 생성 요청");
log.info("-------------------- 요청 내용 ");
log.info("--------------------" + aiGenerateRequest);

return aiGenerateService.generateScript(aiGenerateRequest,uid);
}

Expand Down
11 changes: 10 additions & 1 deletion src/main/java/site/balpyo/ai/controller/PollyController.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public class PollyController {
@PostMapping("/generateAudio")
public ResponseEntity<?> synthesizeText(@RequestBody PollyDTO pollyDTO) {

log.info("--------------------controller로 텍스트 음성 변환 요청");

if (!BALPYO_API_KEY.equals(pollyDTO.getBalpyoAPIKey())) {
return CommonResponse.error(ErrorEnum.BALPYO_API_KEY_ERROR);
}
Expand All @@ -50,6 +52,12 @@ public ResponseEntity<?> synthesizeText(@RequestBody PollyDTO pollyDTO) {
// Amazon Polly와 통합하여 텍스트를 음성으로 변환
InputStream audioStream = pollyService.synthesizeSpeech(pollyDTO);


if (audioStream == null) {
log.error("Amazon Polly 음성 변환 실패: 반환된 오디오 스트림이 null입니다.");
return CommonResponse.error(ErrorEnum.INTERNAL_SERVER_ERROR);
}

// InputStream을 byte 배열로 변환
byte[] audioBytes = IOUtils.toByteArray(audioStream);

Expand All @@ -63,9 +71,10 @@ public ResponseEntity<?> synthesizeText(@RequestBody PollyDTO pollyDTO) {
.body(audioBytes);

} catch (IOException e) {
e.printStackTrace();
log.error("내부 서버 오류: " + e.getMessage());
return CommonResponse.error(ErrorEnum.INTERNAL_SERVER_ERROR);
}
}


}
7 changes: 7 additions & 0 deletions src/main/java/site/balpyo/ai/dto/AudioDTO.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package site.balpyo.ai.dto;

public class AudioDTO {

private String profileUrl;
private String audio;
}
6 changes: 6 additions & 0 deletions src/main/java/site/balpyo/ai/entity/AIGenerateLogEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ public class AIGenerateLogEntity {
@CreationTimestamp
private LocalDateTime createdAt;

@OneToOne(mappedBy = "aiGenerateLogEntity",
cascade = CascadeType.ALL,
fetch = FetchType.LAZY,
orphanRemoval = true)
private FlowAudio flowAudio;

public AIGenerateLogEntity convertToEntity(AIGenerateRequest aiGenerateRequest, GPTInfoEntity gptInfoEntity,GuestEntity guestEntity){
return AIGenerateLogEntity.builder()
.secTime(aiGenerateRequest.getSecTime())
Expand Down
30 changes: 30 additions & 0 deletions src/main/java/site/balpyo/ai/entity/FlowAudio.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package site.balpyo.ai.entity;

import lombok.*;

import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.Id;
import javax.persistence.OneToOne;

@Entity
@Getter
@Builder
@AllArgsConstructor
@NoArgsConstructor
@ToString(exclude = "aiGenerateLogEntity")
public class FlowAudio {

@Id
@Column(name = "profileUrl")
private String profileUrl;

@OneToOne
private AIGenerateLogEntity aiGenerateLogEntity;

public void changeAudio(AIGenerateLogEntity aiGenerateLogEntity) {
this.aiGenerateLogEntity = aiGenerateLogEntity;
}

}

13 changes: 6 additions & 7 deletions src/main/java/site/balpyo/ai/service/AIGenerateService.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package site.balpyo.ai.service;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
Expand All @@ -11,20 +11,19 @@
import site.balpyo.ai.entity.AIGenerateLogEntity;
import site.balpyo.ai.entity.GPTInfoEntity;
import site.balpyo.ai.repository.AIGenerateLogRepository;
import site.balpyo.ai.repository.GPTInfoRepository;
import site.balpyo.common.dto.CommonResponse;
import site.balpyo.common.dto.ErrorEnum;
import site.balpyo.common.util.CommonUtils;
import site.balpyo.guest.entity.GuestEntity;
import site.balpyo.guest.repository.GuestRepository;

import javax.transaction.Transactional;
import java.lang.reflect.Array;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

@Service
@Slf4j
@RequiredArgsConstructor
public class AIGenerateService {

Expand All @@ -35,7 +34,6 @@ public class AIGenerateService {

private final GuestRepository guestRepository;


@Value("${secrets.GPT_API_KEY}")
public String GPT_API_KEY;
@Transactional
Expand All @@ -50,7 +48,7 @@ public ResponseEntity<CommonResponse> generateScript(AIGenerateRequest request,S
//1. 주제, 소주제, 시간을 기반으로 프롬프트 생성
String currentPromptString = aiGenerateUtils.createPromptString(request.getTopic(), request.getKeywords(), request.getSecTime());
//2. 작성된 프롬프트를 기반으로 GPT에게 대본작성 요청
ResponseEntity<Map> generatedScriptObject = aiGenerateUtils.requestGPTTextGeneration(currentPromptString, 0.5f, 9000, CURRENT_GPT_API_KEY);
ResponseEntity<Map> generatedScriptObject = aiGenerateUtils.requestGPTTextGeneration(currentPromptString, 0.5f, 100000, CURRENT_GPT_API_KEY);
//3. GPT응답을 기반으로 대본 추출 + 대본이 없다면 대본 생성 실패 에러 반환
Object resultScript = generatedScriptObject.getBody().get("choices"); if(CommonUtils.isAnyParameterNullOrBlank(resultScript)) return CommonResponse.error(ErrorEnum.GPT_GENERATION_ERROR);

Expand All @@ -74,6 +72,7 @@ public ResponseEntity<CommonResponse> generateScript(AIGenerateRequest request,S
aiGenerateLogRepository.save(aiGenerateLog); //저장
String GPTId = aiGenerateLog.getGptInfoEntity().getGptInfoId();

log.info("-------------------- 저장된 사용 기록 : " + aiGenerateLog);

return CommonResponse.success(new AIGenerateResponse(resultScript,GPTId));
}
Expand Down
46 changes: 31 additions & 15 deletions src/main/java/site/balpyo/ai/service/AIGenerateUtils.java
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
package site.balpyo.ai.service;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import site.balpyo.ai.dto.GPTResponse;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

@Component
@Slf4j
public class AIGenerateUtils {

private static final String ENDPOINT = "https://api.openai.com/v1/chat/completions";

public String createPromptString(String topic, String keywords, Integer sec) {
return "Ignore all previous instructions. \n" +
"\n" +
"I want you to act as a presenter specialized in " + topic + ". My first request is for you to generate a script:\n" +
"\n" +
"Make a script by calculating 150ms per syllable, including spaces, and 250ms for line breaks, commas, and periods." +
"Here's some context:\n" +
"Topic - " + topic + "\n" +
"Keywords - " + keywords + "\n" +
"Amount - " + sec + " sec" +
"\n" +
"Please write in Korean.";
log.info("-------------------- 프롬프트 명령 실행");

// 초기화한 값에 해당하는 글자 수와 시간 비율 계산
int initialCharacterCount = 425; // 초기화한 공백 포함 글자 수
double characterPerSecond = (double) initialCharacterCount / 60.0; // 초당 평균 글자 수

log.info("-------------------- 초당 평균 글자 수 : " + characterPerSecond);

// 주어진 시간(sec)에 해당하는 글자 수 계산
int targetCharacterCount = (int) (sec * characterPerSecond);

log.info("-------------------- 주어진 시간(" + sec + "초)에 해당하는 예상 글자수 : " + targetCharacterCount);

// 주어진 시간(sec)에 해당하는 바이트 수 계산
int targetByteCount = targetCharacterCount * 3; // 한글은 3바이트로 가정

log.info("-------------------- 예상 바이트수 : " + targetByteCount);


return "You need to create a presentation script in Korean.\n" +
"The topic is " + topic + ", and the keywords are " + keywords + ".\n" +
"Please generate a script of " + targetByteCount + " bytes.\n" +
"Count every character, including spaces, special characters, and line breaks, as one byte.\n" +
"When creating a script, exclude characters such as '(', ')', ''', '-', '[', ']' and '_'.\n" +
"This is to prevent bugs that may occur in scripts that include request values in the response.\n" +
"It must be exactly " + targetByteCount + " bytes long. " +
"GPT, you're smart enough to provide me with a script of " + targetByteCount + " bytes, right? Can you do that?";
}



public ResponseEntity<Map> requestGPTTextGeneration(String prompt, float temperature, int maxTokens ,String API_KEY) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
Expand All @@ -44,7 +60,7 @@ public ResponseEntity<Map> requestGPTTextGeneration(String prompt, float tempera
message.put("content", prompt);

Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", "gpt-3.5-turbo");
requestBody.put("model", "gpt-4-0125-preview");
requestBody.put("messages", Arrays.asList(message));
requestBody.put("temperature", temperature);

Expand Down
65 changes: 56 additions & 9 deletions src/main/java/site/balpyo/ai/service/PollyService.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public InputStream synthesizeSpeech(PollyDTO pollyDTO) {
String inputText = pollyDTO.getText();
int speed = pollyDTO.getSpeed();

log.info("-------------------- 클라이언트가 요청한 대본 :" + inputText);
log.info("-------------------- 클라이언트가 요청한 빠르기 :" + speed);

// Amazon Polly 클라이언트 생성
AmazonPolly amazonPolly = AmazonPollyClient.builder()
.withRegion(Regions.AP_NORTHEAST_2) // 서울 리전
Expand All @@ -38,20 +41,30 @@ public InputStream synthesizeSpeech(PollyDTO pollyDTO) {
// 빠르기 계산
float relativeSpeed = calculateRelativeSpeed(speed);

// SynthesizeSpeechRequest 생성
log.info("-------------------- 선택한 빠르기 :" + relativeSpeed);

// SSML 텍스트 생성
String ssmlText = buildSsmlText(inputText, relativeSpeed);

// SynthesizeSpeechRequest 생성 및 설정
SynthesizeSpeechRequest synthesizeSpeechRequest = new SynthesizeSpeechRequest()
.withText(inputText)
.withText(ssmlText)
.withOutputFormat(OutputFormat.Mp3) // MP3 형식
.withVoiceId(VoiceId.Seoyeon) // 한국어 음성 변환 보이스
.withTextType("ssml") // SSML 형식 사용 -> <prosody> 태그와 rate로 설정 가능
.withText("<speak><prosody rate=\"" + relativeSpeed + "\">" + inputText + "</prosody></speak>");

// 텍스트를 음성으로 변환하여 InputStream으로 반환
SynthesizeSpeechResult synthesizeSpeechResult = amazonPolly.synthesizeSpeech(synthesizeSpeechRequest);
return synthesizeSpeechResult.getAudioStream();
}
.withTextType("ssml"); // SSML 형식 사용

try { // 텍스트를 음성으로 변환하여 InputStream으로 반환
SynthesizeSpeechResult synthesizeSpeechResult = amazonPolly.synthesizeSpeech(synthesizeSpeechRequest);

log.info("-------------------- 요청된 문자열 개수 : " + synthesizeSpeechResult.getRequestCharacters());
log.info("-------------------- 음성변환 요청 성공");

return synthesizeSpeechResult.getAudioStream();
} catch (AmazonPollyException e) {
log.error("-------------------- 음성 변환 실패: " + e.getErrorMessage());
throw e;
}
}

/**
* mp3 audio 생성 시, 빠르기 설정 메소드
Expand All @@ -74,5 +87,39 @@ private static float calculateRelativeSpeed(int speed) {
}
}

/**
* SSML 텍스트 생성 메소드
*/
private String buildSsmlText(String inputText, float relativeSpeed) {
StringBuilder ssmlBuilder = new StringBuilder();
ssmlBuilder.append("<speak>");
ssmlBuilder.append(String.format("<prosody rate=\"%f%%\">", relativeSpeed * 100));

for (char ch : inputText.toCharArray()) {
switch (ch) {
case ',':
// 쉼표일 때 숨쉬기 태그 추가
ssmlBuilder.append("<break time=\"500ms\"/>");
break;
case '.':
case '!':
ssmlBuilder.append("<break time=\"600ms\"/>");
break;
case '?':
ssmlBuilder.append("<break time=\"800ms\"/>");
case '\n':
// 마침표나 개행 문자일 때 조금 더 긴 일시정지
ssmlBuilder.append("<break time=\"300ms\"/>");
break;
default:
// 기본 문자 처리
ssmlBuilder.append(ch);
break;
}
}

ssmlBuilder.append("</prosody>");
ssmlBuilder.append("</speak>");
return ssmlBuilder.toString();
}
}
37 changes: 37 additions & 0 deletions src/main/java/site/balpyo/common/s3/S3Config.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package site.balpyo.common.s3;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class S3Config {

@Value("${cloud.aws.s3.bucket}")
private String bucket;

@Value("${cloud.aws.credentials.access-key}")
private String accessKey;

@Value("${cloud.aws.credentials.secret-key}")
private String secretKey;

@Value("${cloud.aws.region.static}")
private String region;

@Bean
public AmazonS3 amazonS3Client() {
AWSCredentials credentials = new BasicAWSCredentials(accessKey, secretKey);

return AmazonS3ClientBuilder
.standard()
.withCredentials(new AWSStaticCredentialsProvider(credentials))
.withRegion(region)
.build();
}
}
Loading

0 comments on commit f6a5ee4

Please sign in to comment.