Skip to content

Commit

Permalink
Merge pull request #35 from HuaTru/dev
Browse files Browse the repository at this point in the history
stream
  • Loading branch information
smilexizheng authored Mar 26, 2024
2 parents 95ea412 + 44deb15 commit 863ba17
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 76 deletions.
34 changes: 0 additions & 34 deletions src/main/java/com/coco/boot/aspect/ResponseHeaderAspect.java

This file was deleted.

15 changes: 9 additions & 6 deletions src/main/java/com/coco/boot/controller/CoCoPilotController.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package com.coco.boot.controller;

import com.alibaba.fastjson2.JSONObject;
import com.coco.boot.common.R;
import com.coco.boot.entity.ServiceStatus;
import com.coco.boot.pojo.Conversation;
import com.coco.boot.service.CoCoPilotService;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
Expand All @@ -18,6 +19,7 @@
@AllArgsConstructor
@RestController
@RequestMapping("/")
@Slf4j
public class CoCoPilotController {


Expand Down Expand Up @@ -57,14 +59,15 @@ public ResponseEntity<String> callback(@RequestParam String code, @RequestParam


@RequestMapping(value = "/v1/**", method = {RequestMethod.GET, RequestMethod.POST})
public ResponseEntity chat(@RequestBody Conversation requestBody,
public ResponseEntity<?> chat(@RequestBody Conversation requestBody,
@RequestHeader("Authorization") String auth,
HttpServletRequest request) {
HttpServletRequest request,
HttpServletResponse response) {
try {
System.out.println("request = " + request.getRequestURI());
return coCoPilotService.chat(requestBody, auth, request.getRequestURI());
log.info("request = {}", request.getRequestURI());
return coCoPilotService.chat(requestBody, auth, request.getRequestURI(), response);
} catch (Exception e) {
e.printStackTrace();
log.error("chat error", e);
return new ResponseEntity<>(null, HttpStatus.INTERNAL_SERVER_ERROR);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/coco/boot/pojo/Conversation.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ public class Conversation {
/**
* 是否流式
*/
private Boolean stream;
private Boolean stream = true;

}
4 changes: 2 additions & 2 deletions src/main/java/com/coco/boot/service/CoCoPilotService.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.coco.boot.service;


import com.alibaba.fastjson2.JSONObject;
import com.coco.boot.common.R;
import com.coco.boot.entity.ServiceStatus;
import com.coco.boot.pojo.Conversation;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
Expand Down Expand Up @@ -33,7 +33,7 @@ public interface CoCoPilotService {
/**
* chat 接口
*/
ResponseEntity chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, String path);
ResponseEntity<?> chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, String path, HttpServletResponse response);

/**
* 服务状态
Expand Down
102 changes: 77 additions & 25 deletions src/main/java/com/coco/boot/service/impl/CoCoPilotServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.coco.boot.pojo.Conversation;
import com.coco.boot.service.CoCoPilotService;
import com.coco.boot.task.CoCoTask;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletResponse;
import jodd.util.StringUtil;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -22,17 +24,15 @@
import org.redisson.api.map.event.EntryExpiredListener;
import org.redisson.client.codec.StringCodec;
import org.springframework.http.*;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.stereotype.Service;
import org.springframework.util.DigestUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.util.*;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.view.RedirectView;


import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.CompletableFuture;
Expand All @@ -45,6 +45,7 @@
@AllArgsConstructor
@Service
@Slf4j
@EnableAsync
public class CoCoPilotServiceImpl implements CoCoPilotService {


Expand All @@ -54,6 +55,7 @@ public class CoCoPilotServiceImpl implements CoCoPilotService {

private final CoCoConfig coCoConfig;
private final RiskContrConfig rcConfig;
private final ObjectMapper objectMapper = new ObjectMapper();

private static final JSONObject NO_KEYS = JSON.parseObject("{\"message\": \"No keys\"}");
private static final JSONObject CODE_429 = JSON.parseObject("{\"message\": \"Rate limit\"}");
Expand Down Expand Up @@ -96,6 +98,7 @@ public R<String> uploadGhu(String data) {
aliveSet.add(key);
} else if (statusCode == HttpStatus.TOO_MANY_REQUESTS.value()) {
String retryAfter = response.headers().firstValue(HEADER_RETRY).orElse("100");
map.put(key, "限流");
setCoolkey(key, retryAfter);
log.info("upload 存活校验限流: {}, 返回: {}", key, response.body());
} else {
Expand Down Expand Up @@ -205,7 +208,7 @@ public ResponseEntity<String> callback(String code, String state) {


@Override
public ResponseEntity chat(Conversation requestBody, String auth, String path) {
public ResponseEntity<?> chat(Conversation requestBody, String auth, String path, HttpServletResponse response) {
JSONObject userInfo = ChatInterceptor.tl.get();
auth = auth.substring("Bearer ".length());
String userId = userInfo.getString("id");
Expand All @@ -223,8 +226,8 @@ public ResponseEntity chat(Conversation requestBody, String auth, String path) {
String tokenKey = DigestUtils.md5DigestAsHex((auth + userId).getBytes());
if (rateLimiter.tryAcquire()) {
// 调用 handleProxy 方法并获取响应
ResponseEntity response = getBaseProxyResponse(requestBody, path, ghuAliveKey);
if (response.getStatusCode().is2xxSuccessful()) {
ResponseEntity<?> result = getBaseProxyResponse(requestBody, path, ghuAliveKey, response);
if (result.getStatusCode().is2xxSuccessful()) {
//成功访问
long tokenSuccess = redissonClient.getAtomicLong(RC_TOKEN_SUCCESS_REQ + tokenKey).incrementAndGet();
if (tokenSuccess > ((long) rcConfig.getTokenMaxReq() * trustLevel)) {
Expand All @@ -235,13 +238,7 @@ public ResponseEntity chat(Conversation requestBody, String auth, String path) {
redissonClient.getBucket(RC_TEMPORARY_BAN + userId).set(true, Duration.ofHours(rcConfig.getUserMaxTime()));
}
}
MediaType contentType = response.getHeaders().getContentType();
if (contentType.equals(MediaType.APPLICATION_JSON)) {
// 处理 application/json 类型数据
return new ResponseEntity<>(JSON.parseObject(response.getBody().toString()),response.getStatusCode());
}
return response;

return result;
} else {
long l = redissonClient.getAtomicLong(RC_USER_TOKEN_LIMIT_NUM + tokenKey).incrementAndGet();
RAtomicLong userLimit = redissonClient.getAtomicLong(RC_USER_LIMIT_NUM + userId);
Expand Down Expand Up @@ -279,10 +276,8 @@ public ServiceStatus getServiceStatus() {
}




@NotNull
private ResponseEntity getBaseProxyResponse(Object requestBody, String path, RSet<String> ghuAliveKey) {
private ResponseEntity<?> getBaseProxyResponse(Conversation requestBody, String path, RSet<String> ghuAliveKey, HttpServletResponse response) {
int i = 0;
while (i < 2) {
String ghu = getGhu(ghuAliveKey);
Expand All @@ -292,22 +287,67 @@ private ResponseEntity getBaseProxyResponse(Object requestBody, String path, RSe
log.info("{}可用令牌数量,当前选择{}", ghuAliveKey.size(), ghu);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON,MediaType.TEXT_EVENT_STREAM));
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM));
headers.set("Authorization", "Bearer " + ghu);
StopWatch sw = new StopWatch();
sw.start("进入代理");
ResponseEntity<String> response = rest.postForEntity(coCoConfig.getBaseProxy() + path, new HttpEntity<>(requestBody, headers), String.class);
ResponseEntity<String> result = null;

if (!requestBody.getStream()) {
// 非流式处理使用postForEntity方法
result = rest.postForEntity(
coCoConfig.getBaseProxy() + path,
new HttpEntity<>(requestBody, headers),
String.class);
} else {
// 流式处理使用execute方法
try {
rest.execute(
coCoConfig.getBaseProxy() + path,
HttpMethod.POST,
requestCallback -> {
HttpHeaders requestHeaders = requestCallback.getHeaders();
requestHeaders.setContentType(MediaType.APPLICATION_JSON);
requestHeaders.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM));
requestHeaders.set("Authorization", "Bearer " + ghu);
String jsonBody = convertObjectToJson(requestBody);
requestCallback.getBody().write(jsonBody.getBytes(StandardCharsets.UTF_8));
},
responseExtractor -> {
responseExtractor.getHeaders().forEach((key, value) -> {
response.setHeader(key, String.join(",", value));
});
response.setStatus(responseExtractor.getRawStatusCode());
response.setHeader("Content-Type", "application/stream+json");
StreamUtils.copy(responseExtractor.getBody(), response.getOutputStream());
//ghu使用成功次数
RAtomicLong atomicLong = redissonClient.getAtomicLong(USING_GHU + ghu);
atomicLong.incrementAndGet();
return null;
}
);
return ResponseEntity.ok().build(); // 表示流式处理成功
} catch (Exception e) {
log.error("流式处理异常", e);
// 根据实际情况处理异常,例如设置重试或返回错误响应
}
}
sw.stop();
log.info(sw.prettyPrint(TimeUnit.SECONDS));
if (response.getStatusCode().is2xxSuccessful()) {
if (result == null) {
i++;
continue;
}
if (result.getStatusCode().is2xxSuccessful()) {
//ghu使用成功次数
RAtomicLong atomicLong = redissonClient.getAtomicLong(USING_GHU + ghu);
atomicLong.incrementAndGet();
return response;
// 客户端请求全部数据
return result;
} else {
ghuAliveKey.remove(ghu);
if (response.getStatusCode() == HttpStatus.TOO_MANY_REQUESTS) {
String retryAfter = response.getHeaders().getFirst(HEADER_RETRY);
if (result.getStatusCode() == HttpStatus.TOO_MANY_REQUESTS) {
String retryAfter = result.getHeaders().getFirst(HEADER_RETRY);
setCoolkey(ghu, retryAfter);
} else {
redissonClient.getSet(GHU_NO_ALIVE_KEY, StringCodec.INSTANCE).addAsync(ghu);
Expand Down Expand Up @@ -351,7 +391,7 @@ private void setCoolkey(String ghu, String retryAfter) {
if (StringUtil.isNotBlank(retryAfter)) {
try {
time = Long.parseLong(retryAfter);
} catch (NumberFormatException e) {
} catch (NumberFormatException ignored) {
}
}

Expand All @@ -369,4 +409,16 @@ private void setCoolkey(String ghu, String retryAfter) {
}
}

public String convertObjectToJson(Object obj) {
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.writeValueAsString(obj);
} catch (Exception e) {
// 处理异常:实际应用中可能需要更复杂的异常处理逻辑
log.error("Error converting object to JSON", e);
return null;
}
}


}
16 changes: 8 additions & 8 deletions src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ coco:
# 请求L站 state 过期时间 分钟
expirationTtl: 5
#重定向地址
redirectUri: http://localhost:8181/oauth2/callback
clientId: hi3geJYfTotoiR5S62u3rh4W5tSeC5UG
clientSecret: VMPBVoAfOB5ojkGXRDEtzvDhRLENHpaN
redirectUri:
clientId:
clientSecret:
# L站
authorizationEndpoint: https://connect.linux.do/oauth2/authorize
tokenEndpoint: https://connect.linux.do/oauth2/token
userEndpoint: https://connect.linux.do/api/user
authorizationEndpoint:
tokenEndpoint:
userEndpoint:
#代理节点
baseApi: https://api.cocopilot.org/copilot_internal/v2/token
baseProxy: https://proxy.cocopilot.org
baseApi:
baseProxy:
#ghu 频率秒 1秒8次
frequencyTime: 1
#ghu频率数
Expand Down

0 comments on commit 863ba17

Please sign in to comment.