diff --git a/pkg/api/handlers/agent.go b/pkg/api/handlers/agent.go index 39d6346e7..a9357653b 100644 --- a/pkg/api/handlers/agent.go +++ b/pkg/api/handlers/agent.go @@ -20,12 +20,15 @@ import ( type AgentHandler struct { gptscript *gptscript.GPTScript serverURL string + // This is currently a hack to access the workflow handler + workflowHandler *WorkflowHandler } func NewAgentHandler(gClient *gptscript.GPTScript, serverURL string) *AgentHandler { return &AgentHandler{ - serverURL: serverURL, - gptscript: gClient, + serverURL: serverURL, + gptscript: gClient, + workflowHandler: NewWorkflowHandler(gClient, serverURL, nil), } } @@ -484,8 +487,13 @@ func (a *AgentHandler) DeleteKnowledgeSource(req api.Context) error { } func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error { + agentID := req.PathValue("id") + if system.IsWorkflowID(agentID) { + return a.workflowHandler.EnsureCredentialForKnowledgeSource(req) + } + var agent v1.Agent - if err := req.Get(&agent, req.PathValue("agent_id")); err != nil { + if err := req.Get(&agent, agentID); err != nil { return err } diff --git a/pkg/api/handlers/workflows.go b/pkg/api/handlers/workflows.go index c25ba6f45..4ad2d3877 100644 --- a/pkg/api/handlers/workflows.go +++ b/pkg/api/handlers/workflows.go @@ -14,6 +14,7 @@ import ( "github.com/otto8-ai/otto8/pkg/render" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" "github.com/otto8-ai/otto8/pkg/system" + "github.com/otto8-ai/otto8/pkg/wait" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -165,6 +166,74 @@ func (a *WorkflowHandler) List(req api.Context) error { return req.Write(resp) } +func (a *WorkflowHandler) EnsureCredentialForKnowledgeSource(req api.Context) error { + var wf v1.Workflow + if err := req.Get(&wf, req.PathValue("id")); err != nil { + return err + } + + ref := req.PathValue("ref") + authStatus := wf.Status.External.AuthStatus[ref] + + // If auth is not required, then don't continue. + if authStatus.Required != nil && !*authStatus.Required { + return req.Write(convertWorkflow(wf, server.GetURLPrefix(req))) + } + + // if auth is already authenticated, then don't continue. + if authStatus.Authenticated { + return req.Write(convertWorkflow(wf, server.GetURLPrefix(req))) + } + + credentialTool, err := v1.CredentialTool(req.Context(), req.Storage, req.Namespace(), ref) + if err != nil { + return err + } + + if credentialTool == "" { + // The only way to get here is if the controller hasn't set the field yet. + if wf.Status.External.AuthStatus == nil { + wf.Status.External.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus) + } + + authStatus.Required = &[]bool{false}[0] + wf.Status.External.AuthStatus[ref] = authStatus + return req.Write(convertWorkflow(wf, server.GetURLPrefix(req))) + } + + oauthLogin := &v1.OAuthAppLogin{ + ObjectMeta: metav1.ObjectMeta{ + Name: system.OAuthAppLoginPrefix + wf.Name + ref, + Namespace: req.Namespace(), + }, + Spec: v1.OAuthAppLoginSpec{ + CredentialContext: wf.Name, + ToolReference: ref, + }, + } + + if err = req.Delete(oauthLogin); err != nil { + return err + } + + oauthLogin, err = wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) bool { + return obj.Status.External.Authenticated || obj.Status.External.Error != "" || obj.Status.External.URL != "" + }, wait.Option{ + Create: true, + }) + if err != nil { + return fmt.Errorf("failed to ensure credential for workflow %q: %w", wf.Name, err) + } + + // Don't need to actually update the knowledge ref, there is a controller that will do that. + if wf.Status.External.AuthStatus == nil { + wf.Status.External.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus) + } + wf.Status.External.AuthStatus[ref] = oauthLogin.Status.External + + return req.Write(convertWorkflow(wf, server.GetURLPrefix(req))) +} + func (a *WorkflowHandler) Script(req api.Context) error { var ( id = req.Request.PathValue("id") diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index 6a8ede4ab..7cedd0090 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -35,7 +35,7 @@ func Router(services *services.Services) (http.Handler, error) { mux.HandleFunc("POST /api/agents", agents.Create) mux.HandleFunc("PUT /api/agents/{id}", agents.Update) mux.HandleFunc("DELETE /api/agents/{id}", agents.Delete) - mux.HandleFunc("POST /api/agents/{agent_id}/oauth-credentials/{ref}/login", agents.EnsureCredentialForKnowledgeSource) + mux.HandleFunc("POST /api/agents/{id}/oauth-credentials/{ref}/login", agents.EnsureCredentialForKnowledgeSource) // Assistants mux.HandleFunc("GET /api/assistants", assistants.List) @@ -87,6 +87,7 @@ func Router(services *services.Services) (http.Handler, error) { mux.HandleFunc("POST /api/workflows/{id}/authenticate", workflows.Authenticate) mux.HandleFunc("PUT /api/workflows/{id}", workflows.Update) mux.HandleFunc("DELETE /api/workflows/{id}", workflows.Delete) + mux.HandleFunc("POST /api/workflows/{id}/oauth-credentials/{ref}/login", workflows.EnsureCredentialForKnowledgeSource) // Workflow knowledge files mux.HandleFunc("GET /api/workflows/{agent_id}/knowledge-files", agents.ListKnowledgeFiles)