Skip to content

Commit

Permalink
[imporve] move ai package and improve ai code. (apache#2542)
Browse files Browse the repository at this point in the history
Co-authored-by: 刘进山 <[email protected]>
Co-authored-by: YuLuo <[email protected]>
Co-authored-by: linDong <[email protected]>
Co-authored-by: Calvin <[email protected]>
  • Loading branch information
5 people authored Aug 18, 2024
1 parent 7ac21d3 commit a2bd619
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.hertzbeat.common.constants;

import java.util.Arrays;

/**
* Ai type Enum
*/
Expand Down Expand Up @@ -48,13 +50,10 @@ public enum AiTypeEnum {
* get type
*/
public static AiTypeEnum getTypeByName(String type) {
for (AiTypeEnum aiTypeEnum : values()) {
if (aiTypeEnum.name().equals(type)) {
return aiTypeEnum;
}

}
return null;
return Arrays.stream(values())
.filter(ai -> ai.name().equals(type))
.findFirst()
.orElse(null);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.apache.hertzbeat.manager.config.AiProperties;
import org.apache.hertzbeat.manager.service.AiService;
import org.apache.hertzbeat.manager.service.impl.AiServiceFactoryImpl;
import org.apache.hertzbeat.manager.service.ai.AiService;
import org.apache.hertzbeat.manager.service.ai.factory.AiServiceFactoryImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -59,10 +59,8 @@ public class AiController {
@Operation(summary = "Artificial intelligence questions and Answers",
description = "Artificial intelligence questions and Answers")
public Flux<ServerSentEvent<String>> requestAi(@Parameter(description = "Request text", example = "Who are you") @RequestParam("text") String text) {

Assert.notNull(aiServiceFactory, "please check that your type value is consistent with the documentation on the website");
AiService aiServiceImplBean = aiServiceFactory.getAiServiceImplBean(aiProperties.getType());

return aiServiceImplBean.requestAi(text);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.hertzbeat.manager.service;
package org.apache.hertzbeat.manager.service.ai;


import org.apache.hertzbeat.common.constants.AiTypeEnum;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.hertzbeat.manager.service.impl;
package org.apache.hertzbeat.manager.service.ai;

import java.util.List;
import java.util.Objects;
Expand All @@ -27,7 +27,6 @@
import org.apache.hertzbeat.manager.pojo.dto.AiMessage;
import org.apache.hertzbeat.manager.pojo.dto.AliAiRequestParamDTO;
import org.apache.hertzbeat.manager.pojo.dto.AliAiResponse;
import org.apache.hertzbeat.manager.service.AiService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -77,46 +76,39 @@ public AiTypeEnum getType() {
@Override
public Flux<ServerSentEvent<String>> requestAi(String text) {
checkParam(text, aiProperties.getModel(), aiProperties.getApiKey());
try {
AliAiRequestParamDTO aliAiRequestParamDTO = AliAiRequestParamDTO.builder()
.model(aiProperties.getModel())
.input(AliAiRequestParamDTO.Input.builder()
.messages(List.of(new AiMessage(AiConstants.AliAiConstants.REQUEST_ROLE, text)))
.build())
.parameters(AliAiRequestParamDTO.Parameters.builder()
.maxTokens(AiConstants.AliAiConstants.MAX_TOKENS)
.temperature(AiConstants.AliAiConstants.TEMPERATURE)
.enableSearch(true)
.resultFormat("message")
.incrementalOutput(true)
.build())
.build();
AliAiRequestParamDTO aliAiRequestParamDTO = AliAiRequestParamDTO.builder()
.model(aiProperties.getModel())
.input(AliAiRequestParamDTO.Input.builder()
.messages(List.of(new AiMessage(AiConstants.AliAiConstants.REQUEST_ROLE, text)))
.build())
.parameters(AliAiRequestParamDTO.Parameters.builder()
.maxTokens(AiConstants.AliAiConstants.MAX_TOKENS)
.temperature(AiConstants.AliAiConstants.TEMPERATURE)
.enableSearch(true)
.resultFormat("message")
.incrementalOutput(true)
.build())
.build();


return webClient.post()
.body(BodyInserters.fromValue(aliAiRequestParamDTO))
.retrieve()
.bodyToFlux(AliAiResponse.class)
.map(aliAiResponse -> {
if (Objects.nonNull(aliAiResponse)) {
List<AliAiResponse.Choice> choices = aliAiResponse.getOutput().getChoices();
if (CollectionUtils.isEmpty(choices)) {
return ServerSentEvent.<String>builder().build();
}
String content = choices.get(0).getMessage().getContent();
return ServerSentEvent.<String>builder()
.data(content)
.build();
return webClient.post()
.body(BodyInserters.fromValue(aliAiRequestParamDTO))
.retrieve()
.bodyToFlux(AliAiResponse.class)
.map(aliAiResponse -> {
if (Objects.nonNull(aliAiResponse)) {
List<AliAiResponse.Choice> choices = aliAiResponse.getOutput().getChoices();
if (CollectionUtils.isEmpty(choices)) {
return ServerSentEvent.<String>builder().build();
}
return ServerSentEvent.<String>builder().build();
})
.doOnError(error -> log.info("AiResponse Exception:{}", error.toString()));

} catch (Exception e) {
log.info("KimiAiServiceImpl.requestAi exception:{}", e.toString());
throw e;
}

String content = choices.get(0).getMessage().getContent();
return ServerSentEvent.<String>builder()
.data(content)
.build();
}
return ServerSentEvent.<String>builder().build();
})
.doOnError(error -> log.info("AlibabaAiServiceImpl.requestAi exception:{}", error.getMessage()));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.hertzbeat.manager.service.impl;
package org.apache.hertzbeat.manager.service.ai;

import java.util.List;
import javax.annotation.PostConstruct;
Expand All @@ -26,7 +26,6 @@
import org.apache.hertzbeat.manager.pojo.dto.AiMessage;
import org.apache.hertzbeat.manager.pojo.dto.OpenAiRequestParamDTO;
import org.apache.hertzbeat.manager.pojo.dto.OpenAiResponse;
import org.apache.hertzbeat.manager.service.AiService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -71,37 +70,28 @@ public AiTypeEnum getType() {

@Override
public Flux<ServerSentEvent<String>> requestAi(String text) {
try {
checkParam(text, aiProperties.getModel(), aiProperties.getApiKey());
OpenAiRequestParamDTO zhiPuRequestParamDTO = OpenAiRequestParamDTO.builder()
.model(aiProperties.getModel())
.stream(Boolean.TRUE)
.maxTokens(AiConstants.KimiAiConstants.MAX_TOKENS)
.temperature(AiConstants.KimiAiConstants.TEMPERATURE)
.messages(List.of(new AiMessage(AiConstants.KimiAiConstants.REQUEST_ROLE, text)))
.build();


return webClient.post()
.body(BodyInserters.fromValue(zhiPuRequestParamDTO))
.retrieve()
.bodyToFlux(String.class)
.filter(aiResponse -> !"[DONE]".equals(aiResponse))
.map(OpenAiResponse::convertToResponse)
.doOnError(error -> log.info("AiResponse Exception:{}", error.toString()));


} catch (Exception e) {
log.info("KimiAiServiceImpl.requestAi exception:{}", e.toString());
throw e;
}
checkParam(text, aiProperties.getModel(), aiProperties.getApiKey());
OpenAiRequestParamDTO zhiPuRequestParamDTO = OpenAiRequestParamDTO.builder()
.model(aiProperties.getModel())
.stream(Boolean.TRUE)
.maxTokens(AiConstants.KimiAiConstants.MAX_TOKENS)
.temperature(AiConstants.KimiAiConstants.TEMPERATURE)
.messages(List.of(new AiMessage(AiConstants.KimiAiConstants.REQUEST_ROLE, text)))
.build();


return webClient.post()
.body(BodyInserters.fromValue(zhiPuRequestParamDTO))
.retrieve()
.bodyToFlux(String.class)
.filter(aiResponse -> !"[DONE]".equals(aiResponse))
.map(OpenAiResponse::convertToResponse)
.doOnError(error -> log.info("KimiAiServiceImpl.requestAi exception:{}", error.getMessage()));
}

private void checkParam(String param, String model, String apiKey) {
Assert.notNull(param, "text is null");
Assert.notNull(param, "model is null");
Assert.notNull(model, "model is null");
Assert.notNull(apiKey, "ai.api-key is null");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.hertzbeat.manager.service.impl;
package org.apache.hertzbeat.manager.service.ai;

import java.util.List;
import javax.annotation.PostConstruct;
Expand All @@ -26,7 +26,6 @@
import org.apache.hertzbeat.manager.pojo.dto.AiMessage;
import org.apache.hertzbeat.manager.pojo.dto.OpenAiRequestParamDTO;
import org.apache.hertzbeat.manager.pojo.dto.OpenAiResponse;
import org.apache.hertzbeat.manager.service.AiService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -78,33 +77,29 @@ public AiTypeEnum getType() {

@Override
public Flux<ServerSentEvent<String>> requestAi(String text) {
checkParam(text, aiProperties.getApiKey(), aiProperties.getModel());
OpenAiRequestParamDTO zhiPuRequestParamDTO = OpenAiRequestParamDTO.builder()
.model(aiProperties.getModel())
//sse
.stream(Boolean.TRUE)
.maxTokens(AiConstants.SparkDeskConstants.MAX_TOKENS)
.temperature(AiConstants.SparkDeskConstants.TEMPERATURE)
.messages(List.of(new AiMessage(AiConstants.SparkDeskConstants.REQUEST_ROLE, text)))
.build();

try {
checkParam(text, aiProperties.getApiKey(), aiProperties.getModel());
OpenAiRequestParamDTO zhiPuRequestParamDTO = OpenAiRequestParamDTO.builder()
.model(aiProperties.getModel())
//sse
.stream(Boolean.TRUE)
.maxTokens(AiConstants.SparkDeskConstants.MAX_TOKENS)
.temperature(AiConstants.SparkDeskConstants.TEMPERATURE)
.messages(List.of(new AiMessage(AiConstants.SparkDeskConstants.REQUEST_ROLE, text)))
.build();
return webClient.post()
.body(BodyInserters.fromValue(zhiPuRequestParamDTO))
.retrieve()
.bodyToFlux(String.class)
.filter(aiResponse -> !"[DONE]".equals(aiResponse))
.map(OpenAiResponse::convertToResponse)
.doOnError(error -> log.info("SparkDeskAiServiceImpl.requestAi exception:{}", error.getMessage()));

return webClient.post()
.body(BodyInserters.fromValue(zhiPuRequestParamDTO))
.retrieve()
.bodyToFlux(String.class)
.filter(aiResponse -> !"[DONE]".equals(aiResponse))
.map(OpenAiResponse::convertToResponse);
} catch (Exception e) {
log.info("SparkDeskAiServiceImpl.requestAi exception:{}", e.toString());
throw e;
}
}

private void checkParam(String param, String apiKey, String model) {
Assert.notNull(param, "text is null");
Assert.notNull(param, "model is null");
Assert.notNull(model, "model is null");
Assert.notNull(apiKey, "ai.api-key is null");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.hertzbeat.manager.service.impl;
package org.apache.hertzbeat.manager.service.ai;


import java.util.List;
Expand All @@ -27,7 +27,6 @@
import org.apache.hertzbeat.manager.pojo.dto.AiMessage;
import org.apache.hertzbeat.manager.pojo.dto.OpenAiRequestParamDTO;
import org.apache.hertzbeat.manager.pojo.dto.OpenAiResponse;
import org.apache.hertzbeat.manager.service.AiService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -73,33 +72,28 @@ public AiTypeEnum getType() {

@Override
public Flux<ServerSentEvent<String>> requestAi(String text) {
try {
checkParam(text, aiProperties.getModel(), aiProperties.getApiKey());
OpenAiRequestParamDTO zhiPuRequestParamDTO = OpenAiRequestParamDTO.builder()
.model(aiProperties.getModel())
//sse
.stream(Boolean.TRUE)
.maxTokens(AiConstants.ZhiPuConstants.MAX_TOKENS)
.temperature(AiConstants.ZhiPuConstants.TEMPERATURE)
.messages(List.of(new AiMessage(AiConstants.ZhiPuConstants.REQUEST_ROLE, text)))
.build();

return webClient.post()
.body(BodyInserters.fromValue(zhiPuRequestParamDTO))
.retrieve()
.bodyToFlux(String.class)
.filter(aiResponse -> !"[DONE]".equals(aiResponse))
.map(OpenAiResponse::convertToResponse)
.doOnError(error -> log.info("AiResponse Exception:{}", error.toString()));
checkParam(text, aiProperties.getModel(), aiProperties.getApiKey());
OpenAiRequestParamDTO zhiPuRequestParamDTO = OpenAiRequestParamDTO.builder()
.model(aiProperties.getModel())
//sse
.stream(Boolean.TRUE)
.maxTokens(AiConstants.ZhiPuConstants.MAX_TOKENS)
.temperature(AiConstants.ZhiPuConstants.TEMPERATURE)
.messages(List.of(new AiMessage(AiConstants.ZhiPuConstants.REQUEST_ROLE, text)))
.build();

} catch (Exception e) {
log.info("ZhiPuServiceImpl.requestAi exception:{}", e.toString());
throw e;
}
return webClient.post()
.body(BodyInserters.fromValue(zhiPuRequestParamDTO))
.retrieve()
.bodyToFlux(String.class)
.filter(aiResponse -> !"[DONE]".equals(aiResponse))
.map(OpenAiResponse::convertToResponse)
.doOnError(error -> log.info("ZhiPuServiceImpl.requestAi exception:{}", error.getMessage()));
}

private void checkParam(String param, String model, String apiKey) {
Assert.notNull(param, "text is null");
Assert.notNull(model, "model is null");
Assert.notNull(apiKey, "ai.api-key is null");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.hertzbeat.manager.service.impl;
package org.apache.hertzbeat.manager.service.ai.factory;

import java.util.HashMap;
import java.util.List;
Expand All @@ -24,7 +24,7 @@
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import org.apache.hertzbeat.common.constants.AiTypeEnum;
import org.apache.hertzbeat.manager.service.AiService;
import org.apache.hertzbeat.manager.service.ai.AiService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.stereotype.Component;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.hertzbeat.manager.controller;

import org.apache.hertzbeat.manager.config.AiProperties;
import org.apache.hertzbeat.manager.service.AiService;
import org.apache.hertzbeat.manager.service.impl.AiServiceFactoryImpl;
import org.apache.hertzbeat.manager.service.ai.AiService;
import org.apache.hertzbeat.manager.service.ai.factory.AiServiceFactoryImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.hertzbeat.common.constants.AiTypeEnum;
import org.apache.hertzbeat.manager.service.impl.AiServiceFactoryImpl;
import org.apache.hertzbeat.manager.service.ai.AiService;
import org.apache.hertzbeat.manager.service.ai.factory.AiServiceFactoryImpl;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down
Loading

0 comments on commit a2bd619

Please sign in to comment.