Skip to content

Commit

Permalink
🔧 #161: Refined the chat stream req/responses (#179)
Browse files Browse the repository at this point in the history
- Separated chat & chat stream request schemas 
- introduced a new finish reason field 
- added metadata to stream chat response
- allow to attach some metadata to a chat stream request and then attach it to each chat stream chunk
- adjusted error message schema to include request ID and metadata
  • Loading branch information
roma-glushko authored Mar 19, 2024
1 parent 3742035 commit dc7d42f
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 110 deletions.
9 changes: 5 additions & 4 deletions pkg/api/http/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
// websocket.Conn bindings https://pkg.go.dev/github.com/fasthttp/websocket?tab=doc#pkg-index

var (
err error
chatRequest schemas.ChatRequest
wg sync.WaitGroup
err error
wg sync.WaitGroup
)

chunkResultC := make(chan *schemas.ChatStreamResult)
Expand Down Expand Up @@ -151,6 +150,8 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
}()

for {
var chatRequest schemas.ChatStreamRequest

if err = c.ReadJSON(&chatRequest); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
tel.L().Warn("Streaming Chat connection is closed", zap.Error(err), zap.String("routerID", routerID))
Expand All @@ -164,7 +165,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
// TODO: handle termination gracefully
wg.Add(1)

go func(chatRequest schemas.ChatRequest) {
go func(chatRequest schemas.ChatStreamRequest) {
defer wg.Done()

router.ChatStream(context.Background(), &chatRequest, chunkResultC)
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/schemas/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewChatFromStr(message string) *ChatRequest {
Message: ChatMessage{
"human",
message,
"roma",
"glide",
},
}
}
Expand Down
56 changes: 40 additions & 16 deletions pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,55 @@
package schemas

import "time"
type (
Metadata = map[string]any
FinishReason = string
)

var Complete FinishReason = "complete"

// ChatStreamRequest defines a message that requests a new streaming chat
type ChatStreamRequest struct {
// TODO: implement
ID string `json:"id" validate:"required"`
Message ChatMessage `json:"message" validate:"required"`
MessageHistory []ChatMessage `json:"messageHistory" validate:"required"`
Override *OverrideChatRequest `json:"overrideMessage,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}

func NewChatStreamFromStr(message string) *ChatStreamRequest {
return &ChatStreamRequest{
Message: ChatMessage{
"human",
message,
"glide",
},
}
}

type ModelChunkResponse struct {
Metadata *Metadata `json:"metadata,omitempty"`
Message ChatMessage `json:"message"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
}

// ChatStreamChunk defines a message for a chunk of streaming chat response
type ChatStreamChunk struct {
// TODO: modify according to the streaming chat needs
ID string `json:"id,omitempty"`
Created int `json:"created,omitempty"`
Provider string `json:"provider,omitempty"`
RouterID string `json:"router,omitempty"`
ModelID string `json:"model_id,omitempty"`
ModelName string `json:"model,omitempty"`
Cached bool `json:"cached,omitempty"`
ModelResponse ModelResponse `json:"modelResponse,omitempty"`
Latency *time.Duration `json:"-"`
// TODO: add chat request-specific context
ID string `json:"id"`
CreatedAt int `json:"createdAt"`
Provider string `json:"providerId"`
RouterID string `json:"routerId"`
ModelID string `json:"modelId"`
Cached bool `json:"cached"`
ModelName string `json:"modelName"`
Metadata *Metadata `json:"metadata,omitempty"`
ModelResponse ModelChunkResponse `json:"modelResponse"`
}

type ChatStreamError struct {
// TODO: add chat request-specific context
Reason string `json:"reason"`
Message string `json:"message"`
ID string `json:"id"`
ErrCode string `json:"errCode"`
Message string `json:"message"`
Metadata *Metadata `json:"metadata,omitempty"`
}

type ChatStreamResult struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/anthropic/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
2 changes: 1 addition & 1 deletion pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
2 changes: 1 addition & 1 deletion pkg/providers/bedrock/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
2 changes: 1 addition & 1 deletion pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
6 changes: 3 additions & 3 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ type LangProvider interface {
SupportChatStream() bool

Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, req *schemas.ChatRequest) (clients.ChatStream, error)
ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error)
}

type LangModel interface {
Model
Provider() string
Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, req *schemas.ChatRequest) (<-chan *clients.ChatStreamResult, error)
ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error)
}

// LanguageModel wraps provider client and expend it with health & latency tracking
Expand Down Expand Up @@ -102,7 +102,7 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)
return resp, err
}

func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) (<-chan *clients.ChatStreamResult, error) {
func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error) {
stream, err := m.client.ChatStream(ctx, req)
if err != nil {
m.healthTracker.TrackErr(err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/octoml/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
2 changes: 1 addition & 1 deletion pkg/providers/ollama/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
33 changes: 14 additions & 19 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,13 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
}
}

func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
}

messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return messages
}

// Chat sends a chat request to the specified OpenAI model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := *c.createChatRequestSchema(request)
chatRequest := c.createRequestSchema(request)
chatRequest.Stream = false

chatResponse, err := c.doChatRequest(ctx, &chatRequest)
chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
return nil, err
}
Expand All @@ -64,12 +51,20 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return chatResponse, nil
}

func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template

chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
}

chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return chatRequest
return &chatRequest
}

func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
Expand Down
Loading

0 comments on commit dc7d42f

Please sign in to comment.