Skip to content

Commit

Permalink
#43 Defined a LangModel struct and merged all health tracking code there
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Jan 7, 2024
1 parent 06ebe93 commit f64828c
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 78 deletions.
1 change: 1 addition & 0 deletions pkg/config/fields/error_budget.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package fields
4 changes: 2 additions & 2 deletions pkg/providers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ func DefaultLangModelConfig() *LangModelConfig {
}
}

func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (LanguageModel, error) {
func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error) {
if c.OpenAI != nil {
client, err := openai.NewClient(c.OpenAI, tel)
if err != nil {
return nil, fmt.Errorf("error initing openai client: %v", err)
}

return client, nil
return NewLangModel(c.ID, client), nil
}

return nil, ErrProviderNotFound
Expand Down
51 changes: 45 additions & 6 deletions pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,55 @@ package providers
import (
"context"

"glide/pkg/routers/health"

"glide/pkg/api/schemas"
)

// ModelProvider defines an interface all model providers should support
type ModelProvider interface {
// LangModelProvider defines an interface a provider should fulfill to be able to serve language chat requests
type LangModelProvider interface {
Provider() string
Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error)
}

// LanguageModel defines the interface a provider should fulfill to be able to serve language chat requests
type LanguageModel interface {
ID() string
Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error)
// LangModel
type LangModel struct {
modelID string
client LangModelProvider
rateLimit *health.RateLimitTracker
errorBudget *health.TokenBucket // TODO: centralize provider API health tracking in the registry
}

func NewLangModel(modelID string, client LangModelProvider) *LangModel {
return &LangModel{
modelID: modelID,
client: client,
rateLimit: health.NewRateLimitTracker(),
errorBudget: health.NewTokenBucket(1, 10), // TODO: set from configs
}
}

func (m *LangModel) ID() string {
return m.modelID
}

func (m *LangModel) Provider() string {
return m.client.Provider()
}

func (m *LangModel) Healthy() bool {
return !m.rateLimit.Limited() && m.errorBudget.HasTokens()
}

func (m *LangModel) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) {
resp, err := m.client.Chat(ctx, request)

if err == nil {
// successful response
return resp, err
}

// TODO: track all availability issues

return resp, err
}
4 changes: 2 additions & 2 deletions pkg/routers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ type LangRouterConfig struct {
}

// BuildModels creates LanguageModel slice out of the given config
func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.LanguageModel, error) {
func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LangModel, error) {
var errs error

models := make([]providers.LanguageModel, 0, len(c.Models))
models := make([]*providers.LangModel, 0, len(c.Models))

for _, modelConfig := range c.Models {
if !modelConfig.Enabled {
Expand Down
48 changes: 0 additions & 48 deletions pkg/routers/health/trackers.go

This file was deleted.

17 changes: 6 additions & 11 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"context"
"errors"

"glide/pkg/providers"

"glide/pkg/api/schemas"
"glide/pkg/routers/health"
"glide/pkg/routers/routing"
"glide/pkg/telemetry"
)
Expand All @@ -15,7 +16,7 @@ var ErrNoModels = errors.New("no models configured for router")
type LangRouter struct {
Config *LangRouterConfig
routing routing.LangModelRouting
models *[]health.LangModelHealthTracker
models []*providers.LangModel
telemetry *telemetry.Telemetry
}

Expand All @@ -25,16 +26,10 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter
return nil, err
}

modelTrackers := make([]health.LangModelHealthTracker, 0, len(models))

for _, model := range models {
modelTrackers = append(modelTrackers, *health.NewLangModelHealthTracker(model))
}

router := &LangRouter{
Config: cfg,
models: &modelTrackers,
routing: routing.NewPriorityRouting(&modelTrackers),
models: models,
routing: routing.NewPriorityRouting(models),
telemetry: tel,
}

Expand All @@ -46,7 +41,7 @@ func (r *LangRouter) ID() string {
}

func (r *LangRouter) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) {
if len(*r.models) == 0 {
if len(r.models) == 0 {
return nil, ErrNoModels
}

Expand Down
14 changes: 7 additions & 7 deletions pkg/routers/routing/priority.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package routing
import (
"sync/atomic"

"glide/pkg/routers/health"
"glide/pkg/providers"
)

const (
Expand All @@ -15,10 +15,10 @@ const (
// Priority of models are defined as position of the model on the list
// (e.g. the first model definition has the highest priority, then the second model definition and so on)
type PriorityRouting struct {
models *[]health.LangModelHealthTracker
models []*providers.LangModel
}

func NewPriorityRouting(models *[]health.LangModelHealthTracker) *PriorityRouting {
func NewPriorityRouting(models []*providers.LangModel) *PriorityRouting {
return &PriorityRouting{
models: models,
}
Expand All @@ -35,11 +35,11 @@ func (r *PriorityRouting) Iterator() LangModelIterator {

type PriorityIterator struct {
idx *atomic.Uint64
models *[]health.LangModelHealthTracker
models []*providers.LangModel
}

func (r PriorityIterator) Next() (*health.LangModelHealthTracker, error) {
models := *r.models
func (r PriorityIterator) Next() (*providers.LangModel, error) {
models := r.models
idx := r.idx.Load()

for int(idx) < len(models) {
Expand All @@ -48,7 +48,7 @@ func (r PriorityIterator) Next() (*health.LangModelHealthTracker, error) {
r.idx.Add(1)

if model.Healthy() {
return &model, nil
return model, nil
}

// otherwise, try to pick the next model on the list
Expand Down
4 changes: 2 additions & 2 deletions pkg/routers/routing/strategies.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package routing
import (
"errors"

"glide/pkg/routers/health"
"glide/pkg/providers"
)

var ErrNoHealthyModels = errors.New("no healthy models found")
Expand All @@ -16,5 +16,5 @@ type LangModelRouting interface {
}

type LangModelIterator interface {
Next() (*health.LangModelHealthTracker, error)
Next() (*providers.LangModel, error)
}

0 comments on commit f64828c

Please sign in to comment.