From 50325641eb417d6ba77961d26a3f75bc95485981 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 15 Nov 2024 14:17:37 -0500 Subject: [PATCH] fix: use custom oauth configurations for knowledge syncing (#597) Signed-off-by: Donnie Adams --- pkg/api/handlers/agent.go | 9 +++++---- pkg/api/handlers/assistants.go | 2 +- pkg/api/handlers/webhooks.go | 4 ++-- pkg/api/handlers/workflows.go | 1 + pkg/controller/handlers/oauthapp/oauthapplogin.go | 15 ++++++++++++--- pkg/controller/routes.go | 2 +- pkg/invoke/system.go | 3 +++ pkg/render/render.go | 8 ++++---- pkg/storage/apis/otto.otto8.ai/v1/oauthapp.go | 5 +++-- .../otto.otto8.ai/v1/zz_generated.deepcopy.go | 7 ++++++- .../openapi/generated/openapi_generated.go | 14 ++++++++++++++ 11 files changed, 52 insertions(+), 18 deletions(-) diff --git a/pkg/api/handlers/agent.go b/pkg/api/handlers/agent.go index c0f7ee51..bd4305c1 100644 --- a/pkg/api/handlers/agent.go +++ b/pkg/api/handlers/agent.go @@ -292,7 +292,7 @@ func (a *AgentHandler) CreateKnowledgeSource(req api.Context) error { var input types.KnowledgeSourceManifest if err := req.Read(&input); err != nil { - return types.NewErrBadRequest("failed to decode request body: %w", err) + return types.NewErrBadRequest("failed to decode request body: %v", err) } if err := input.Validate(); err != nil { @@ -312,7 +312,7 @@ func (a *AgentHandler) CreateKnowledgeSource(req api.Context) error { } if err := req.Create(&source); err != nil { - return types.NewErrBadRequest("failed to create RemoteKnowledgeSource: %w", err) + return types.NewErrBadRequest("failed to create RemoteKnowledgeSource: %v", err) } return req.Write(convertKnowledgeSource(agentName, source)) @@ -326,7 +326,7 @@ func (a *AgentHandler) UpdateKnowledgeSource(req api.Context) error { var manifest types.KnowledgeSourceManifest if err := req.Read(&manifest); err != nil { - return types.NewErrBadRequest("failed to decode request body: %w", err) + return types.NewErrBadRequest("failed to decode request body: %v", err) } if err := manifest.Validate(); err != nil { @@ -535,6 +535,7 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error Spec: v1.OAuthAppLoginSpec{ CredentialContext: agent.Name, ToolReference: ref, + OAuthApps: agent.Spec.Manifest.OAuthApps, }, } @@ -566,7 +567,7 @@ func (a *AgentHandler) Script(req api.Context) error { agent v1.Agent ) if err := req.Get(&agent, id); err != nil { - return types.NewErrBadRequest("failed to get agent with id %s: %w", id, err) + return types.NewErrBadRequest("failed to get agent with id %s: %v", id, err) } tools, extraEnv, err := render.Agent(req.Context(), req.Storage, &agent, a.serverURL, render.AgentOptions{}) diff --git a/pkg/api/handlers/assistants.go b/pkg/api/handlers/assistants.go index bd21d631..bfdcba21 100644 --- a/pkg/api/handlers/assistants.go +++ b/pkg/api/handlers/assistants.go @@ -346,7 +346,7 @@ func (a *AssistantHandler) DeleteKnowledge(req api.Context) error { } if len(thread.Status.KnowledgeSetNames) == 0 { - return types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("knowledge set is not created yet")) + return types.NewErrHttp(http.StatusTooEarly, "knowledge set is not created yet") } return deleteKnowledge(req, req.PathValue("file"), thread.Status.KnowledgeSetNames[0]) diff --git a/pkg/api/handlers/webhooks.go b/pkg/api/handlers/webhooks.go index 402214f6..77437d2e 100644 --- a/pkg/api/handlers/webhooks.go +++ b/pkg/api/handlers/webhooks.go @@ -283,7 +283,7 @@ func validateSecretHeader(secret string, body []byte, values []string) error { func validateManifest(req api.Context, manifest types.WebhookManifest) error { // Ensure that the WorkflowID is set and the workflow exists if manifest.WorkflowID == "" { - return apierrors.NewBadRequest(fmt.Sprintf("webhook manifest must have a workflow name")) + return apierrors.NewBadRequest("webhook manifest must have a workflow name") } var workflow v1.Workflow @@ -297,7 +297,7 @@ func validateManifest(req api.Context, manifest types.WebhookManifest) error { } if (manifest.ValidationHeader != "") != (manifest.Secret != "") { - return apierrors.NewBadRequest(fmt.Sprintf("webhook must have secret and header set together")) + return apierrors.NewBadRequest("webhook must have secret and header set together") } return nil diff --git a/pkg/api/handlers/workflows.go b/pkg/api/handlers/workflows.go index 4ad2d387..54d94dcb 100644 --- a/pkg/api/handlers/workflows.go +++ b/pkg/api/handlers/workflows.go @@ -209,6 +209,7 @@ func (a *WorkflowHandler) EnsureCredentialForKnowledgeSource(req api.Context) er Spec: v1.OAuthAppLoginSpec{ CredentialContext: wf.Name, ToolReference: ref, + OAuthApps: wf.Spec.Manifest.OAuthApps, }, } diff --git a/pkg/controller/handlers/oauthapp/oauthapplogin.go b/pkg/controller/handlers/oauthapp/oauthapplogin.go index 6ae10ed6..228608e9 100644 --- a/pkg/controller/handlers/oauthapp/oauthapplogin.go +++ b/pkg/controller/handlers/oauthapp/oauthapplogin.go @@ -8,6 +8,7 @@ import ( "github.com/otto8-ai/nah/pkg/router" "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/invoke" + "github.com/otto8-ai/otto8/pkg/render" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" "github.com/otto8-ai/otto8/pkg/system" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -16,12 +17,14 @@ import ( ) type LoginHandler struct { - invoker *invoke.Invoker + invoker *invoke.Invoker + serverURL string } -func NewLogin(invoker *invoke.Invoker) *LoginHandler { +func NewLogin(invoker *invoke.Invoker, serverURL string) *LoginHandler { return &LoginHandler{ - invoker: invoker, + invoker: invoker, + serverURL: serverURL, } } @@ -50,6 +53,11 @@ func (h *LoginHandler) RunTool(req router.Request, _ router.Response) error { return err } + oauthAppEnv, err := render.OAuthAppEnv(req.Ctx, req.Client, login.Spec.OAuthApps, login.Namespace, h.serverURL) + if err != nil { + return err + } + task, err := h.invoker.SystemTask(req.Ctx, &thread, []gptscript.ToolDef{ { Credentials: []string{credentialTool}, @@ -57,6 +65,7 @@ func (h *LoginHandler) RunTool(req router.Request, _ router.Response) error { }, }, "", invoke.SystemTaskOptions{ CredentialContextIDs: []string{login.Spec.CredentialContext}, + Env: oauthAppEnv, }) if err != nil { return err diff --git a/pkg/controller/routes.go b/pkg/controller/routes.go index 856678e6..bbcd0d2f 100644 --- a/pkg/controller/routes.go +++ b/pkg/controller/routes.go @@ -34,7 +34,7 @@ func (c *Controller) setupRoutes() error { runs := runs.New(c.services.Invoker) webHooks := webhook.New() cronJobs := cronjob.New() - oauthLogins := oauthapp.NewLogin(c.services.Invoker) + oauthLogins := oauthapp.NewLogin(c.services.Invoker, c.services.ServerURL) // Runs root.Type(&v1.Run{}).FinalizeFunc(v1.RunFinalizer, runs.DeleteRunState) diff --git a/pkg/invoke/system.go b/pkg/invoke/system.go index 2eadd94b..03d0c010 100644 --- a/pkg/invoke/system.go +++ b/pkg/invoke/system.go @@ -9,11 +9,13 @@ import ( type SystemTaskOptions struct { CredentialContextIDs []string + Env []string } func complete(opts []SystemTaskOptions) (result SystemTaskOptions) { for _, opt := range opts { result.CredentialContextIDs = append(result.CredentialContextIDs, opt.CredentialContextIDs...) + result.Env = append(result.Env, opt.Env...) } return } @@ -42,6 +44,7 @@ func (i *Invoker) SystemTask(ctx context.Context, thread *v1.Thread, tool, input } return i.createRun(ctx, i.uncached, thread, tool, inputString, runOptions{ + Env: opt.Env, CredentialContextIDs: opt.CredentialContextIDs, Synchronous: true, }) diff --git a/pkg/render/render.go b/pkg/render/render.go index e804efb0..77f751d4 100644 --- a/pkg/render/render.go +++ b/pkg/render/render.go @@ -78,7 +78,7 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU return nil, nil, err } - if oauthEnv, err := setupOAuthApps(ctx, db, agent, oauthServerURL); err != nil { + if oauthEnv, err := OAuthAppEnv(ctx, db, agent.Spec.Manifest.OAuthApps, agent.Namespace, oauthServerURL); err != nil { return nil, nil, err } else { extraEnv = append(extraEnv, oauthEnv...) @@ -87,8 +87,8 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU return append([]gptscript.ToolDef{mainTool}, otherTools...), extraEnv, nil } -func setupOAuthApps(ctx context.Context, db kclient.Client, agent *v1.Agent, serverURL string) (extraEnv []string, _ error) { - apps, err := oauthAppsByName(ctx, db, agent.Namespace) +func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string, namespace, serverURL string) (extraEnv []string, _ error) { + apps, err := oauthAppsByName(ctx, db, namespace) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func setupOAuthApps(ctx context.Context, db kclient.Client, agent *v1.Agent, ser activeIntegrations[app.Spec.Manifest.Integration] = app } - for _, appRef := range agent.Spec.Manifest.OAuthApps { + for _, appRef := range oauthAppNames { app, ok := apps[appRef] if !ok { return nil, fmt.Errorf("oauth app %s not found", appRef) diff --git a/pkg/storage/apis/otto.otto8.ai/v1/oauthapp.go b/pkg/storage/apis/otto.otto8.ai/v1/oauthapp.go index 143b3e26..d8394fb9 100644 --- a/pkg/storage/apis/otto.otto8.ai/v1/oauthapp.go +++ b/pkg/storage/apis/otto.otto8.ai/v1/oauthapp.go @@ -158,8 +158,9 @@ func (o *OAuthAppLogin) DeleteRefs() []Ref { } type OAuthAppLoginSpec struct { - CredentialContext string `json:"credentialContext,omitempty"` - ToolReference string `json:"toolReference,omitempty"` + CredentialContext string `json:"credentialContext,omitempty"` + ToolReference string `json:"toolReference,omitempty"` + OAuthApps []string `json:"oauthApps,omitempty"` } type OAuthAppLoginStatus struct { diff --git a/pkg/storage/apis/otto.otto8.ai/v1/zz_generated.deepcopy.go b/pkg/storage/apis/otto.otto8.ai/v1/zz_generated.deepcopy.go index a7118030..dc44cec7 100644 --- a/pkg/storage/apis/otto.otto8.ai/v1/zz_generated.deepcopy.go +++ b/pkg/storage/apis/otto.otto8.ai/v1/zz_generated.deepcopy.go @@ -692,7 +692,7 @@ func (in *OAuthAppLogin) DeepCopyInto(out *OAuthAppLogin) { *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) - out.Spec = in.Spec + in.Spec.DeepCopyInto(&out.Spec) in.Status.DeepCopyInto(&out.Status) } @@ -749,6 +749,11 @@ func (in *OAuthAppLoginList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *OAuthAppLoginSpec) DeepCopyInto(out *OAuthAppLoginSpec) { *out = *in + if in.OAuthApps != nil { + in, out := &in.OAuthApps, &out.OAuthApps + *out = make([]string, len(*in)) + copy(*out, *in) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OAuthAppLoginSpec. diff --git a/pkg/storage/openapi/generated/openapi_generated.go b/pkg/storage/openapi/generated/openapi_generated.go index 22b2faa4..78f941ed 100644 --- a/pkg/storage/openapi/generated/openapi_generated.go +++ b/pkg/storage/openapi/generated/openapi_generated.go @@ -4713,6 +4713,20 @@ func schema_storage_apis_ottootto8ai_v1_OAuthAppLoginSpec(ref common.ReferenceCa Format: "", }, }, + "oauthApps": { + SchemaProps: spec.SchemaProps{ + Type: []string{"array"}, + Items: &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Default: "", + Type: []string{"string"}, + Format: "", + }, + }, + }, + }, + }, }, }, },