Skip to content

Commit

Permalink
feat: interpreter ws 反向连接
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed Jun 26, 2024
1 parent c862c17 commit 7988f82
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 62 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ serverless:
interpreter:
baseUrl: http://127.0.0.1:8000
echoCode: false
ws: false
ws: true

custom-llm:
baseUrl: http://127.0.0.1:8080
Expand Down
5 changes: 1 addition & 4 deletions internal/gin.handler/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ func Bind(port int, version, proxies string) {
route.GET("/v1/models", models)
route.Static("/file/tmp/", "tmp")

if plugin.IO != nil {
route.GET("/socket.io/*any", gin.WrapH(plugin.IO.ServeHandler(nil)))
route.POST("/socket.io/*any", gin.WrapH(plugin.IO.ServeHandler(nil)))
}
route.Any("/socket.io/*any", gin.WrapH(plugin.IO.ServeHandler(nil)))

addr := ":" + strconv.Itoa(port)
logger.Info(fmt.Sprintf("server start by http://0.0.0.0%s/v1", addr))
Expand Down
49 changes: 15 additions & 34 deletions internal/plugin/llm/interpreter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ var (
Adapter = API{}
Model = "open-interpreter"

mu sync.Mutex
ws *socketio.Socket
mu sync.Mutex
ws *socketio.Socket
wsChan chan string
)

type API struct {
Expand Down Expand Up @@ -55,7 +56,7 @@ func init() {
}

func (API) Match(_ *gin.Context, model string) bool {
return model == Model
return model == Model || model == Model+"-ws"
}

func (API) Models() []plugin.Model {
Expand All @@ -66,6 +67,12 @@ func (API) Models() []plugin.Model {
Created: 1686935002,
By: "interpreter-adapter",
},
{
Id: "open-interpreter-ws",
Object: "model",
Created: 1686935002,
By: "interpreter-adapter",
},
}
}

Expand All @@ -76,6 +83,11 @@ func (API) Completion(ctx *gin.Context) {
matchers = common.GetGinMatchers(ctx)
)

if completion.Model == Model+"-ws" {
completionWS(ctx)
return
}

r, tokens, err := fetch(ctx, proxies, completion)
if err != nil {
logger.Error(err)
Expand All @@ -89,34 +101,3 @@ func (API) Completion(ctx *gin.Context) {
response.Error(ctx, -1, "EMPTY RESPONSE")
}
}

func initSocketIO(w *socketio.Socket) bool {
if ws != nil {
return false
}

r := w.Request()
if token := r.GetPathInfo(); token != "/socket.io/open-i/" {
return false
}

w.On("disconnect", func(...any) {
mu.Lock()
defer mu.Unlock()
ws = nil
})

w.On("ping", func(...any) {
w.Emit("pong", "ok")
})

w.On("message", func(args ...any) {
message := args[0].(string)
// TODO -
logger.Infof("message: %s", message)
w.Emit("message", "ok")
})

ws = w
return true
}
49 changes: 26 additions & 23 deletions internal/plugin/llm/interpreter/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@ import (

func fetch(ctx *gin.Context, proxies string, completion pkg.ChatCompletion) (response *http.Response, tokens int, err error) {
var (
baseUrl = pkg.Config.GetString("interpreter.baseUrl")
useProxy = pkg.Config.GetBool("interpreter.useProxy")
baseUrl = pkg.Config.GetString("interpreter.baseUrl")
)

if !useProxy {
proxies = ""
tokens, message, err := mergeMessages(ctx, proxies, baseUrl, completion)
if err != nil {
return nil, -1, err
}

response, err = emit.ClientBuilder(plugin.HTTPClient).
Context(common.GetGinContext(ctx)).
Proxies(proxies).
POST(baseUrl+"/chat").
Body(map[string]string{
"message": message,
}).
DoC(emit.Status(http.StatusOK), emit.IsSTREAM)
if err != nil {
err = logger.WarpError(err)
}
return
}

func mergeMessages(ctx *gin.Context, proxies, baseUrl string, completion pkg.ChatCompletion) (tokens int, message string, err error) {
condition := func(role string) string {
switch role {
case "assistant", "end":
Expand All @@ -49,8 +64,7 @@ func fetch(ctx *gin.Context, proxies string, completion pkg.ChatCompletion) (res
tokens += common.CalcTokens(opts.Message["content"])
// 复合消息
if _, ok := opts.Message["multi"]; ok && role == "user" {
message := opts.Initial()
content, e := common.MergeMultiMessage(ctx.Request.Context(), proxies, message)
content, e := common.MergeMultiMessage(ctx.Request.Context(), proxies, opts.Initial())
if e != nil {
return nil, e
}
Expand Down Expand Up @@ -87,10 +101,11 @@ func fetch(ctx *gin.Context, proxies string, completion pkg.ChatCompletion) (res
})

if messageL := len(messages); !messages[messageL-1].Is("role", "user") {
return nil, 0, errors.Errorf("messages[%d] is not `user` role", messageL-1)
err = errors.Errorf("messages[%d] is not `user` role", messageL-1)
return
}

message := messages[len(messages)-1].GetString("content")
message = messages[len(messages)-1].GetString("content")
messages = messages[:len(messages)-1]

obj := map[string]interface{}{
Expand All @@ -101,28 +116,16 @@ func fetch(ctx *gin.Context, proxies string, completion pkg.ChatCompletion) (res
obj["system"] = system
}

response, err = emit.ClientBuilder(plugin.HTTPClient).
response, e := emit.ClientBuilder(plugin.HTTPClient).
Context(common.GetGinContext(ctx)).
Proxies(proxies).
POST(baseUrl+"/settings").
Body(obj).
DoC(emit.Status(http.StatusOK), emit.IsJSON)
if err != nil {
err = logger.WarpError(err)
if e != nil {
err = logger.WarpError(e)
return
}
logger.Info(emit.TextResponse(response))

response, err = emit.ClientBuilder(plugin.HTTPClient).
Context(common.GetGinContext(ctx)).
Proxies(proxies).
POST(baseUrl+"/chat").
Body(map[string]string{
"message": message,
}).
DoC(emit.Status(http.StatusOK), emit.IsSTREAM)
if err != nil {
err = logger.WarpError(err)
}
return
}
101 changes: 101 additions & 0 deletions internal/plugin/llm/interpreter/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,104 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, r *http.Response,
}
return
}

func waitResponseWS(ctx *gin.Context, matchers []common.Matcher, sse bool) (content string) {
defer close(wsChan)

logger.Info("waitResponse ...")
created := time.Now().Unix()
completion := common.GetGinCompletion(ctx)
toolId := common.GetGinToolValue(ctx).GetString("id")
toolId = plugin.NameWithTools(toolId, completion.Tools)
echoCode := pkg.Config.GetBool("interpreter.echoCode")

// return true 结束
handler := func() bool {
data, ok := <-wsChan
if !ok {
return true
}

logger.Tracef("--------- ORIGINAL MESSAGE ---------")
logger.Tracef("%s", data)

//if len(data) < 6 || data[:6] != "data: " {
// return false
//}
//
//data = data[6:]
if data == "[DONE]" || data == "[DONE]\r" {
return true
}

var message pkg.Keyv[interface{}]
err := json.Unmarshal([]byte(data), &message)
if err != nil {
logger.Error(err)
return false
}

if !message.Is("role", "assistant") || !message.In("type", "message", "code") {
return false
}

// 控制是否输出代码
if !echoCode && message.Is("type", "code") {
return false
}

raw := message.GetString("content")

if message.Is("type", "code") && message.Is("start", true) {
raw += fmt.Sprintf("\n```%s\n", message.GetString("format"))
}

if message.Is("type", "code") && message.Is("end", true) {
raw += "\n```\n"
}

if len(raw) == 0 {
return false
}

logger.Debug("----- raw -----")
logger.Debug(raw)

raw = common.ExecMatchers(matchers, raw)
if len(raw) == 0 {
return false
}

if sse && len(raw) > 0 {
response.SSEResponse(ctx, Model, raw, created)
}
content += raw
return false
}

c := common.GetGinContext(ctx)
for {
select {
case <-c.Done():
logger.Error("timed out [ws] waitResponse")
_ = ws.Emit("cancel")
goto label
default:
if handler() {
goto label
}
}
}

label:
if content == "" && response.NotSSEHeader(ctx) {
return
}

if !sse {
response.Response(ctx, Model, content)
} else {
response.SSEResponse(ctx, Model, "[DONE]", created)
}
return
}
79 changes: 79 additions & 0 deletions internal/plugin/llm/interpreter/socket.io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package interpreter

import (
"github.com/bincooo/chatgpt-adapter/internal/common"
"github.com/bincooo/chatgpt-adapter/internal/gin.handler/response"
"github.com/bincooo/chatgpt-adapter/logger"
"github.com/bincooo/chatgpt-adapter/pkg"
"github.com/gin-gonic/gin"
socketio "github.com/zishang520/socket.io/socket"
)

func completionWS(ctx *gin.Context) {
if ws == nil {
response.Error(ctx, -1, "socket.io connection is closed / disabled")
return
}

var (
baseUrl = pkg.Config.GetString("interpreter.baseUrl")
proxies = ctx.GetString("proxies")
completion = common.GetGinCompletion(ctx)
matchers = common.GetGinMatchers(ctx)
)

tokens, message, err := mergeMessages(ctx, proxies, baseUrl, completion)
if err != nil {
logger.Error(err)
response.Error(ctx, -1, err)
return
}
ctx.Set(ginTokens, tokens)

err = ws.Emit("reply", message)
if err != nil {
logger.Error(err)
response.Error(ctx, -1, err)
return
}

wsChan = make(chan string)
content := waitResponseWS(ctx, matchers, completion.Stream)
if content == "" && response.NotResponse(ctx) {
response.Error(ctx, -1, "EMPTY RESPONSE")
}
}

func initSocketIO(w *socketio.Socket) bool {
if ws != nil {
return false
}

r := w.Request()
if token := r.GetPathInfo(); token != "/socket.io/open-i/" {
return false
}

w.On("disconnect", func(...any) {
mu.Lock()
defer mu.Unlock()
ws = nil
})

w.On("ping", func(...any) {
w.Emit("pong", "ok")
})

w.On("message", func(args ...any) {
message := args[0].(string)
// TODO -
if ws != nil {
wsChan <- message
}
logger.Infof("message: %s", message)
w.Emit("message", "ok")
})

ws = w
return true
}

0 comments on commit 7988f82

Please sign in to comment.