Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: restart model providers when its credential changes #761

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ require (
github.com/gptscript-ai/chat-completion-client v0.0.0-20241127005108-02b41e1cd02e
github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc
github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858
github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94
github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de
github.com/mhale/smtpd v0.8.3
github.com/oauth2-proxy/oauth2-proxy/v7 v7.0.0-00010101000000-000000000000
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a h1:LX7AOcbBoTnUk/
github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc h1:D0N65peenVgPor9Ph0LiMLfgI4qCE9VGxgw8KxlvLgE=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858 h1:FCHDABLrqOgxppU0NqLW4PlyWeF/K8WD969PQso3mVY=
github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ=
github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94 h1:Vkaujp51uMhix6Mag2LMHBCR8XEBXbnvxR9klUiG5iE=
github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ=
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8=
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=
Expand Down
11 changes: 8 additions & 3 deletions pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@ import (
"github.com/gptscript-ai/go-gptscript"
"github.com/otto8-ai/otto8/apiclient/types"
"github.com/otto8-ai/otto8/pkg/api"
"github.com/otto8-ai/otto8/pkg/gateway/server/dispatcher"
v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1"
"k8s.io/apimachinery/pkg/fields"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

type ModelProviderHandler struct {
gptscript *gptscript.GPTScript
gptscript *gptscript.GPTScript
dispatcher *dispatcher.Dispatcher
}

func NewModelProviderHandler(gClient *gptscript.GPTScript) *ModelProviderHandler {
func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler {
return &ModelProviderHandler{
gptscript: gClient,
gptscript: gClient,
dispatcher: dispatcher,
}
}

Expand Down Expand Up @@ -93,6 +96,8 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error {
return fmt.Errorf("failed to create credential: %w", err)
}

mp.dispatcher.StopModelProvider(ref.Namespace, ref.Name)

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Router(services *services.Services) (http.Handler, error) {
cronJobs := handlers.NewCronJobHandler()
models := handlers.NewModelHandler(services.GPTClient)
availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher)
prompt := handlers.NewPromptHandler(services.GPTClient)
emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName)
defaultModelAliases := handlers.NewDefaultModelAliasHandler()
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *Controller) PostStart(ctx context.Context) error {
return fmt.Errorf("failed to apply data: %w", err)
}
go c.toolRefHandler.PollRegistry(ctx, c.services.Router.Backend())
return nil
return c.toolRefHandler.EnsureOpenAIEnvCredential(ctx, c.services.Router.Backend())
}

func (c *Controller) Start(ctx context.Context) error {
Expand Down
95 changes: 55 additions & 40 deletions pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/otto8-ai/otto8/logger"
v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/otto8-ai/otto8/pkg/system"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/yaml"
Expand Down Expand Up @@ -193,53 +194,13 @@ func (h *Handler) PollRegistry(ctx context.Context, c client.Client) {
break
}

var openAICredentialSet bool
t := time.NewTicker(time.Hour)
defer t.Stop()
for {
if err := h.readFromRegistry(ctx, c); err != nil {
log.Errorf("Failed to read from registry: %v", err)
}

// If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists.
if !openAICredentialSet && os.Getenv("OPENAI_API_KEY") != "" {
var openAIModelProvider v1.ToolReference

// Not reporting errors here, best-effort only.
if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); err == nil {
if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil && strings.HasSuffix(err.Error(), "credential not found") {
// The credential doesn't exist, so create it.
err = h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})

openAICredentialSet = err == nil
} else if err == nil && cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") {
// If the credential exists, but has a different value, then update it.
// The only way to update it is to delete the existing credential and recreate it.
if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err == nil {
err = h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})

openAICredentialSet = err == nil
}
} else {
openAICredentialSet = true
}
}
}

select {
case <-t.C:
case <-ctx.Done():
Expand Down Expand Up @@ -297,6 +258,60 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error {
return nil
}

func (h *Handler) EnsureOpenAIEnvCredential(ctx context.Context, c client.Client) error {
if os.Getenv("OPENAI_API_KEY") == "" {
return nil
}

for {
select {
case <-time.After(2 * time.Second):
case <-ctx.Done():
return ctx.Err()
}

// If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists.
var openAIModelProvider v1.ToolReference
if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); apierrors.IsNotFound(err) {
continue
} else if err != nil {
return err
}

if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil {
if strings.HasSuffix(err.Error(), "credential not found") {
// The credential doesn't exist, so create it.
return h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})
}

return fmt.Errorf("failed to check OpenAI credential: %w", err)
} else if cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") {
// If the credential exists, but has a different value, then update it.
// The only way to update it is to delete the existing credential and recreate it.
if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err != nil {
return fmt.Errorf("failed to delete credential: %w", err)
}
return h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})
}

return nil
}
}

func (h *Handler) RemoveModelProviderCredential(req router.Request, _ router.Response) error {
toolRef := req.Object.(*v1.ToolReference)
if toolRef.Spec.Type != types.ToolReferenceTypeModelProvider || toolRef.Status.Tool == nil || toolRef.Status.Tool.Metadata["envVars"] == "" {
Expand Down
24 changes: 19 additions & 5 deletions pkg/gateway/server/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ func New(invoker *invoke.Invoker, c kclient.Client, gClient *gptscript.GPTScript
}

func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelProviderName string) (*url.URL, error) {
key := namespace + "/" + modelProviderName
// Check the map with the read lock.
d.lock.RLock()
u, ok := d.urls[modelProviderName]
u, ok := d.urls[key]
d.lock.RUnlock()
if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) {
if ok && (u.Hostname() == "127.0.0.1" || engine.IsDaemonRunning(u.String())) {
return u, nil
}

Expand All @@ -56,8 +57,8 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr

// If we didn't find anything with the read lock, check with the write lock.
// It could be that another thread beat us to the write lock and added the model provider we desire.
u, ok = d.urls[modelProviderName]
if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) {
u, ok = d.urls[key]
if ok && (u.Hostname() != "127.0.0.1" || engine.IsDaemonRunning(u.String())) {
return u, nil
}

Expand All @@ -67,10 +68,23 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr
return nil, err
}

d.urls[modelProviderName] = u
d.urls[key] = u
return u, nil
}

func (d *Dispatcher) StopModelProvider(namespace, modelProviderName string) {
key := namespace + "/" + modelProviderName
d.lock.Lock()
defer d.lock.Unlock()

u := d.urls[key]
if u != nil && u.Hostname() == "127.0.0.1" && engine.IsDaemonRunning(u.String()) {
engine.StopDaemon(u.String())
}

delete(d.urls, key)
}

func (d *Dispatcher) TransformRequest(req *http.Request, namespace string) error {
body, err := readBody(req)
if err != nil {
Expand Down
22 changes: 10 additions & 12 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@ func Run(ctx context.Context, c services.Config) error {
return err
}

go func() {
c, err := controller.New(svcs)
if err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err := c.Start(ctx); err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err := c.PostStart(ctx); err != nil {
log.Fatalf("Failed to post start controller: %v", err)
}
}()
ctrl, err := controller.New(svcs)
if err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err = ctrl.Start(ctx); err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err = ctrl.PostStart(ctx); err != nil {
log.Fatalf("Failed to post start controller: %v", err)
}

handler, err := router.Router(svcs)
if err != nil {
Expand Down