diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000..f1e4e84 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,58 @@ +name: Docker Build and Push + +on: + pull_request: + push: + branches: [ master ] + tags: [ 'v*.*.*' ] + +jobs: + docker: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + id-token: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: livepool/openai-api + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,format=long + type=raw,value=latest,enable={{is_default_branch}} + + - name: Login to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + provenance: false + sbom: false \ No newline at end of file diff --git a/common/utils.go b/common/utils.go index 00f81b2..2b489cc 100644 --- a/common/utils.go +++ b/common/utils.go @@ -92,22 +92,20 @@ func TransformResponse(req *worker.GenLLMFormdataRequestBody, resp *http.Respons FinishReason: "stop", }, }, - Usage: models.Usage{ - PromptTokens: len(req.Prompt), // TODO: count actual tokens - CompletionTokens: res.TokensUsed, - TotalTokens: res.TokensUsed + len(req.Prompt), // Adjust if you have prompt tokens count + Usage: &models.Usage{ + TotalTokens: res.TokensUsed, // TokensUsed already includes prompt tokens }, } return openAIResp, nil } -func TransformStreamResponse(chunk worker.LlmStreamChunk, streamID string) (models.OpenAIStreamResponse, error) { +func TransformStreamResponse(chunk worker.LlmStreamChunk, req *worker.GenLLMFormdataRequestBody, streamID string) (models.OpenAIStreamResponse, error) { openAIResp := models.OpenAIStreamResponse{ ID: streamID, - Object: "text_completion", + Object: "chat.completion.chunk", Created: time.Now().Unix(), - Model: "gpt-3.5-turbo-0301", // You might want to make this configurable + Model: *req.ModelId, Choices: []models.StreamChoice{ { Index: 0, @@ -121,12 +119,16 @@ func TransformStreamResponse(chunk worker.LlmStreamChunk, streamID string) (mode if chunk.Done { openAIResp.Choices[0].FinishReason = "stop" + // Only include usage information in the final chunk + openAIResp.Usage = &models.Usage{ + TotalTokens: chunk.TokensUsed, // TokensUsed already includes prompt tokens + } } return openAIResp, nil } -func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan models.OpenAIStreamResponse, <-chan error) { +func HandleStreamingResponse(ctx context.Context, req *worker.BodyGenLLM, resp *http.Response) (<-chan models.OpenAIStreamResponse, <-chan error) { streamChan := make(chan models.OpenAIStreamResponse) errChan := make(chan error, 1) // Buffered channel to avoid goroutine leak @@ -152,7 +154,7 @@ func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan m data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { chunk := worker.LlmStreamChunk{Chunk: "DONE", Done: true, TokensUsed: totalTokens} - openAIChunk, err := TransformStreamResponse(chunk, streamID) + openAIChunk, err := TransformStreamResponse(chunk, req, streamID) if err != nil { errChan <- fmt.Errorf("error converting final chunk: %w", err) return @@ -168,7 +170,7 @@ func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan m } totalTokens += chunk.TokensUsed - openAIChunk, err := TransformStreamResponse(chunk, streamID) + openAIChunk, err := TransformStreamResponse(chunk, req, streamID) if err != nil { errChan <- fmt.Errorf("error converting chunk: %w", err) return diff --git a/middleware/gateway.go b/middleware/gateway.go index 4e687ab..1e4ef07 100644 --- a/middleware/gateway.go +++ b/middleware/gateway.go @@ -1,19 +1,16 @@ package middleware import ( - "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" - "strings" "github.com/golang/glog" "github.com/livepeer/ai-worker/worker" "github.com/livepool-io/openai-middleware/common" - "github.com/livepool-io/openai-middleware/models" ) type Gateway struct { @@ -53,69 +50,8 @@ func (g *Gateway) PostLlmGenerate(req worker.GenLLMFormdataRequestBody) (*http.R return client.Do(httpReq) } -func HandleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan models.OpenAIStreamResponse, <-chan error) { - streamChan := make(chan models.OpenAIStreamResponse) - errChan := make(chan error, 1) // Buffered channel to avoid goroutine leak - - go func() { - defer close(streamChan) - defer close(errChan) - - streamID := common.GenerateUniqueID() - scanner := bufio.NewScanner(resp.Body) - var totalTokens int - - for scanner.Scan() { - select { - case <-ctx.Done(): - errChan <- ctx.Err() - return - default: - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") { - continue - } - - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - chunk := worker.LlmStreamChunk{Chunk: "DONE", Done: true, TokensUsed: totalTokens} - openAIChunk, err := common.TransformStreamResponse(chunk, streamID) - if err != nil { - errChan <- fmt.Errorf("error converting final chunk: %w", err) - return - } - streamChan <- openAIChunk - return - } - - var chunk worker.LlmStreamChunk - if err := json.Unmarshal([]byte(data), &chunk); err != nil { - errChan <- fmt.Errorf("error unmarshalling SSE chunk: %w", err) - continue - } - - totalTokens += chunk.TokensUsed - openAIChunk, err := common.TransformStreamResponse(chunk, streamID) - if err != nil { - errChan <- fmt.Errorf("error converting chunk: %w", err) - return - } - - streamChan <- openAIChunk - } - } - - if err := scanner.Err(); err != nil { - errChan <- fmt.Errorf("error reading SSE stream: %w", err) - } - }() - - return streamChan, errChan -} - -func (g *Gateway) HandleStreamingResponse(w http.ResponseWriter, r *http.Request, resp *http.Response) error { - ctx := r.Context() - streamChan, errChan := common.HandleStreamingResponse(ctx, resp) +func (g *Gateway) HandleStreamingResponse(ctx context.Context, req *worker.BodyGenLLM, w http.ResponseWriter, resp *http.Response) error { + streamChan, errChan := common.HandleStreamingResponse(ctx, req, resp) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") diff --git a/models/models.go b/models/models.go index 99412d1..f18f810 100644 --- a/models/models.go +++ b/models/models.go @@ -25,7 +25,7 @@ type OpenAIResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage"` } type OpenAIStreamResponse struct { @@ -34,6 +34,7 @@ type OpenAIStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []StreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` } type StreamChoice struct { diff --git a/server/http.go b/server/http.go index 29afc2c..341f116 100644 --- a/server/http.go +++ b/server/http.go @@ -70,7 +70,7 @@ func (s *Server) handleChatCompletion(c *gin.Context) { // Handle streaming response and // forward stream to caller in OpenAPI format - if err := s.gateway.HandleStreamingResponse(c.Writer, c.Request, resp); err != nil { + if err := s.gateway.HandleStreamingResponse(c.Request.Context(), req, c.Writer, resp); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return }