Skip to content

Commit

Permalink
[improve] improve the way Ai is entered and requested (apache#2762)
Browse files Browse the repository at this point in the history
Co-authored-by: tomsun28 <[email protected]>
Co-authored-by: shown <[email protected]>
Co-authored-by: aias00 <[email protected]>
Co-authored-by: Zhang Yuxuan <[email protected]>
  • Loading branch information
5 people authored Oct 31, 2024
1 parent bff7575 commit 4373ab2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.apache.hertzbeat.manager.config.AiProperties;
import org.apache.hertzbeat.manager.pojo.dto.AiControllerRequestParam;
import org.apache.hertzbeat.manager.service.ai.AiService;
import org.apache.hertzbeat.manager.service.ai.factory.AiServiceFactoryImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

Expand All @@ -52,15 +53,16 @@ public class AiController {

/**
* request AI
* @param text request text
* @param requestParam request text
* @return AI response
*/
@GetMapping(path = "/get", produces = {TEXT_EVENT_STREAM_VALUE})
@PostMapping(path = "/get", produces = {TEXT_EVENT_STREAM_VALUE})
@Operation(summary = "Artificial intelligence questions and Answers",
description = "Artificial intelligence questions and Answers")
public Flux<ServerSentEvent<String>> requestAi(@Parameter(description = "Request text", example = "Who are you") @RequestParam("text") String text) {
public Flux<ServerSentEvent<String>> requestAi(@Parameter(description = "Request text", example = "Who are you")
@RequestBody AiControllerRequestParam requestParam) {
Assert.notNull(aiServiceFactory, "please check that your type value is consistent with the documentation on the website");
AiService aiServiceImplBean = aiServiceFactory.getAiServiceImplBean(aiProperties.getType());
return aiServiceImplBean.requestAi(text);
return aiServiceImplBean.requestAi(requestParam.getText());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hertzbeat.manager.pojo.dto;

import lombok.Data;

/**
* AiControllerRequestParam
*/
@Data
public class AiControllerRequestParam {

/**
* required parameter
*/
String text;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import org.apache.hertzbeat.manager.config.AiProperties;
Expand All @@ -34,6 +33,7 @@
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import reactor.core.publisher.Flux;

Expand Down Expand Up @@ -74,9 +74,12 @@ public void testRequestAi() throws Exception {
when(aiService.requestAi(anyString())).thenReturn(responseFlux);
when(aiProperties.getType()).thenReturn("alibabaAi");

mockMvc.perform(get("/api/ai/get")
.param("text", "Who are you")
.accept(MediaType.TEXT_EVENT_STREAM))
String requestBody = "{\"text\":\"Who are you\"}";

mockMvc.perform((MockMvcRequestBuilders.post("/api/ai/get")
.content(requestBody)
.contentType(MediaType.APPLICATION_JSON)
.accept(MediaType.TEXT_EVENT_STREAM)))
.andExpect(status().isOk())
.andExpect(content().contentType(MediaType.TEXT_EVENT_STREAM_VALUE))
.andExpect(content().string("data:response\n\n"));
Expand Down

0 comments on commit 4373ab2

Please sign in to comment.