Skip to content

Commit

Permalink
🛠️ #45: Added WRR routing strategy (#75)
Browse files Browse the repository at this point in the history
* #45: Added weight field and WRR routing strategy

* #45: Covered the distribution by tests

* #45 Updated the openAPI specs

* #45 linting
  • Loading branch information
roma-glushko authored Jan 14, 2024
1 parent 6aec59f commit 5c233f6
Show file tree
Hide file tree
Showing 14 changed files with 303 additions and 22 deletions.
3 changes: 3 additions & 0 deletions docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ const docTemplate = `{
},
"openai": {
"$ref": "#/definitions/openai.Config"
},
"weight": {
"type": "integer"
}
}
},
Expand Down
3 changes: 3 additions & 0 deletions docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@
},
"openai": {
"$ref": "#/definitions/openai.Config"
},
"weight": {
"type": "integer"
}
}
},
Expand Down
2 changes: 2 additions & 0 deletions docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ definitions:
$ref: '#/definitions/latency.Config'
openai:
$ref: '#/definitions/openai.Config'
weight:
type: integer
required:
- id
type: object
Expand Down
4 changes: 3 additions & 1 deletion pkg/providers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type LangModelConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Is the model enabled?
ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"`
Latency *latency.Config `yaml:"latency" json:"latency"`
Weight int `yaml:"weight" json:"weight"`
Client *clients.ClientConfig `yaml:"client" json:"client"`
OpenAI *openai.Config `yaml:"openai" json:"openai"`
AzureOpenAI *azureopenai.Config `yaml:"azureopenai" json:"azureopenai"`
Expand All @@ -36,6 +37,7 @@ func DefaultLangModelConfig() *LangModelConfig {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Weight: 1,
}
}

Expand All @@ -61,7 +63,7 @@ func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error)
}

if client != nil {
return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency), nil
return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency, c.Weight), nil
}

return nil, ErrProviderNotFound
Expand Down
9 changes: 8 additions & 1 deletion pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Model interface {
Healthy() bool
Latency() *latency.MovingAverage
LatencyUpdateInterval() *time.Duration
Weight() int
}

type LanguageModel interface {
Expand All @@ -33,21 +34,23 @@ type LanguageModel interface {
// LangModel wraps provider client and expend it with health & latency tracking
type LangModel struct {
modelID string
weight int
client LangModelProvider
rateLimit *health.RateLimitTracker
errorBudget *health.TokenBucket // TODO: centralize provider API health tracking in the registry
latency *latency.MovingAverage
latencyUpdateInterval *time.Duration
}

func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget, latencyConfig latency.Config) *LangModel {
func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget, latencyConfig latency.Config, weight int) *LangModel {
return &LangModel{
modelID: modelID,
client: client,
rateLimit: health.NewRateLimitTracker(),
errorBudget: health.NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()),
latency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples),
latencyUpdateInterval: latencyConfig.UpdateInterval,
weight: weight,
}
}

Expand All @@ -71,6 +74,10 @@ func (m *LangModel) Healthy() bool {
return !m.rateLimit.Limited() && m.errorBudget.HasTokens()
}

func (m *LangModel) Weight() int {
return m.weight
}

func (m *LangModel) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) {
// TODO: we may want to track time-to-first-byte to "normalize" response latency wrt response size
startedAt := time.Now()
Expand Down
8 changes: 7 additions & 1 deletion pkg/providers/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ type LangModelMock struct {
modelID string
healthy bool
latency *latency.MovingAverage
weight int
}

func NewLangModelMock(ID string, healthy bool, avgLatency float64) *LangModelMock {
func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) *LangModelMock {
movingAverage := latency.NewMovingAverage(0.06, 3)

if avgLatency > 0.0 {
Expand All @@ -72,6 +73,7 @@ func NewLangModelMock(ID string, healthy bool, avgLatency float64) *LangModelMoc
modelID: ID,
healthy: healthy,
latency: movingAverage,
weight: weight,
}
}

Expand All @@ -92,3 +94,7 @@ func (m *LangModelMock) LatencyUpdateInterval() *time.Duration {

return &updateInterval
}

func (m *LangModelMock) Weight() int {
return m.weight
}
4 changes: 3 additions & 1 deletion pkg/routers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ func (c *LangRouterConfig) BuildRouting(models []providers.LanguageModel) (routi

switch c.RoutingStrategy {
case routing.Priority:
return routing.NewPriorityRouting(m), nil
return routing.NewPriority(m), nil
case routing.RoundRobin:
return routing.NewRoundRobinRouting(m), nil
case routing.WeightedRoundRobin:
return routing.NewWeightedRoundRobin(m), nil
case routing.LeastLatency:
return routing.NewLeastLatencyRouting(m), nil
}
Expand Down
21 changes: 16 additions & 5 deletions pkg/routers/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
*budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}}),
*budget,
*latConfig,
1,
),
}

