diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go index b11d3374..c1250e72 100644 --- a/pkg/providers/lang.go +++ b/pkg/providers/lang.go @@ -70,6 +70,10 @@ func (m LanguageModel) ChatLatency() *latency.MovingAverage { return m.chatLatency } +func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage { + return m.chatStreamLatency +} + func (m LanguageModel) Healthy() bool { return !m.rateLimit.Limited() && m.errBudget.HasTokens() } @@ -88,10 +92,10 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) return resp, err } - var rle *clients.RateLimitError + var rateLimitErr *clients.RateLimitError - if errors.As(err, &rle) { - m.rateLimit.SetLimited(rle.UntilReset()) + if errors.As(err, &rateLimitErr) { + m.rateLimit.SetLimited(rateLimitErr.UntilReset()) return resp, err } @@ -102,8 +106,23 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) } func (m *LanguageModel) ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error { - // TODO: implement health & latency tracking - return m.client.ChatStream(ctx, request, responseC) + err := m.client.ChatStream(ctx, request, responseC) + + if err == nil { + return err + } + + var rateLimitErr *clients.RateLimitError + + if errors.As(err, &rateLimitErr) { + m.rateLimit.SetLimited(rateLimitErr.UntilReset()) + + return err + } + + _ = m.errBudget.Take(1) + + return err } func (m *LanguageModel) SupportChatStream() bool { diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go index fac129b1..422fd863 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/routers/routing/least_latency.go @@ -14,6 +14,7 @@ const ( LeastLatency Strategy = "least_latency" ) +// LatencyGetter defines where to find latency for the specific model action type LatencyGetter = func(model providers.Model) *latency.MovingAverage // ModelSchedule defines latency update schedule for models