Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream #35

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions src/main/java/com/coco/boot/aspect/ResponseHeaderAspect.java

This file was deleted.

15 changes: 9 additions & 6 deletions src/main/java/com/coco/boot/controller/CoCoPilotController.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package com.coco.boot.controller;

import com.alibaba.fastjson2.JSONObject;
import com.coco.boot.common.R;
import com.coco.boot.entity.ServiceStatus;
import com.coco.boot.pojo.Conversation;
import com.coco.boot.service.CoCoPilotService;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
Expand All @@ -18,6 +19,7 @@
@AllArgsConstructor
@RestController
@RequestMapping("/")
@Slf4j
public class CoCoPilotController {


Expand Down Expand Up @@ -57,14 +59,15 @@ public ResponseEntity<String> callback(@RequestParam String code, @RequestParam


@RequestMapping(value = "/v1/**", method = {RequestMethod.GET, RequestMethod.POST})
public ResponseEntity chat(@RequestBody Conversation requestBody,
public ResponseEntity<?> chat(@RequestBody Conversation requestBody,
@RequestHeader("Authorization") String auth,
HttpServletRequest request) {
HttpServletRequest request,
HttpServletResponse response) {
try {
System.out.println("request = " + request.getRequestURI());
return coCoPilotService.chat(requestBody, auth, request.getRequestURI());
log.info("request = {}", request.getRequestURI());
return coCoPilotService.chat(requestBody, auth, request.getRequestURI(), response);
} catch (Exception e) {
e.printStackTrace();
log.error("chat error", e);
return new ResponseEntity<>(null, HttpStatus.INTERNAL_SERVER_ERROR);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/coco/boot/pojo/Conversation.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ public class Conversation {
/**
* 是否流式
*/
private Boolean stream;
private Boolean stream = true;

}
4 changes: 2 additions & 2 deletions src/main/java/com/coco/boot/service/CoCoPilotService.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.coco.boot.service;


import com.alibaba.fastjson2.JSONObject;
import com.coco.boot.common.R;
import com.coco.boot.entity.ServiceStatus;
import com.coco.boot.pojo.Conversation;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
Expand Down Expand Up @@ -33,7 +33,7 @@ public interface CoCoPilotService {
/**
* chat 接口
*/
ResponseEntity chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, String path);
ResponseEntity<?> chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, String path, HttpServletResponse response);

/**
* 服务状态
Expand Down
102 changes: 77 additions & 25 deletions src/main/java/com/coco/boot/service/impl/CoCoPilotServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.coco.boot.pojo.Conversation;
import com.coco.boot.service.CoCoPilotService;
import com.coco.boot.task.CoCoTask;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletResponse;
import jodd.util.StringUtil;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -22,17 +24,15 @@
import org.redisson.api.map.event.EntryExpiredListener;
import org.redisson.client.codec.StringCodec;
import org.springframework.http.*;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.stereotype.Service;
import org.springframework.util.DigestUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.util.*;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.view.RedirectView;


import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.CompletableFuture;
Expand All @@ -45,6 +45,7 @@
@AllArgsConstructor
@Service
@Slf4j
@EnableAsync
public class CoCoPilotServiceImpl implements CoCoPilotService {


Expand All @@ -54,6 +55,7 @@ public class CoCoPilotServiceImpl implements CoCoPilotService {

private final CoCoConfig coCoConfig;
private final RiskContrConfig rcConfig;
private final ObjectMapper objectMapper = new ObjectMapper();

private static final JSONObject NO_KEYS = JSON.parseObject("{\"message\": \"No keys\"}");
private static final JSONObject CODE_429 = JSON.parseObject("{\"message\": \"Rate limit\"}");
Expand Down Expand Up @@ -96,6 +98,7 @@ public R<String> uploadGhu(String data) {
aliveSet.add(key);
} else if (statusCode == HttpStatus.TOO_MANY_REQUESTS.value()) {
String retryAfter = response.headers().firstValue(HEADER_RETRY).orElse("100");
map.put(key, "限流");
setCoolkey(key, retryAfter);
log.info("upload 存活校验限流: {}, 返回: {}", key, response.body());
} else {
Expand Down Expand Up @@ -205,7 +208,7 @@ public ResponseEntity<String> callback(String code, String state) {


@Override
public ResponseEntity chat(Conversation requestBody, String auth, String path) {
public ResponseEntity<?> chat(Conversation requestBody, String auth, String path, HttpServletResponse response) {
JSONObject userInfo = ChatInterceptor.tl.get();
auth = auth.substring("Bearer ".length());
String userId = userInfo.getString("id");
Expand All @@ -223,8 +226,8 @@ public ResponseEntity chat(Conversation requestBody, String auth, String path) {
String tokenKey = DigestUtils.md5DigestAsHex((auth + userId).getBytes());
if (rateLimiter.tryAcquire()) {
// 调用 handleProxy 方法并获取响应
ResponseEntity response = getBaseProxyResponse(requestBody, path, ghuAliveKey);
if (response.getStatusCode().is2xxSuccessful()) {
ResponseEntity<?> result = getBaseProxyResponse(requestBody, path, ghuAliveKey, response);
if (result.getStatusCode().is2xxSuccessful()) {
//成功访问
long tokenSuccess = redissonClient.getAtomicLong(RC_TOKEN_SUCCESS_REQ + tokenKey).incrementAndGet();
if (tokenSuccess > ((long) rcConfig.getTokenMaxReq() * trustLevel)) {
Expand All @@ -235,13 +238,7 @@ public ResponseEntity chat(Conversation requestBody, String auth, String path) {
redissonClient.getBucket(RC_TEMPORARY_BAN + userId).set(true, Duration.ofHours(rcConfig.getUserMaxTime()));
}
}
MediaType contentType = response.getHeaders().getContentType();
if (contentType.equals(MediaType.APPLICATION_JSON)) {
// 处理 application/json 类型数据
return new ResponseEntity<>(JSON.parseObject(response.getBody().toString()),response.getStatusCode());
}
return response;

return result;
} else {
long l = redissonClient.getAtomicLong(RC_USER_TOKEN_LIMIT_NUM + tokenKey).incrementAndGet();
RAtomicLong userLimit = redissonClient.getAtomicLong(RC_USER_LIMIT_NUM + userId);
Expand Down Expand Up @@ -279,10 +276,8 @@ public ServiceStatus getServiceStatus() {
}




@NotNull
private ResponseEntity getBaseProxyResponse(Object requestBody, String path, RSet<String> ghuAliveKey) {
private ResponseEntity<?> getBaseProxyResponse(Conversation requestBody, String path, RSet<String> ghuAliveKey, HttpServletResponse response) {
int i = 0;
while (i < 2) {
String ghu = getGhu(ghuAliveKey);
Expand All @@ -292,22 +287,67 @@ private ResponseEntity getBaseProxyResponse(Object requestBody, String path, RSe
log.info("{}可用令牌数量,当前选择{}", ghuAliveKey.size(), ghu);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON,MediaType.TEXT_EVENT_STREAM));
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM));
headers.set("Authorization", "Bearer " + ghu);
StopWatch sw = new StopWatch();
sw.start("进入代理");
ResponseEntity<String> response = rest.postForEntity(coCoConfig.getBaseProxy() + path, new HttpEntity<>(requestBody, headers), String.class);
ResponseEntity<String> result = null;

if (!requestBody.getStream()) {
// 非流式处理使用postForEntity方法
result = rest.postForEntity(
coCoConfig.getBaseProxy() + path,
new HttpEntity<>(requestBody, headers),
String.class);
} else {
// 流式处理使用execute方法
try {
rest.execute(
coCoConfig.getBaseProxy() + path,
HttpMethod.POST,
requestCallback -> {
HttpHeaders requestHeaders = requestCallback.getHeaders();
requestHeaders.setContentType(MediaType.APPLICATION_JSON);
requestHeaders.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM));
requestHeaders.set("Authorization", "Bearer " + ghu);
String jsonBody = convertObjectToJson(requestBody);
requestCallback.getBody().write(jsonBody.getBytes(StandardCharsets.UTF_8));
},
responseExtractor -> {
responseExtractor.getHeaders().forEach((key, value) -> {
response.setHeader(key, String.join(",", value));
});
response.setStatus(responseExtractor.getRawStatusCode());
response.setHeader("Content-Type", "application/stream+json");
StreamUtils.copy(responseExtractor.getBody(), response.getOutputStream());
//ghu使用成功次数
RAtomicLong atomicLong = redissonClient.getAtomicLong(USING_GHU + ghu);
atomicLong.incrementAndGet();
return null;
}
);
return ResponseEntity.ok().build(); // 表示流式处理成功
} catch (Exception e) {
log.error("流式处理异常", e);
// 根据实际情况处理异常,例如设置重试或返回错误响应
}
}
sw.stop();
log.info(sw.prettyPrint(TimeUnit.SECONDS));
if (response.getStatusCode().is2xxSuccessful()) {
if (result == null) {
i++;
continue;
}
if (result.getStatusCode().is2xxSuccessful()) {
//ghu使用成功次数
RAtomicLong atomicLong = redissonClient.getAtomicLong(USING_GHU + ghu);
atomicLong.incrementAndGet();
return response;
// 客户端请求全部数据
return result;
} else {
ghuAliveKey.remove(ghu);
if (response.getStatusCode() == HttpStatus.TOO_MANY_REQUESTS) {
String retryAfter = response.getHeaders().getFirst(HEADER_RETRY);
if (result.getStatusCode() == HttpStatus.TOO_MANY_REQUESTS) {
String retryAfter = result.getHeaders().getFirst(HEADER_RETRY);
setCoolkey(ghu, retryAfter);
} else {
redissonClient.getSet(GHU_NO_ALIVE_KEY, StringCodec.INSTANCE).addAsync(ghu);
Expand Down Expand Up @@ -351,7 +391,7 @@ private void setCoolkey(String ghu, String retryAfter) {
if (StringUtil.isNotBlank(retryAfter)) {
try {
time = Long.parseLong(retryAfter);
} catch (NumberFormatException e) {
} catch (NumberFormatException ignored) {
}
}

Expand All @@ -369,4 +409,16 @@ private void setCoolkey(String ghu, String retryAfter) {
}
}

public String convertObjectToJson(Object obj) {
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.writeValueAsString(obj);
} catch (Exception e) {
// 处理异常:实际应用中可能需要更复杂的异常处理逻辑
log.error("Error converting object to JSON", e);
return null;
}
}


}
16 changes: 8 additions & 8 deletions src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ coco:
# 请求L站 state 过期时间 分钟
expirationTtl: 5
#重定向地址
redirectUri: http://localhost:8181/oauth2/callback
clientId: hi3geJYfTotoiR5S62u3rh4W5tSeC5UG
clientSecret: VMPBVoAfOB5ojkGXRDEtzvDhRLENHpaN
redirectUri:
clientId:
clientSecret:
# L站
authorizationEndpoint: https://connect.linux.do/oauth2/authorize
tokenEndpoint: https://connect.linux.do/oauth2/token
userEndpoint: https://connect.linux.do/api/user
authorizationEndpoint:
tokenEndpoint:
userEndpoint:
#代理节点
baseApi: https://api.cocopilot.org/copilot_internal/v2/token
baseProxy: https://proxy.cocopilot.org
baseApi:
baseProxy:
#ghu 频率秒 1秒8次
frequencyTime: 1
#ghu频率数
Expand Down