diff --git a/src/main/java/com/coco/boot/aspect/ResponseHeaderAspect.java b/src/main/java/com/coco/boot/aspect/ResponseHeaderAspect.java deleted file mode 100644 index 9bdd4d2..0000000 --- a/src/main/java/com/coco/boot/aspect/ResponseHeaderAspect.java +++ /dev/null @@ -1,34 +0,0 @@ -//package com.coco.boot.aspect; -// -//import jakarta.servlet.http.HttpServletResponse; -//import org.aspectj.lang.ProceedingJoinPoint; -//import org.aspectj.lang.annotation.Around; -//import org.aspectj.lang.annotation.Aspect; -//import org.aspectj.lang.annotation.Pointcut; -//import org.springframework.stereotype.Component; -//import org.springframework.web.context.request.RequestContextHolder; -//import org.springframework.web.context.request.ServletRequestAttributes; -// -//@Aspect -//@Component -//public class ResponseHeaderAspect { -// -// @Pointcut("execution(public * *(..)) && @within(org.springframework.web.bind.annotation.RestController)") -// public void controllerMethods() { -// } -// -// @Around("controllerMethods()") -// public Object apiResponseAdvice(ProceedingJoinPoint pjp) throws Throwable { -// Object proceed = pjp.proceed(); -// ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); -// -// if (sra != null) { -// HttpServletResponse response = sra.getResponse(); -// if (response != null) { -// response.setHeader("Content-Type", "application/json; charset=utf-8"); -// } -// } -// -// return proceed; -// } -//} diff --git a/src/main/java/com/coco/boot/controller/CoCoPilotController.java b/src/main/java/com/coco/boot/controller/CoCoPilotController.java index 641338e..a821787 100644 --- a/src/main/java/com/coco/boot/controller/CoCoPilotController.java +++ b/src/main/java/com/coco/boot/controller/CoCoPilotController.java @@ -1,12 +1,13 @@ package com.coco.boot.controller; -import com.alibaba.fastjson2.JSONObject; import com.coco.boot.common.R; import com.coco.boot.entity.ServiceStatus; import com.coco.boot.pojo.Conversation; import com.coco.boot.service.CoCoPilotService; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; @@ -18,6 +19,7 @@ @AllArgsConstructor @RestController @RequestMapping("/") +@Slf4j public class CoCoPilotController { @@ -57,14 +59,15 @@ public ResponseEntity callback(@RequestParam String code, @RequestParam @RequestMapping(value = "/v1/**", method = {RequestMethod.GET, RequestMethod.POST}) - public ResponseEntity chat(@RequestBody Conversation requestBody, + public ResponseEntity chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, - HttpServletRequest request) { + HttpServletRequest request, + HttpServletResponse response) { try { - System.out.println("request = " + request.getRequestURI()); - return coCoPilotService.chat(requestBody, auth, request.getRequestURI()); + log.info("request = {}", request.getRequestURI()); + return coCoPilotService.chat(requestBody, auth, request.getRequestURI(), response); } catch (Exception e) { - e.printStackTrace(); + log.error("chat error", e); return new ResponseEntity<>(null, HttpStatus.INTERNAL_SERVER_ERROR); } } diff --git a/src/main/java/com/coco/boot/pojo/Conversation.java b/src/main/java/com/coco/boot/pojo/Conversation.java index c3c8c68..c673909 100644 --- a/src/main/java/com/coco/boot/pojo/Conversation.java +++ b/src/main/java/com/coco/boot/pojo/Conversation.java @@ -30,6 +30,6 @@ public class Conversation { /** * 是否流式 */ - private Boolean stream; + private Boolean stream = true; } diff --git a/src/main/java/com/coco/boot/service/CoCoPilotService.java b/src/main/java/com/coco/boot/service/CoCoPilotService.java index 717a4a1..5ec1e05 100644 --- a/src/main/java/com/coco/boot/service/CoCoPilotService.java +++ b/src/main/java/com/coco/boot/service/CoCoPilotService.java @@ -1,10 +1,10 @@ package com.coco.boot.service; -import com.alibaba.fastjson2.JSONObject; import com.coco.boot.common.R; import com.coco.boot.entity.ServiceStatus; import com.coco.boot.pojo.Conversation; +import jakarta.servlet.http.HttpServletResponse; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestHeader; @@ -33,7 +33,7 @@ public interface CoCoPilotService { /** * chat 接口 */ - ResponseEntity chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, String path); + ResponseEntity chat(@RequestBody Conversation requestBody, @RequestHeader("Authorization") String auth, String path, HttpServletResponse response); /** * 服务状态 diff --git a/src/main/java/com/coco/boot/service/impl/CoCoPilotServiceImpl.java b/src/main/java/com/coco/boot/service/impl/CoCoPilotServiceImpl.java index ff54a47..67900d3 100644 --- a/src/main/java/com/coco/boot/service/impl/CoCoPilotServiceImpl.java +++ b/src/main/java/com/coco/boot/service/impl/CoCoPilotServiceImpl.java @@ -14,6 +14,8 @@ import com.coco.boot.pojo.Conversation; import com.coco.boot.service.CoCoPilotService; import com.coco.boot.task.CoCoTask; +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.servlet.http.HttpServletResponse; import jodd.util.StringUtil; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -22,17 +24,15 @@ import org.redisson.api.map.event.EntryExpiredListener; import org.redisson.client.codec.StringCodec; import org.springframework.http.*; +import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.stereotype.Service; -import org.springframework.util.DigestUtils; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; +import org.springframework.util.*; import org.springframework.web.client.RestTemplate; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.view.RedirectView; - import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.*; import java.util.concurrent.CompletableFuture; @@ -45,6 +45,7 @@ @AllArgsConstructor @Service @Slf4j +@EnableAsync public class CoCoPilotServiceImpl implements CoCoPilotService { @@ -54,6 +55,7 @@ public class CoCoPilotServiceImpl implements CoCoPilotService { private final CoCoConfig coCoConfig; private final RiskContrConfig rcConfig; + private final ObjectMapper objectMapper = new ObjectMapper(); private static final JSONObject NO_KEYS = JSON.parseObject("{\"message\": \"No keys\"}"); private static final JSONObject CODE_429 = JSON.parseObject("{\"message\": \"Rate limit\"}"); @@ -96,6 +98,7 @@ public R uploadGhu(String data) { aliveSet.add(key); } else if (statusCode == HttpStatus.TOO_MANY_REQUESTS.value()) { String retryAfter = response.headers().firstValue(HEADER_RETRY).orElse("100"); + map.put(key, "限流"); setCoolkey(key, retryAfter); log.info("upload 存活校验限流: {}, 返回: {}", key, response.body()); } else { @@ -205,7 +208,7 @@ public ResponseEntity callback(String code, String state) { @Override - public ResponseEntity chat(Conversation requestBody, String auth, String path) { + public ResponseEntity chat(Conversation requestBody, String auth, String path, HttpServletResponse response) { JSONObject userInfo = ChatInterceptor.tl.get(); auth = auth.substring("Bearer ".length()); String userId = userInfo.getString("id"); @@ -223,8 +226,8 @@ public ResponseEntity chat(Conversation requestBody, String auth, String path) { String tokenKey = DigestUtils.md5DigestAsHex((auth + userId).getBytes()); if (rateLimiter.tryAcquire()) { // 调用 handleProxy 方法并获取响应 - ResponseEntity response = getBaseProxyResponse(requestBody, path, ghuAliveKey); - if (response.getStatusCode().is2xxSuccessful()) { + ResponseEntity result = getBaseProxyResponse(requestBody, path, ghuAliveKey, response); + if (result.getStatusCode().is2xxSuccessful()) { //成功访问 long tokenSuccess = redissonClient.getAtomicLong(RC_TOKEN_SUCCESS_REQ + tokenKey).incrementAndGet(); if (tokenSuccess > ((long) rcConfig.getTokenMaxReq() * trustLevel)) { @@ -235,13 +238,7 @@ public ResponseEntity chat(Conversation requestBody, String auth, String path) { redissonClient.getBucket(RC_TEMPORARY_BAN + userId).set(true, Duration.ofHours(rcConfig.getUserMaxTime())); } } - MediaType contentType = response.getHeaders().getContentType(); - if (contentType.equals(MediaType.APPLICATION_JSON)) { - // 处理 application/json 类型数据 - return new ResponseEntity<>(JSON.parseObject(response.getBody().toString()),response.getStatusCode()); - } - return response; - + return result; } else { long l = redissonClient.getAtomicLong(RC_USER_TOKEN_LIMIT_NUM + tokenKey).incrementAndGet(); RAtomicLong userLimit = redissonClient.getAtomicLong(RC_USER_LIMIT_NUM + userId); @@ -279,10 +276,8 @@ public ServiceStatus getServiceStatus() { } - - @NotNull - private ResponseEntity getBaseProxyResponse(Object requestBody, String path, RSet ghuAliveKey) { + private ResponseEntity getBaseProxyResponse(Conversation requestBody, String path, RSet ghuAliveKey, HttpServletResponse response) { int i = 0; while (i < 2) { String ghu = getGhu(ghuAliveKey); @@ -292,22 +287,67 @@ private ResponseEntity getBaseProxyResponse(Object requestBody, String path, RSe log.info("{}可用令牌数量,当前选择{}", ghuAliveKey.size(), ghu); HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); - headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON,MediaType.TEXT_EVENT_STREAM)); + headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM)); headers.set("Authorization", "Bearer " + ghu); StopWatch sw = new StopWatch(); sw.start("进入代理"); - ResponseEntity response = rest.postForEntity(coCoConfig.getBaseProxy() + path, new HttpEntity<>(requestBody, headers), String.class); + ResponseEntity result = null; + + if (!requestBody.getStream()) { + // 非流式处理使用postForEntity方法 + result = rest.postForEntity( + coCoConfig.getBaseProxy() + path, + new HttpEntity<>(requestBody, headers), + String.class); + } else { + // 流式处理使用execute方法 + try { + rest.execute( + coCoConfig.getBaseProxy() + path, + HttpMethod.POST, + requestCallback -> { + HttpHeaders requestHeaders = requestCallback.getHeaders(); + requestHeaders.setContentType(MediaType.APPLICATION_JSON); + requestHeaders.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM)); + requestHeaders.set("Authorization", "Bearer " + ghu); + String jsonBody = convertObjectToJson(requestBody); + requestCallback.getBody().write(jsonBody.getBytes(StandardCharsets.UTF_8)); + }, + responseExtractor -> { + responseExtractor.getHeaders().forEach((key, value) -> { + response.setHeader(key, String.join(",", value)); + }); + response.setStatus(responseExtractor.getRawStatusCode()); + response.setHeader("Content-Type", "application/stream+json"); + StreamUtils.copy(responseExtractor.getBody(), response.getOutputStream()); + //ghu使用成功次数 + RAtomicLong atomicLong = redissonClient.getAtomicLong(USING_GHU + ghu); + atomicLong.incrementAndGet(); + return null; + } + ); + return ResponseEntity.ok().build(); // 表示流式处理成功 + } catch (Exception e) { + log.error("流式处理异常", e); + // 根据实际情况处理异常,例如设置重试或返回错误响应 + } + } sw.stop(); log.info(sw.prettyPrint(TimeUnit.SECONDS)); - if (response.getStatusCode().is2xxSuccessful()) { + if (result == null) { + i++; + continue; + } + if (result.getStatusCode().is2xxSuccessful()) { //ghu使用成功次数 RAtomicLong atomicLong = redissonClient.getAtomicLong(USING_GHU + ghu); atomicLong.incrementAndGet(); - return response; + // 客户端请求全部数据 + return result; } else { ghuAliveKey.remove(ghu); - if (response.getStatusCode() == HttpStatus.TOO_MANY_REQUESTS) { - String retryAfter = response.getHeaders().getFirst(HEADER_RETRY); + if (result.getStatusCode() == HttpStatus.TOO_MANY_REQUESTS) { + String retryAfter = result.getHeaders().getFirst(HEADER_RETRY); setCoolkey(ghu, retryAfter); } else { redissonClient.getSet(GHU_NO_ALIVE_KEY, StringCodec.INSTANCE).addAsync(ghu); @@ -351,7 +391,7 @@ private void setCoolkey(String ghu, String retryAfter) { if (StringUtil.isNotBlank(retryAfter)) { try { time = Long.parseLong(retryAfter); - } catch (NumberFormatException e) { + } catch (NumberFormatException ignored) { } } @@ -369,4 +409,16 @@ private void setCoolkey(String ghu, String retryAfter) { } } + public String convertObjectToJson(Object obj) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + return objectMapper.writeValueAsString(obj); + } catch (Exception e) { + // 处理异常:实际应用中可能需要更复杂的异常处理逻辑 + log.error("Error converting object to JSON", e); + return null; + } + } + + } diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 41f75f4..ba44d36 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -40,16 +40,16 @@ coco: # 请求L站 state 过期时间 分钟 expirationTtl: 5 #重定向地址 - redirectUri: http://localhost:8181/oauth2/callback - clientId: hi3geJYfTotoiR5S62u3rh4W5tSeC5UG - clientSecret: VMPBVoAfOB5ojkGXRDEtzvDhRLENHpaN + redirectUri: + clientId: + clientSecret: # L站 - authorizationEndpoint: https://connect.linux.do/oauth2/authorize - tokenEndpoint: https://connect.linux.do/oauth2/token - userEndpoint: https://connect.linux.do/api/user + authorizationEndpoint: + tokenEndpoint: + userEndpoint: #代理节点 - baseApi: https://api.cocopilot.org/copilot_internal/v2/token - baseProxy: https://proxy.cocopilot.org + baseApi: + baseProxy: #ghu 频率秒 1秒8次 frequencyTime: 1 #ghu频率数