From b5388b9eaf10d15b6dcffb375df610684f693308 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Mar 2024 11:44:45 -0600 Subject: [PATCH] #171: support streaming --- pkg/api/schemas/chat_stream.go | 2 +- pkg/providers/cohere/chat.go | 22 +-- pkg/providers/cohere/chat_stream.go | 238 ++++++++++++++++++++++- pkg/providers/cohere/chat_stream_test.go | 164 ++++++++++++++++ pkg/providers/cohere/client.go | 5 +- pkg/providers/cohere/config.go | 2 +- pkg/providers/cohere/error.go | 64 ++++++ pkg/providers/cohere/schemas.go | 18 ++ 8 files changed, 496 insertions(+), 19 deletions(-) create mode 100644 pkg/providers/cohere/chat_stream_test.go create mode 100644 pkg/providers/cohere/error.go diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go index e9095511..d77310c7 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schemas/chat_stream.go @@ -19,7 +19,7 @@ type ChatStreamRequest struct { func NewChatStreamFromStr(message string) *ChatStreamRequest { return &ChatStreamRequest{ Message: ChatMessage{ - "human", + "user", message, "glide", }, diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 8b72c69c..b3554883 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -38,8 +38,7 @@ type ChatRequest struct { Connectors []string `json:"connectors,omitempty"` SearchQueriesOnly bool `json:"search_queries_only,omitempty"` CitiationQuality string `json:"citiation_quality,omitempty"` - - // Stream bool `json:"stream,omitempty"` + Stream bool `json:"stream,omitempty"` } type Connectors struct { @@ -61,6 +60,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { Connectors: cfg.DefaultParams.Connectors, SearchQueriesOnly: cfg.DefaultParams.SearchQueriesOnly, CitiationQuality: cfg.DefaultParams.CitiationQuality, + Stream: false, } } @@ -119,7 +119,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche req.Header.Set("Content-Type", "application/json") // TODO: this could leak information from messages which may not be a desired thing to have - c.telemetry.Logger.Debug( + c.tel.Logger.Debug( "cohere chat request", zap.String("chat_url", c.chatURL), zap.Any("payload", payload), @@ -135,10 +135,10 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche if resp.StatusCode != http.StatusOK { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err)) } - c.telemetry.Logger.Error( + c.tel.Logger.Error( "cohere chat request failed", zap.Int("status_code", resp.StatusCode), zap.String("response", string(bodyBytes)), @@ -156,7 +156,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Read the response body into a byte slice bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err)) return nil, err } @@ -165,7 +165,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche err = json.Unmarshal(bodyBytes, &responseJSON) if err != nil { - c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err)) return nil, err } @@ -174,7 +174,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche err = json.Unmarshal(bodyBytes, &cohereCompletion) if err != nil { - c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err)) return nil, err } @@ -191,7 +191,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: "model", // TODO: Does this need to change? + Role: "model", Content: cohereCompletion.Text, Name: "", }, @@ -209,11 +209,11 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse, error) { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err)) return nil, err } - c.telemetry.Logger.Error( + c.tel.Logger.Error( "cohere chat request failed", zap.Int("status_code", resp.StatusCode), zap.String("response", string(bodyBytes)), diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index e8a4bd2a..46b6f311 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -1,16 +1,246 @@ package cohere + import ( + "bytes" "context" + "encoding/json" + "fmt" + "net/http" - "glide/pkg/api/schemas" + "github.com/r3labs/sse/v2" "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "go.uber.org/zap" + + "glide/pkg/api/schemas" ) + +var ( + StopReason = "stream-end" +) + +// ChatStream represents cohere chat stream for a specific request +type ChatStream struct { + tel *telemetry.Telemetry + client *http.Client + req *http.Request + reqID string + reqMetadata *schemas.Metadata + resp *http.Response + reader *sse.EventStreamReader + errMapper *ErrorMapper +} + +func NewChatStream( + tel *telemetry.Telemetry, + client *http.Client, + req *http.Request, + reqID string, + reqMetadata *schemas.Metadata, + errMapper *ErrorMapper, +) *ChatStream { + return &ChatStream{ + tel: tel, + client: client, + req: req, + reqID: reqID, + reqMetadata: reqMetadata, + errMapper: errMapper, + } +} + +func (s *ChatStream) Open() error { + resp, err := s.client.Do(s.req) //nolint:bodyclose + fmt.Print(resp.StatusCode) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return s.errMapper.Map(resp) + } + + s.resp = resp + s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? + + return nil +} + +func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { + var completionChunk ChatCompletionChunk + + for { + rawEvent, err := s.reader.ReadEvent() + fmt.Print(err) + if err != nil { + s.tel.L().Warn( + "Chat stream is unexpectedly disconnected", + zap.String("provider", providerName), + zap.Error(err), + ) + + // if err is io.EOF, this still means that the stream is interrupted unexpectedly + // because the normal stream termination is done via finding out streamDoneMarker + + return nil, clients.ErrProviderUnavailable + } + + s.tel.L().Debug( + "Raw chat stream chunk", + zap.String("provider", providerName), + zap.ByteString("rawChunk", rawEvent), + ) + + event, err := clients.ParseSSEvent(rawEvent) + + if err != nil { + return nil, fmt.Errorf("failed to parse chat stream message: %v", err) + } + + if !event.HasContent() { + s.tel.L().Debug( + "Received an empty message in chat stream, skipping it", + zap.String("provider", providerName), + zap.Any("msg", event), + ) + + continue + } + + err = json.Unmarshal(event.Data, &completionChunk) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err) + } + + responseChunk := completionChunk + + var finishReason *schemas.FinishReason + + if responseChunk.IsFinished { + finishReason = &schemas.Complete + return &schemas.ChatStreamChunk{ + ID: s.reqID, + Provider: providerName, + Cached: false, + ModelName: "NA", + Metadata: s.reqMetadata, + ModelResponse: schemas.ModelChunkResponse{ + Metadata: &schemas.Metadata{ + "generationId": responseChunk.Response.GenerationID, + "responseId": responseChunk.Response.ResponseID, + }, + Message: schemas.ChatMessage{ + Role: "model", + Content: responseChunk.Text, + }, + FinishReason: finishReason, + }, + }, nil + } + + // TODO: use objectpool here + return &schemas.ChatStreamChunk{ + ID: s.reqID, + Provider: providerName, + Cached: false, + ModelName: "NA", + Metadata: s.reqMetadata, + ModelResponse: schemas.ModelChunkResponse{ + Message: schemas.ChatMessage{ + Role: "model", + Content: responseChunk.Text, + }, + FinishReason: finishReason, + }, + }, nil + } +} + +func (s *ChatStream) Close() error { + if s.resp != nil { + return s.resp.Body.Close() + } + + return nil +} + func (c *Client) SupportChatStream() bool { - return false + return true } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) { + // Create a new chat request + httpRequest, err := c.makeStreamReq(ctx, req) + if err != nil { + return nil, err + } + + return NewChatStream( + c.tel, + c.httpClient, + httpRequest, + req.ID, + req.Metadata, + c.errMapper, + ), nil } + +func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template + + chatRequest.Message = request.Message.Content + + // Build the Cohere specific ChatHistory + if len(request.MessageHistory) > 0 { + chatRequest.ChatHistory = make([]ChatHistory, len(request.MessageHistory)) + for i, message := range request.MessageHistory { + chatRequest.ChatHistory[i] = ChatHistory{ + // Copy the necessary fields from message to ChatHistory + // For example, if ChatHistory has a field called "Text", you can do: + Role: message.Role, + Message: message.Content, + User: "", + } + } + } + + return &chatRequest +} + +func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) { + chatRequest := c.createRequestFromStream(req) + + chatRequest.Stream = true + + fmt.Print(chatRequest) + + rawPayload, err := json.Marshal(chatRequest) + if err != nil { + return nil, fmt.Errorf("unable to marshal cohere chat stream request payload: %w", err) + } + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create cohere stream chat request: %w", err) + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey))) + request.Header.Set("Cache-Control", "no-cache") + request.Header.Set("Accept", "text/event-stream") + request.Header.Set("Connection", "keep-alive") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.tel.L().Debug( + "Stream chat request", + zap.String("chatURL", c.chatURL), + zap.Any("payload", chatRequest), + ) + + return request, nil +} + diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/providers/cohere/chat_stream_test.go new file mode 100644 index 00000000..dd5dbbba --- /dev/null +++ b/pkg/providers/cohere/chat_stream_test.go @@ -0,0 +1,164 @@ +package cohere + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/api/schemas" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestCohere_ChatStreamSupported(t *testing.T) { + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + require.True(t, client.SupportChatStream()) +} + +func TestCohere_ChatStreamRequest(t *testing.T) { + tests := map[string]string{ + "success stream": "./testdata/chat_stream.success.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + //providerCfg.BaseURL = openAIServer.URL + providerCfg.BaseURL = "https://api.cohere.ai/v1" + providerCfg.APIKey = "LXeDdTbRN17H4tDFTyP6b2E7LeXMM9oCkm1Fd1ig" + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + + fmt.Print(chunk) + fmt.Print(err) + + if err == io.EOF { + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} + +func TestCohere_ChatStreamRequestInterrupted(t *testing.T) { + tests := map[string]string{ + "success stream, but no last done message": "./testdata/chat_stream.nodone.txt", + "success stream, but with empty event": "./testdata/chat_stream.empty.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = openAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + if err != nil { + require.ErrorIs(t, err, clients.ErrProviderUnavailable) + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go index a6cc9cf5..57a73e15 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/providers/cohere/client.go @@ -23,9 +23,10 @@ type Client struct { baseURL string chatURL string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client - telemetry *telemetry.Telemetry + tel *telemetry.Telemetry } // NewClient creates a new Cohere client for the Cohere API. @@ -48,7 +49,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * MaxIdleConnsPerHost: 2, }, }, - telemetry: tel, + tel: tel, } return c, nil diff --git a/pkg/providers/cohere/config.go b/pkg/providers/cohere/config.go index 1a38aefa..50dbcde4 100644 --- a/pkg/providers/cohere/config.go +++ b/pkg/providers/cohere/config.go @@ -8,7 +8,7 @@ import ( // TODO: Add validations type Params struct { Temperature float64 `json:"temperature,omitempty"` - Stream bool `json:"stream,omitempty"` // unsupported right now + Stream bool `json:"stream,omitempty"` PreambleOverride string `json:"preamble_override,omitempty"` ChatHistory []ChatHistory `json:"chat_history,omitempty"` ConversationID string `json:"conversation_id,omitempty"` diff --git a/pkg/providers/cohere/error.go b/pkg/providers/cohere/error.go new file mode 100644 index 00000000..3e5ae89e --- /dev/null +++ b/pkg/providers/cohere/error.go @@ -0,0 +1,64 @@ +package cohere + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.Logger.Error( + "Failed to unmarshal chat response error", + zap.String("provider", providerName), + zap.Error(err), + zap.ByteString("rawResponse", bodyBytes), + ) + + return clients.ErrProviderUnavailable + } + + m.tel.Logger.Error( + "Chat request failed", + zap.String("provider", providerName), + zap.Int("statusCode", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go index c807aa56..e692c3b0 100644 --- a/pkg/providers/cohere/schemas.go +++ b/pkg/providers/cohere/schemas.go @@ -65,3 +65,21 @@ type ConnectorsResponse struct { ContOnFail string `json:"continue_on_failure"` Options map[string]string `json:"options"` } + +// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming +// Ref: https://docs.cohere.com/reference/about +type ChatCompletionChunk struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + Text string `json:"text"` + Response FinalResponse `json:"response,omitempty"` +} + +type FinalResponse struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + TokenCount TokenCount `json:"token_count"` + Meta Meta `json:"meta"` + FinishReason string `json:"finish_reason"` +}