Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add api key auth and usage tracking #7

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Built binary
openai-api

# env
.env

# Binaries for programs and plugins
*.exe
*.exe~
Expand Down
10 changes: 6 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ FROM golang:1.21 AS builder
# Set the working directory inside the container
WORKDIR /app

# Copy from local
COPY . .
COPY . ./

# Download all the dependencies
RUN go mod download

# Generate the Prisma Client Go client
RUN go generate ./db

# Build the Go application with CGO disabled and statically linked
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o app .
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o api .

# Use a minimal base image for running the application
FROM alpine:latest
Expand All @@ -28,4 +30,4 @@ COPY --from=builder /app/app .
EXPOSE 8080

# Set the entry point to run the binary
ENTRYPOINT ["./app"]
ENTRYPOINT ["./api"]
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Default target
all: build
all: db build

# Build the Go application
build:
go build openai-api.go
go build openai-api.go

db:
go generate ./db
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,23 @@ type LlmStreamChunk struct {
Done bool `json:"done,omitempty"`
}
```

## Prisma & Postgres

1. Install the auto-generated query builder for go

```sh
go get github.com/steebchen/prisma-client-go
```

To generate the schema from an existing database, specify the DB source in our `schema.prisma` file and pull the schema

```sh
go run github.com/steebchen/prisma-client-go db pull
```

Once we have our schema we can generate our client bindings:

```sh
make db
```
2 changes: 2 additions & 0 deletions db/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# gitignore generated by Prisma Client Go. DO NOT EDIT.
*_gen.go
146 changes: 146 additions & 0 deletions db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
//go:generate go run github.com/steebchen/prisma-client-go generate

package db

import (
"context"
"errors"
"time"
)

type APIKeyStore struct {
client *PrismaClient
}

func NewAPIKeyStore() (*APIKeyStore, error) {
client := NewClient()
if err := client.Prisma.Connect(); err != nil {
return nil, err
}

return &APIKeyStore{
client: client,
}, nil
}

func (s *APIKeyStore) ValidateAndGetAPIKey(ctx context.Context, apiKey string) (*APIKeyModel, error) {
key, err := s.client.APIKey.FindFirst(
APIKey.Key.Equals(apiKey),
APIKey.IsActive.Equals(true),
).With(
APIKey.User.Fetch(),
).Exec(ctx)

if err != nil {
if errors.Is(err, ErrNotFound) {
return nil, errors.New("invalid API key")
}
return nil, err
}

return key, nil
}

func (s *APIKeyStore) RecordAPIUsage(ctx context.Context, apiKeyID string, userID string, tokens int) error {
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())

// First get current usage (outside transaction)
user, err := s.client.User.FindUnique(
User.ID.Equals(userID),
).Exec(ctx)
if err != nil {
return err
}

// Calculate new usage values
dailyUsage := user.DailyTokenUsage
if user.UsageResetAt.Before(startOfDay) {
dailyUsage = 0
}

// Create transaction for updates
updateUserQuery := s.client.User.FindUnique(
User.ID.Equals(userID),
).Update(
User.DailyTokenUsage.Set(dailyUsage+tokens),
User.UsageResetAt.Set(now),
).Tx()

updateAPIKeyQuery := s.client.APIKey.FindUnique(
APIKey.ID.Equals(apiKeyID),
).Update(
APIKey.DailyTokenUsage.Set(tokens),
APIKey.UsageResetAt.Set(now),
APIKey.LastUsedAt.Set(now),
).Tx()

upsertUsageQuery := s.client.DailyUsage.UpsertOne(
DailyUsage.DateUserIDAPIKeyID(
DailyUsage.Date.Equals(startOfDay),
DailyUsage.UserID.Equals(userID),
DailyUsage.APIKeyID.Equals(apiKeyID)),
).Create(
DailyUsage.Date.Set(startOfDay),
DailyUsage.User.Link(User.ID.Equals(userID)),
DailyUsage.APIKeyID.Set(apiKeyID),
DailyUsage.TokenUsage.Set(tokens),
).Update(
DailyUsage.TokenUsage.Increment(tokens),
).Tx()

// Execute transaction for updates only
err = s.client.Prisma.TX.Transaction(
updateUserQuery,
updateAPIKeyQuery,
upsertUsageQuery,
).Exec(ctx)

return err
}

func (s *APIKeyStore) GetUserForAPIKey(ctx context.Context, apiKey string) (*UserModel, error) {
user, err := s.client.User.FindFirst(
User.APIKey.Some(
APIKey.Key.Equals(apiKey),
),
).Exec(ctx)

if err != nil {
return nil, err
}

return user, nil
}

func (s *APIKeyStore) CheckUsageLimit(ctx context.Context, userID string, tokens int) (bool, error) {
user, err := s.client.User.FindUnique(
User.ID.Equals(userID),
).Exec(ctx)

if err != nil {
return false, err
}

now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())

// Reset daily usage if it's a new day
dailyUsage := user.DailyTokenUsage
if user.UsageResetAt.Before(startOfDay) {
dailyUsage = 0
}

// Get limit based on user tier
var limit int
switch user.Tier {
case UserTierFree:
limit = 10000
case UserTierPro:
limit = 100000
case UserTierEnterprise:
limit = 1000000
}

return (dailyUsage + tokens) <= limit, nil
}
104 changes: 104 additions & 0 deletions db/schema.prisma
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}

generator db {
provider = "go run github.com/steebchen/prisma-client-go"
output = "./"
}

model APIKey {
id String @id
key String @unique
name String?
userId String
type ApiKeyType @default(CUSTOMER)
dailyTokenUsage Int @default(0)
usageResetAt DateTime @default(now())
createdAt DateTime @default(now())
lastUsedAt DateTime @default(now())
isActive Boolean @default(true)
User User @relation(fields: [userId], references: [id])
DailyUsage DailyUsage[]

@@index([key, isActive])
@@index([userId, lastUsedAt])
@@index([userId, type])
}

model Conversation {
id String @id
title String
userId String
createdAt DateTime @default(now())
updatedAt DateTime
User User @relation(fields: [userId], references: [id])
Message Message[]

@@index([userId, updatedAt])
}

model DailyUsage {
id String @id @default(cuid())
date DateTime @db.Date
tokenUsage Int @default(0)
userId String
apiKeyId String?
createdAt DateTime @default(now())
APIKey APIKey? @relation(fields: [apiKeyId], references: [id])
User User @relation(fields: [userId], references: [id])

@@unique([date, userId, apiKeyId])
@@index([apiKeyId, date])
@@index([userId, date])
}

model Message {
id String @id
orderIndex Int @default(autoincrement())
content String
role MessageRole
conversationId String
createdAt DateTime @default(now())
Conversation Conversation @relation(fields: [conversationId], references: [id])

@@index([conversationId, createdAt])
@@index([conversationId, orderIndex])
}

model User {
id String @id
name String?
email String? @unique
image String?
loginType String?
dailyTokenUsage Int @default(0)
usageResetAt DateTime @default(now())
tier UserTier @default(FREE)
isActive Boolean @default(true)
createdAt DateTime @default(now())
updatedAt DateTime
APIKey APIKey[]
Conversation Conversation[]
DailyUsage DailyUsage[]

@@index([email])
}

enum ApiKeyType {
MASTER
CUSTOMER
}

enum MessageRole {
system
user
assistant
}

enum UserTier {
FREE
PRO
ENTERPRISE
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ require (
github.com/gin-gonic/gin v1.10.0
github.com/golang/glog v1.2.2
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/livepeer/ai-worker v0.7.0
github.com/shopspring/decimal v1.4.0
github.com/steebchen/prisma-client-go v0.42.0
)

require (
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down Expand Up @@ -119,9 +121,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/steebchen/prisma-client-go v0.42.0 h1:83keN+4jGvoTccCKCk74UU5JQj6pOwPcg3/zkoqxKJE=
github.com/steebchen/prisma-client-go v0.42.0/go.mod h1:wp2xU9HO5WIefc65vcl1HOiFUzaHKyOhHw5atrzs8hc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
Expand Down
7 changes: 6 additions & 1 deletion openai-api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"flag"
"log"

"github.com/livepool-io/openai-middleware/db"
"github.com/livepool-io/openai-middleware/middleware"
"github.com/livepool-io/openai-middleware/server"
)
Expand All @@ -12,8 +13,12 @@ func main() {
gatewayURL := flag.String("gateway", "http://your-api-host", "The URL of the gateway API")
port := flag.String("port", "8080", "The port to run the server on")
flag.Parse()
apiKeyStore, err := db.NewAPIKeyStore()
if err != nil {
log.Fatalf("Failed to create Supabase API key store: %v", err)
}
gateway := middleware.NewGateway(*gatewayURL)
server, err := server.NewServer(gateway)
server, err := server.NewServer(apiKeyStore, gateway)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
}
Expand Down
Loading
Loading