From a5ee4324b0e4ea1341cf872ba8e949e87bf7dadf Mon Sep 17 00:00:00 2001
From: Donnie Adams <donnie@acorn.io>
Date: Wed, 4 Dec 2024 12:28:04 -0500
Subject: [PATCH] fix: restart model providers when its credential changes

Signed-off-by: Donnie Adams <donnie@acorn.io>
---
 go.mod                                        |  2 +-
 go.sum                                        |  4 +-
 pkg/api/handlers/modelprovider.go             | 11 ++-
 pkg/api/router/router.go                      |  2 +-
 pkg/controller/controller.go                  |  2 +-
 .../handlers/toolreference/toolreference.go   | 95 +++++++++++--------
 pkg/gateway/server/dispatcher/dispatcher.go   | 24 ++++-
 pkg/server/server.go                          | 22 ++---
 8 files changed, 97 insertions(+), 65 deletions(-)

diff --git a/go.mod b/go.mod
index e45d687bc..a53f0bf5f 100644
--- a/go.mod
+++ b/go.mod
@@ -17,7 +17,7 @@ require (
 	github.com/gptscript-ai/chat-completion-client v0.0.0-20241127005108-02b41e1cd02e
 	github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a
 	github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc
-	github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858
+	github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94
 	github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de
 	github.com/mhale/smtpd v0.8.3
 	github.com/oauth2-proxy/oauth2-proxy/v7 v7.0.0-00010101000000-000000000000
diff --git a/go.sum b/go.sum
index 8f18530e0..655f33979 100644
--- a/go.sum
+++ b/go.sum
@@ -341,8 +341,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a h1:LX7AOcbBoTnUk/
 github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
 github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc h1:D0N65peenVgPor9Ph0LiMLfgI4qCE9VGxgw8KxlvLgE=
 github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
-github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858 h1:FCHDABLrqOgxppU0NqLW4PlyWeF/K8WD969PQso3mVY=
-github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ=
+github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94 h1:Vkaujp51uMhix6Mag2LMHBCR8XEBXbnvxR9klUiG5iE=
+github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ=
 github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8=
 github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4=
 github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=
diff --git a/pkg/api/handlers/modelprovider.go b/pkg/api/handlers/modelprovider.go
index 26e6fb440..2151609f4 100644
--- a/pkg/api/handlers/modelprovider.go
+++ b/pkg/api/handlers/modelprovider.go
@@ -8,18 +8,21 @@ import (
 	"github.com/gptscript-ai/go-gptscript"
 	"github.com/otto8-ai/otto8/apiclient/types"
 	"github.com/otto8-ai/otto8/pkg/api"
+	"github.com/otto8-ai/otto8/pkg/gateway/server/dispatcher"
 	v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1"
 	"k8s.io/apimachinery/pkg/fields"
 	kclient "sigs.k8s.io/controller-runtime/pkg/client"
 )
 
 type ModelProviderHandler struct {
-	gptscript *gptscript.GPTScript
+	gptscript  *gptscript.GPTScript
+	dispatcher *dispatcher.Dispatcher
 }
 
-func NewModelProviderHandler(gClient *gptscript.GPTScript) *ModelProviderHandler {
+func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler {
 	return &ModelProviderHandler{
-		gptscript: gClient,
+		gptscript:  gClient,
+		dispatcher: dispatcher,
 	}
 }
 
@@ -93,6 +96,8 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error {
 		return fmt.Errorf("failed to create credential: %w", err)
 	}
 
+	mp.dispatcher.StopModelProvider(ref.Namespace, ref.Name)
+
 	return nil
 }
 
diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go
index ffa2ff78d..5d3820749 100644
--- a/pkg/api/router/router.go
+++ b/pkg/api/router/router.go
@@ -23,7 +23,7 @@ func Router(services *services.Services) (http.Handler, error) {
 	cronJobs := handlers.NewCronJobHandler()
 	models := handlers.NewModelHandler(services.GPTClient)
 	availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher)
-	modelProviders := handlers.NewModelProviderHandler(services.GPTClient)
+	modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher)
 	prompt := handlers.NewPromptHandler(services.GPTClient)
 	emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName)
 	defaultModelAliases := handlers.NewDefaultModelAliasHandler()
diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go
index c95640e88..d55c8624f 100644
--- a/pkg/controller/controller.go
+++ b/pkg/controller/controller.go
@@ -37,7 +37,7 @@ func (c *Controller) PostStart(ctx context.Context) error {
 		return fmt.Errorf("failed to apply data: %w", err)
 	}
 	go c.toolRefHandler.PollRegistry(ctx, c.services.Router.Backend())
-	return nil
+	return c.toolRefHandler.EnsureOpenAIEnvCredential(ctx, c.services.Router.Backend())
 }
 
 func (c *Controller) Start(ctx context.Context) error {
diff --git a/pkg/controller/handlers/toolreference/toolreference.go b/pkg/controller/handlers/toolreference/toolreference.go
index f986a5daf..d1701c39b 100644
--- a/pkg/controller/handlers/toolreference/toolreference.go
+++ b/pkg/controller/handlers/toolreference/toolreference.go
@@ -16,6 +16,7 @@ import (
 	"github.com/otto8-ai/otto8/logger"
 	v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1"
 	"github.com/otto8-ai/otto8/pkg/system"
+	apierrors "k8s.io/apimachinery/pkg/api/errors"
 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 	"sigs.k8s.io/controller-runtime/pkg/client"
 	"sigs.k8s.io/yaml"
@@ -193,7 +194,6 @@ func (h *Handler) PollRegistry(ctx context.Context, c client.Client) {
 		break
 	}
 
-	var openAICredentialSet bool
 	t := time.NewTicker(time.Hour)
 	defer t.Stop()
 	for {
@@ -201,45 +201,6 @@ func (h *Handler) PollRegistry(ctx context.Context, c client.Client) {
 			log.Errorf("Failed to read from registry: %v", err)
 		}
 
-		// If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists.
-		if !openAICredentialSet && os.Getenv("OPENAI_API_KEY") != "" {
-			var openAIModelProvider v1.ToolReference
-
-			// Not reporting errors here, best-effort only.
-			if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); err == nil {
-				if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil && strings.HasSuffix(err.Error(), "credential not found") {
-					// The credential doesn't exist, so create it.
-					err = h.gptClient.CreateCredential(ctx, gptscript.Credential{
-						Context:  string(openAIModelProvider.UID),
-						ToolName: "openai-model-provider",
-						Type:     gptscript.CredentialTypeModelProvider,
-						Env: map[string]string{
-							"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
-						},
-					})
-
-					openAICredentialSet = err == nil
-				} else if err == nil && cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") {
-					// If the credential exists, but has a different value, then update it.
-					// The only way to update it is to delete the existing credential and recreate it.
-					if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err == nil {
-						err = h.gptClient.CreateCredential(ctx, gptscript.Credential{
-							Context:  string(openAIModelProvider.UID),
-							ToolName: "openai-model-provider",
-							Type:     gptscript.CredentialTypeModelProvider,
-							Env: map[string]string{
-								"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
-							},
-						})
-
-						openAICredentialSet = err == nil
-					}
-				} else {
-					openAICredentialSet = true
-				}
-			}
-		}
-
 		select {
 		case <-t.C:
 		case <-ctx.Done():
@@ -297,6 +258,60 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error {
 	return nil
 }
 
+func (h *Handler) EnsureOpenAIEnvCredential(ctx context.Context, c client.Client) error {
+	if os.Getenv("OPENAI_API_KEY") == "" {
+		return nil
+	}
+
+	for {
+		select {
+		case <-time.After(2 * time.Second):
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
+		// If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists.
+		var openAIModelProvider v1.ToolReference
+		if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); apierrors.IsNotFound(err) {
+			continue
+		} else if err != nil {
+			return err
+		}
+
+		if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil {
+			if strings.HasSuffix(err.Error(), "credential not found") {
+				// The credential doesn't exist, so create it.
+				return h.gptClient.CreateCredential(ctx, gptscript.Credential{
+					Context:  string(openAIModelProvider.UID),
+					ToolName: "openai-model-provider",
+					Type:     gptscript.CredentialTypeModelProvider,
+					Env: map[string]string{
+						"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
+					},
+				})
+			}
+
+			return fmt.Errorf("failed to check OpenAI credential: %w", err)
+		} else if cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") {
+			// If the credential exists, but has a different value, then update it.
+			// The only way to update it is to delete the existing credential and recreate it.
+			if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err != nil {
+				return fmt.Errorf("failed to delete credential: %w", err)
+			}
+			return h.gptClient.CreateCredential(ctx, gptscript.Credential{
+				Context:  string(openAIModelProvider.UID),
+				ToolName: "openai-model-provider",
+				Type:     gptscript.CredentialTypeModelProvider,
+				Env: map[string]string{
+					"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
+				},
+			})
+		}
+
+		return nil
+	}
+}
+
 func (h *Handler) RemoveModelProviderCredential(req router.Request, _ router.Response) error {
 	toolRef := req.Object.(*v1.ToolReference)
 	if toolRef.Spec.Type != types.ToolReferenceTypeModelProvider || toolRef.Status.Tool == nil || toolRef.Status.Tool.Metadata["envVars"] == "" {
diff --git a/pkg/gateway/server/dispatcher/dispatcher.go b/pkg/gateway/server/dispatcher/dispatcher.go
index 088ead29b..816e0f9ef 100644
--- a/pkg/gateway/server/dispatcher/dispatcher.go
+++ b/pkg/gateway/server/dispatcher/dispatcher.go
@@ -43,11 +43,12 @@ func New(invoker *invoke.Invoker, c kclient.Client, gClient *gptscript.GPTScript
 }
 
 func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelProviderName string) (*url.URL, error) {
+	key := namespace + "/" + modelProviderName
 	// Check the map with the read lock.
 	d.lock.RLock()
-	u, ok := d.urls[modelProviderName]
+	u, ok := d.urls[key]
 	d.lock.RUnlock()
-	if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) {
+	if ok && (u.Hostname() == "127.0.0.1" || engine.IsDaemonRunning(u.String())) {
 		return u, nil
 	}
 
@@ -56,8 +57,8 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr
 
 	// If we didn't find anything with the read lock, check with the write lock.
 	// It could be that another thread beat us to the write lock and added the model provider we desire.
-	u, ok = d.urls[modelProviderName]
-	if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) {
+	u, ok = d.urls[key]
+	if ok && (u.Hostname() != "127.0.0.1" || engine.IsDaemonRunning(u.String())) {
 		return u, nil
 	}
 
@@ -67,10 +68,23 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr
 		return nil, err
 	}
 
-	d.urls[modelProviderName] = u
+	d.urls[key] = u
 	return u, nil
 }
 
+func (d *Dispatcher) StopModelProvider(namespace, modelProviderName string) {
+	key := namespace + "/" + modelProviderName
+	d.lock.Lock()
+	defer d.lock.Unlock()
+
+	u := d.urls[key]
+	if u != nil && u.Hostname() == "127.0.0.1" && engine.IsDaemonRunning(u.String()) {
+		engine.StopDaemon(u.String())
+	}
+
+	delete(d.urls, key)
+}
+
 func (d *Dispatcher) TransformRequest(req *http.Request, namespace string) error {
 	body, err := readBody(req)
 	if err != nil {
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 1a30a912f..7010808d6 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -20,18 +20,16 @@ func Run(ctx context.Context, c services.Config) error {
 		return err
 	}
 
-	go func() {
-		c, err := controller.New(svcs)
-		if err != nil {
-			log.Fatalf("Failed to start controller: %v", err)
-		}
-		if err := c.Start(ctx); err != nil {
-			log.Fatalf("Failed to start controller: %v", err)
-		}
-		if err := c.PostStart(ctx); err != nil {
-			log.Fatalf("Failed to post start controller: %v", err)
-		}
-	}()
+	ctrl, err := controller.New(svcs)
+	if err != nil {
+		log.Fatalf("Failed to start controller: %v", err)
+	}
+	if err = ctrl.Start(ctx); err != nil {
+		log.Fatalf("Failed to start controller: %v", err)
+	}
+	if err = ctrl.PostStart(ctx); err != nil {
+		log.Fatalf("Failed to post start controller: %v", err)
+	}
 
 	handler, err := router.Router(svcs)
 	if err != nil {