From a5ee4324b0e4ea1341cf872ba8e949e87bf7dadf Mon Sep 17 00:00:00 2001 From: Donnie Adams <donnie@acorn.io> Date: Wed, 4 Dec 2024 12:28:04 -0500 Subject: [PATCH] fix: restart model providers when its credential changes Signed-off-by: Donnie Adams <donnie@acorn.io> --- go.mod | 2 +- go.sum | 4 +- pkg/api/handlers/modelprovider.go | 11 ++- pkg/api/router/router.go | 2 +- pkg/controller/controller.go | 2 +- .../handlers/toolreference/toolreference.go | 95 +++++++++++-------- pkg/gateway/server/dispatcher/dispatcher.go | 24 ++++- pkg/server/server.go | 22 ++--- 8 files changed, 97 insertions(+), 65 deletions(-) diff --git a/go.mod b/go.mod index e45d687bc..a53f0bf5f 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/gptscript-ai/chat-completion-client v0.0.0-20241127005108-02b41e1cd02e github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc - github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858 + github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94 github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de github.com/mhale/smtpd v0.8.3 github.com/oauth2-proxy/oauth2-proxy/v7 v7.0.0-00010101000000-000000000000 diff --git a/go.sum b/go.sum index 8f18530e0..655f33979 100644 --- a/go.sum +++ b/go.sum @@ -341,8 +341,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a h1:LX7AOcbBoTnUk/ github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc h1:D0N65peenVgPor9Ph0LiMLfgI4qCE9VGxgw8KxlvLgE= github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= -github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858 h1:FCHDABLrqOgxppU0NqLW4PlyWeF/K8WD969PQso3mVY= -github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ= +github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94 h1:Vkaujp51uMhix6Mag2LMHBCR8XEBXbnvxR9klUiG5iE= +github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ= github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8= github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM= diff --git a/pkg/api/handlers/modelprovider.go b/pkg/api/handlers/modelprovider.go index 26e6fb440..2151609f4 100644 --- a/pkg/api/handlers/modelprovider.go +++ b/pkg/api/handlers/modelprovider.go @@ -8,18 +8,21 @@ import ( "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/gateway/server/dispatcher" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" "k8s.io/apimachinery/pkg/fields" kclient "sigs.k8s.io/controller-runtime/pkg/client" ) type ModelProviderHandler struct { - gptscript *gptscript.GPTScript + gptscript *gptscript.GPTScript + dispatcher *dispatcher.Dispatcher } -func NewModelProviderHandler(gClient *gptscript.GPTScript) *ModelProviderHandler { +func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler { return &ModelProviderHandler{ - gptscript: gClient, + gptscript: gClient, + dispatcher: dispatcher, } } @@ -93,6 +96,8 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error { return fmt.Errorf("failed to create credential: %w", err) } + mp.dispatcher.StopModelProvider(ref.Namespace, ref.Name) + return nil } diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index ffa2ff78d..5d3820749 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -23,7 +23,7 @@ func Router(services *services.Services) (http.Handler, error) { cronJobs := handlers.NewCronJobHandler() models := handlers.NewModelHandler(services.GPTClient) availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher) - modelProviders := handlers.NewModelProviderHandler(services.GPTClient) + modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher) prompt := handlers.NewPromptHandler(services.GPTClient) emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName) defaultModelAliases := handlers.NewDefaultModelAliasHandler() diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index c95640e88..d55c8624f 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -37,7 +37,7 @@ func (c *Controller) PostStart(ctx context.Context) error { return fmt.Errorf("failed to apply data: %w", err) } go c.toolRefHandler.PollRegistry(ctx, c.services.Router.Backend()) - return nil + return c.toolRefHandler.EnsureOpenAIEnvCredential(ctx, c.services.Router.Backend()) } func (c *Controller) Start(ctx context.Context) error { diff --git a/pkg/controller/handlers/toolreference/toolreference.go b/pkg/controller/handlers/toolreference/toolreference.go index f986a5daf..d1701c39b 100644 --- a/pkg/controller/handlers/toolreference/toolreference.go +++ b/pkg/controller/handlers/toolreference/toolreference.go @@ -16,6 +16,7 @@ import ( "github.com/otto8-ai/otto8/logger" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" "github.com/otto8-ai/otto8/pkg/system" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/yaml" @@ -193,7 +194,6 @@ func (h *Handler) PollRegistry(ctx context.Context, c client.Client) { break } - var openAICredentialSet bool t := time.NewTicker(time.Hour) defer t.Stop() for { @@ -201,45 +201,6 @@ func (h *Handler) PollRegistry(ctx context.Context, c client.Client) { log.Errorf("Failed to read from registry: %v", err) } - // If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists. - if !openAICredentialSet && os.Getenv("OPENAI_API_KEY") != "" { - var openAIModelProvider v1.ToolReference - - // Not reporting errors here, best-effort only. - if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); err == nil { - if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil && strings.HasSuffix(err.Error(), "credential not found") { - // The credential doesn't exist, so create it. - err = h.gptClient.CreateCredential(ctx, gptscript.Credential{ - Context: string(openAIModelProvider.UID), - ToolName: "openai-model-provider", - Type: gptscript.CredentialTypeModelProvider, - Env: map[string]string{ - "OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"), - }, - }) - - openAICredentialSet = err == nil - } else if err == nil && cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") { - // If the credential exists, but has a different value, then update it. - // The only way to update it is to delete the existing credential and recreate it. - if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err == nil { - err = h.gptClient.CreateCredential(ctx, gptscript.Credential{ - Context: string(openAIModelProvider.UID), - ToolName: "openai-model-provider", - Type: gptscript.CredentialTypeModelProvider, - Env: map[string]string{ - "OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"), - }, - }) - - openAICredentialSet = err == nil - } - } else { - openAICredentialSet = true - } - } - } - select { case <-t.C: case <-ctx.Done(): @@ -297,6 +258,60 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error { return nil } +func (h *Handler) EnsureOpenAIEnvCredential(ctx context.Context, c client.Client) error { + if os.Getenv("OPENAI_API_KEY") == "" { + return nil + } + + for { + select { + case <-time.After(2 * time.Second): + case <-ctx.Done(): + return ctx.Err() + } + + // If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists. + var openAIModelProvider v1.ToolReference + if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); apierrors.IsNotFound(err) { + continue + } else if err != nil { + return err + } + + if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil { + if strings.HasSuffix(err.Error(), "credential not found") { + // The credential doesn't exist, so create it. + return h.gptClient.CreateCredential(ctx, gptscript.Credential{ + Context: string(openAIModelProvider.UID), + ToolName: "openai-model-provider", + Type: gptscript.CredentialTypeModelProvider, + Env: map[string]string{ + "OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"), + }, + }) + } + + return fmt.Errorf("failed to check OpenAI credential: %w", err) + } else if cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") { + // If the credential exists, but has a different value, then update it. + // The only way to update it is to delete the existing credential and recreate it. + if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err != nil { + return fmt.Errorf("failed to delete credential: %w", err) + } + return h.gptClient.CreateCredential(ctx, gptscript.Credential{ + Context: string(openAIModelProvider.UID), + ToolName: "openai-model-provider", + Type: gptscript.CredentialTypeModelProvider, + Env: map[string]string{ + "OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"), + }, + }) + } + + return nil + } +} + func (h *Handler) RemoveModelProviderCredential(req router.Request, _ router.Response) error { toolRef := req.Object.(*v1.ToolReference) if toolRef.Spec.Type != types.ToolReferenceTypeModelProvider || toolRef.Status.Tool == nil || toolRef.Status.Tool.Metadata["envVars"] == "" { diff --git a/pkg/gateway/server/dispatcher/dispatcher.go b/pkg/gateway/server/dispatcher/dispatcher.go index 088ead29b..816e0f9ef 100644 --- a/pkg/gateway/server/dispatcher/dispatcher.go +++ b/pkg/gateway/server/dispatcher/dispatcher.go @@ -43,11 +43,12 @@ func New(invoker *invoke.Invoker, c kclient.Client, gClient *gptscript.GPTScript } func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelProviderName string) (*url.URL, error) { + key := namespace + "/" + modelProviderName // Check the map with the read lock. d.lock.RLock() - u, ok := d.urls[modelProviderName] + u, ok := d.urls[key] d.lock.RUnlock() - if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) { + if ok && (u.Hostname() == "127.0.0.1" || engine.IsDaemonRunning(u.String())) { return u, nil } @@ -56,8 +57,8 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr // If we didn't find anything with the read lock, check with the write lock. // It could be that another thread beat us to the write lock and added the model provider we desire. - u, ok = d.urls[modelProviderName] - if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) { + u, ok = d.urls[key] + if ok && (u.Hostname() != "127.0.0.1" || engine.IsDaemonRunning(u.String())) { return u, nil } @@ -67,10 +68,23 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr return nil, err } - d.urls[modelProviderName] = u + d.urls[key] = u return u, nil } +func (d *Dispatcher) StopModelProvider(namespace, modelProviderName string) { + key := namespace + "/" + modelProviderName + d.lock.Lock() + defer d.lock.Unlock() + + u := d.urls[key] + if u != nil && u.Hostname() == "127.0.0.1" && engine.IsDaemonRunning(u.String()) { + engine.StopDaemon(u.String()) + } + + delete(d.urls, key) +} + func (d *Dispatcher) TransformRequest(req *http.Request, namespace string) error { body, err := readBody(req) if err != nil { diff --git a/pkg/server/server.go b/pkg/server/server.go index 1a30a912f..7010808d6 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -20,18 +20,16 @@ func Run(ctx context.Context, c services.Config) error { return err } - go func() { - c, err := controller.New(svcs) - if err != nil { - log.Fatalf("Failed to start controller: %v", err) - } - if err := c.Start(ctx); err != nil { - log.Fatalf("Failed to start controller: %v", err) - } - if err := c.PostStart(ctx); err != nil { - log.Fatalf("Failed to post start controller: %v", err) - } - }() + ctrl, err := controller.New(svcs) + if err != nil { + log.Fatalf("Failed to start controller: %v", err) + } + if err = ctrl.Start(ctx); err != nil { + log.Fatalf("Failed to start controller: %v", err) + } + if err = ctrl.PostStart(ctx); err != nil { + log.Fatalf("Failed to post start controller: %v", err) + } handler, err := router.Router(svcs) if err != nil {