Expand All @@ -46,7 +48,7 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -72,18 +74,21 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}),
*budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}),
*budget,
*latConfig,
1,
),
providers.NewLangModel(
"third",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
*budget,
*latConfig,
1,
),
}

Expand All @@ -98,7 +103,7 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -124,12 +129,14 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}),
*budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}),
*budget,
*latConfig,
1,
),
}

Expand All @@ -142,7 +149,7 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -163,12 +170,14 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
providers.NewProviderMock([]providers.ResponseMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}),
*budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
*budget,
*latConfig,
1,
),
}

Expand All @@ -181,7 +190,7 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -204,12 +213,14 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
*budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
*budget,
*latConfig,
1,
),
}

Expand All @@ -222,7 +233,7 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/routers/routing/least_latency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, model.latency))
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, model.latency, 1))
}

routing := NewLeastLatencyRouting(models)
Expand Down Expand Up @@ -108,6 +108,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) {
model.modelID,
model.healthy,
model.latency,
1,
),
expireAt: model.expireAt,
})
Expand Down Expand Up @@ -142,7 +143,7 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) {
models := make([]providers.Model, 0, len(latencies))

for idx, latency := range latencies {
models = append(models, providers.NewLangModelMock(strconv.Itoa(idx), false, latency))
models = append(models, providers.NewLangModelMock(strconv.Itoa(idx), false, latency, 1))
}

routing := NewLeastLatencyRouting(models)
Expand Down
2 changes: 1 addition & 1 deletion pkg/routers/routing/priority.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type PriorityRouting struct {
models []providers.Model
}

func NewPriorityRouting(models []providers.Model) *PriorityRouting {
func NewPriority(models []providers.Model) *PriorityRouting {
return &PriorityRouting{
models: models,
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/routers/routing/priority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100))
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1))
}

routing := NewPriorityRouting(models)
routing := NewPriority(models)
iterator := routing.Iterator()

// loop three times over the whole pool to check if we return back to the begging of the list
Expand All @@ -47,12 +47,12 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {

func TestPriorityRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
providers.NewLangModelMock("first", false, 0),
providers.NewLangModelMock("second", false, 0),
providers.NewLangModelMock("third", false, 0),
providers.NewLangModelMock("first", false, 0, 1),
providers.NewLangModelMock("second", false, 0, 1),
providers.NewLangModelMock("third", false, 0, 1),
}

routing := NewPriorityRouting(models)
routing := NewPriority(models)
iterator := routing.Iterator()

_, err := iterator.Next()
Expand Down
8 changes: 4 additions & 4 deletions pkg/routers/routing/round_robin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100))
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1))
}

routing := NewRoundRobinRouting(models)
Expand All @@ -50,9 +50,9 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {

func TestRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
providers.NewLangModelMock("first", false, 0),
providers.NewLangModelMock("second", false, 0),
providers.NewLangModelMock("third", false, 0),
providers.NewLangModelMock("first", false, 0, 1),
providers.NewLangModelMock("second", false, 0, 1),
providers.NewLangModelMock("third", false, 0, 1),
}

routing := NewRoundRobinRouting(models)
Expand Down
Loading

0 comments on commit 5c233f6

Please sign in to comment.