Skip to content

Commit

Permalink
#171: support streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Mar 22, 2024
1 parent 369ddea commit b5388b9
Show file tree
Hide file tree
Showing 8 changed files with 496 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type ChatStreamRequest struct {
func NewChatStreamFromStr(message string) *ChatStreamRequest {
return &ChatStreamRequest{
Message: ChatMessage{
"human",
"user",
message,
"glide",
},
Expand Down
22 changes: 11 additions & 11 deletions pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ type ChatRequest struct {
Connectors []string `json:"connectors,omitempty"`
SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
CitiationQuality string `json:"citiation_quality,omitempty"`

// Stream bool `json:"stream,omitempty"`
Stream bool `json:"stream,omitempty"`
}

type Connectors struct {
Expand All @@ -61,6 +60,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
Connectors: cfg.DefaultParams.Connectors,
SearchQueriesOnly: cfg.DefaultParams.SearchQueriesOnly,
CitiationQuality: cfg.DefaultParams.CitiationQuality,
Stream: false,
}
}

Expand Down Expand Up @@ -119,7 +119,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
req.Header.Set("Content-Type", "application/json")

// TODO: this could leak information from messages which may not be a desired thing to have
c.telemetry.Logger.Debug(
c.tel.Logger.Debug(
"cohere chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
Expand All @@ -135,10 +135,10 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
if resp.StatusCode != http.StatusOK {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err))
c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err))
}

c.telemetry.Logger.Error(
c.tel.Logger.Error(
"cohere chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
Expand All @@ -156,7 +156,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err))
c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err))
return nil, err
}

Expand All @@ -165,7 +165,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche

err = json.Unmarshal(bodyBytes, &responseJSON)
if err != nil {
c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err))
c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err))
return nil, err
}

Expand All @@ -174,7 +174,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche

err = json.Unmarshal(bodyBytes, &cohereCompletion)
if err != nil {
c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err))
c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err))
return nil, err
}

Expand All @@ -191,7 +191,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
"responseId": cohereCompletion.ResponseID,
},
Message: schemas.ChatMessage{
Role: "model", // TODO: Does this need to change?
Role: "model",
Content: cohereCompletion.Text,
Name: "",
},
Expand All @@ -209,11 +209,11 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse, error) {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err))
c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err))
return nil, err
}

c.telemetry.Logger.Error(
c.tel.Logger.Error(
"cohere chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
Expand Down
238 changes: 234 additions & 4 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,246 @@
package cohere


import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"glide/pkg/api/schemas"
"github.com/r3labs/sse/v2"
"glide/pkg/providers/clients"
"glide/pkg/telemetry"

"go.uber.org/zap"

"glide/pkg/api/schemas"
)


var (
StopReason = "stream-end"
)

// ChatStream represents cohere chat stream for a specific request
type ChatStream struct {
tel *telemetry.Telemetry
client *http.Client
req *http.Request
reqID string
reqMetadata *schemas.Metadata
resp *http.Response
reader *sse.EventStreamReader
errMapper *ErrorMapper
}

func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
reqID string,
reqMetadata *schemas.Metadata,
errMapper *ErrorMapper,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
reqID: reqID,
reqMetadata: reqMetadata,
errMapper: errMapper,
}
}

func (s *ChatStream) Open() error {
resp, err := s.client.Do(s.req) //nolint:bodyclose
fmt.Print(resp.StatusCode)
if err != nil {

Check failure on line 58 in pkg/providers/cohere/chat_stream.go

View workflow job for this annotation

GitHub Actions / Static Checks

if statements should only be cuddled with assignments (wsl)
return err
}

if resp.StatusCode != http.StatusOK {
return s.errMapper.Map(resp)
}

s.resp = resp
s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize?

return nil
}

func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
var completionChunk ChatCompletionChunk

for {
rawEvent, err := s.reader.ReadEvent()
fmt.Print(err)
if err != nil {

Check failure on line 78 in pkg/providers/cohere/chat_stream.go

View workflow job for this annotation

GitHub Actions / Static Checks

if statements should only be cuddled with assignments (wsl)
s.tel.L().Warn(
"Chat stream is unexpectedly disconnected",
zap.String("provider", providerName),
zap.Error(err),
)

// if err is io.EOF, this still means that the stream is interrupted unexpectedly
// because the normal stream termination is done via finding out streamDoneMarker

return nil, clients.ErrProviderUnavailable
}

s.tel.L().Debug(
"Raw chat stream chunk",
zap.String("provider", providerName),
zap.ByteString("rawChunk", rawEvent),
)

event, err := clients.ParseSSEvent(rawEvent)

if err != nil {
return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
}

if !event.HasContent() {
s.tel.L().Debug(
"Received an empty message in chat stream, skipping it",
zap.String("provider", providerName),
zap.Any("msg", event),
)

continue
}

err = json.Unmarshal(event.Data, &completionChunk)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)
}

responseChunk := completionChunk

var finishReason *schemas.FinishReason

if responseChunk.IsFinished {
finishReason = &schemas.Complete
return &schemas.ChatStreamChunk{

Check failure on line 124 in pkg/providers/cohere/chat_stream.go

View workflow job for this annotation

GitHub Actions / Static Checks

return statements should not be cuddled if block has more than two lines (wsl)
ID: s.reqID,
Provider: providerName,
Cached: false,
ModelName: "NA",
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": responseChunk.Response.GenerationID,
"responseId": responseChunk.Response.ResponseID,
},
Message: schemas.ChatMessage{
Role: "model",
Content: responseChunk.Text,
},
FinishReason: finishReason,
},
}, nil
}

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
ModelName: "NA",
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Message: schemas.ChatMessage{
Role: "model",
Content: responseChunk.Text,
},
FinishReason: finishReason,
},
}, nil
}
}

func (s *ChatStream) Close() error {
if s.resp != nil {
return s.resp.Body.Close()
}

return nil
}

func (c *Client) SupportChatStream() bool {
return false
return true
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
// Create a new chat request
httpRequest, err := c.makeStreamReq(ctx, req)
if err != nil {
return nil, err
}

return NewChatStream(
c.tel,
c.httpClient,
httpRequest,
req.ID,
req.Metadata,
c.errMapper,
), nil
}

func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template

chatRequest.Message = request.Message.Content

// Build the Cohere specific ChatHistory
if len(request.MessageHistory) > 0 {
chatRequest.ChatHistory = make([]ChatHistory, len(request.MessageHistory))
for i, message := range request.MessageHistory {
chatRequest.ChatHistory[i] = ChatHistory{
// Copy the necessary fields from message to ChatHistory
// For example, if ChatHistory has a field called "Text", you can do:
Role: message.Role,
Message: message.Content,
User: "",
}
}
}

return &chatRequest
}

func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
chatRequest := c.createRequestFromStream(req)

chatRequest.Stream = true

fmt.Print(chatRequest)

rawPayload, err := json.Marshal(chatRequest)
if err != nil {
return nil, fmt.Errorf("unable to marshal cohere chat stream request payload: %w", err)
}

request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
if err != nil {
return nil, fmt.Errorf("unable to create cohere stream chat request: %w", err)
}

request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey)))
request.Header.Set("Cache-Control", "no-cache")
request.Header.Set("Accept", "text/event-stream")
request.Header.Set("Connection", "keep-alive")

// TODO: this could leak information from messages which may not be a desired thing to have
c.tel.L().Debug(
"Stream chat request",
zap.String("chatURL", c.chatURL),
zap.Any("payload", chatRequest),
)

return request, nil
}

Loading

0 comments on commit b5388b9

Please sign in to comment.