Skip to content

Commit

Permalink
Merge pull request #6 from Livepool-io/nv/update-response-transfo
Browse files Browse the repository at this point in the history
properly send usage with streamed responses
  • Loading branch information
kyriediculous authored Oct 30, 2024
2 parents 6e4efdd + 3d0ea50 commit 0c62654
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 78 deletions.
58 changes: 58 additions & 0 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: Docker Build and Push

on:
pull_request:
push:
branches: [ master ]
tags: [ 'v*.*.*' ]

jobs:
docker:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
id-token: write

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up QEMU
uses: docker/setup-qemu-action@v3

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: livepool/openai-api
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,format=long
type=raw,value=latest,enable={{is_default_branch}}
- name: Login to Docker Hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
provenance: false
sbom: false
22 changes: 12 additions & 10 deletions common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,20 @@ func TransformResponse(req *worker.GenLLMFormdataRequestBody, resp *http.Respons
FinishReason: "stop",
},
},
Usage: models.Usage{
PromptTokens: len(req.Prompt), // TODO: count actual tokens
CompletionTokens: res.TokensUsed,
TotalTokens: res.TokensUsed + len(req.Prompt), // Adjust if you have prompt tokens count
Usage: &models.Usage{
TotalTokens: res.TokensUsed, // TokensUsed already includes prompt tokens
},
}

return openAIResp, nil
}

func TransformStreamResponse(chunk worker.LlmStreamChunk, streamID string) (models.OpenAIStreamResponse, error) {
func TransformStreamResponse(chunk worker.LlmStreamChunk, req *worker.GenLLMFormdataRequestBody, streamID string) (models.OpenAIStreamResponse, error) {
openAIResp := models.OpenAIStreamResponse{
ID: streamID,
Object: "text_completion",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "gpt-3.5-turbo-0301", // You might want to make this configurable
Model: *req.ModelId,
Choices: []models.StreamChoice{
{
Index: 0,
Expand All @@ -121,12 +119,16 @@ func TransformStreamResponse(chunk worker.LlmStreamChunk, streamID string) (mode

if chunk.Done {
openAIResp.Choices[0].FinishReason = "stop"
// Only include usage information in the final chunk
openAIResp.Usage = &models.Usage{
TotalTokens: chunk.TokensUsed, // TokensUsed already includes prompt tokens
}
}

return openAIResp, nil
}

func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan models.OpenAIStreamResponse, <-chan error) {
func HandleStreamingResponse(ctx context.Context, req *worker.BodyGenLLM, resp *http.Response) (<-chan models.OpenAIStreamResponse, <-chan error) {
streamChan := make(chan models.OpenAIStreamResponse)
errChan := make(chan error, 1) // Buffered channel to avoid goroutine leak

Expand All @@ -152,7 +154,7 @@ func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan m
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
chunk := worker.LlmStreamChunk{Chunk: "DONE", Done: true, TokensUsed: totalTokens}
openAIChunk, err := TransformStreamResponse(chunk, streamID)
openAIChunk, err := TransformStreamResponse(chunk, req, streamID)
if err != nil {
errChan <- fmt.Errorf("error converting final chunk: %w", err)
return
Expand All @@ -168,7 +170,7 @@ func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan m
}

totalTokens += chunk.TokensUsed
openAIChunk, err := TransformStreamResponse(chunk, streamID)
openAIChunk, err := TransformStreamResponse(chunk, req, streamID)
if err != nil {
errChan <- fmt.Errorf("error converting chunk: %w", err)
return
Expand Down
68 changes: 2 additions & 66 deletions middleware/gateway.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
package middleware

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/golang/glog"
"github.com/livepeer/ai-worker/worker"
"github.com/livepool-io/openai-middleware/common"
"github.com/livepool-io/openai-middleware/models"
)

type Gateway struct {
Expand Down Expand Up @@ -53,69 +50,8 @@ func (g *Gateway) PostLlmGenerate(req worker.GenLLMFormdataRequestBody) (*http.R
return client.Do(httpReq)
}

func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan models.OpenAIStreamResponse, <-chan error) {
streamChan := make(chan models.OpenAIStreamResponse)
errChan := make(chan error, 1) // Buffered channel to avoid goroutine leak

go func() {
defer close(streamChan)
defer close(errChan)

streamID := common.GenerateUniqueID()
scanner := bufio.NewScanner(resp.Body)
var totalTokens int

for scanner.Scan() {
select {
case <-ctx.Done():
errChan <- ctx.Err()
return
default:
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}

data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
chunk := worker.LlmStreamChunk{Chunk: "DONE", Done: true, TokensUsed: totalTokens}
openAIChunk, err := common.TransformStreamResponse(chunk, streamID)
if err != nil {
errChan <- fmt.Errorf("error converting final chunk: %w", err)
return
}
streamChan <- openAIChunk
return
}

var chunk worker.LlmStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
errChan <- fmt.Errorf("error unmarshalling SSE chunk: %w", err)
continue
}

totalTokens += chunk.TokensUsed
openAIChunk, err := common.TransformStreamResponse(chunk, streamID)
if err != nil {
errChan <- fmt.Errorf("error converting chunk: %w", err)
return
}

streamChan <- openAIChunk
}
}

if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("error reading SSE stream: %w", err)
}
}()

return streamChan, errChan
}

func (g *Gateway) HandleStreamingResponse(w http.ResponseWriter, r *http.Request, resp *http.Response) error {
ctx := r.Context()
streamChan, errChan := common.HandleStreamingResponse(ctx, resp)
func (g *Gateway) HandleStreamingResponse(ctx context.Context, req *worker.BodyGenLLM, w http.ResponseWriter, resp *http.Response) error {
streamChan, errChan := common.HandleStreamingResponse(ctx, req, resp)

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand Down
3 changes: 2 additions & 1 deletion models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type OpenAIResponse struct {
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
Usage *Usage `json:"usage"`
}

type OpenAIStreamResponse struct {
Expand All @@ -34,6 +34,7 @@ type OpenAIStreamResponse struct {
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}

type StreamChoice struct {
Expand Down
2 changes: 1 addition & 1 deletion server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *Server) handleChatCompletion(c *gin.Context) {

// Handle streaming response and
// forward stream to caller in OpenAPI format
if err := s.gateway.HandleStreamingResponse(c.Writer, c.Request, resp); err != nil {
if err := s.gateway.HandleStreamingResponse(c.Request.Context(), req, c.Writer, resp); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
Expand Down

0 comments on commit 0c62654

Please sign in to comment.