diff --git a/apiclient/types/toolreference.go b/apiclient/types/toolreference.go index 80a0a716a..bd2aa37e2 100644 --- a/apiclient/types/toolreference.go +++ b/apiclient/types/toolreference.go @@ -21,13 +21,12 @@ type ToolReferenceManifest struct { type ToolReference struct { Metadata ToolReferenceManifest - Resolved bool `json:"resolved,omitempty"` - Error string `json:"error,omitempty"` - Builtin bool `json:"builtin,omitempty"` - Description string `json:"description,omitempty"` - Credential string `json:"credential,omitempty"` - Params map[string]string `json:"params,omitempty"` - ModelProviderStatus *ModelProviderStatus `json:"modelProviderStatus,omitempty"` + Resolved bool `json:"resolved,omitempty"` + Error string `json:"error,omitempty"` + Builtin bool `json:"builtin,omitempty"` + Description string `json:"description,omitempty"` + Credential string `json:"credential,omitempty"` + Params map[string]string `json:"params,omitempty"` } type ToolReferenceList List[ToolReference] diff --git a/apiclient/types/zz_generated.deepcopy.go b/apiclient/types/zz_generated.deepcopy.go index d2e696a4b..9b084b8ac 100644 --- a/apiclient/types/zz_generated.deepcopy.go +++ b/apiclient/types/zz_generated.deepcopy.go @@ -1662,11 +1662,6 @@ func (in *ToolReference) DeepCopyInto(out *ToolReference) { (*out)[key] = val } } - if in.ModelProviderStatus != nil { - in, out := &in.ModelProviderStatus, &out.ModelProviderStatus - *out = new(ModelProviderStatus) - (*in).DeepCopyInto(*out) - } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ToolReference. diff --git a/pkg/api/handlers/availablemodels.go b/pkg/api/handlers/availablemodels.go index a04269124..573047c33 100644 --- a/pkg/api/handlers/availablemodels.go +++ b/pkg/api/handlers/availablemodels.go @@ -29,8 +29,8 @@ func NewAvailableModelsHandler(gClient *gptscript.GPTScript, dispatcher *dispatc } func (a *AvailableModelsHandler) List(req api.Context) error { - var modelProviderReference v1.ToolReferenceList - if err := req.List(&modelProviderReference, &kclient.ListOptions{ + var modelProviderReferences v1.ToolReferenceList + if err := req.List(&modelProviderReferences, &kclient.ListOptions{ Namespace: req.Namespace(), FieldSelector: fields.SelectorFromSet(map[string]string{ "spec.type": string(types.ToolReferenceTypeModelProvider), @@ -39,11 +39,27 @@ func (a *AvailableModelsHandler) List(req api.Context) error { return err } + credCtxs := make([]string, 0, len(modelProviderReferences.Items)) + for _, ref := range modelProviderReferences.Items { + credCtxs = append(credCtxs, string(ref.UID)) + } + + creds, err := a.gptscript.ListCredentials(req.Context(), gptscript.ListCredentialsOptions{ + CredentialContexts: credCtxs, + }) + if err != nil { + return fmt.Errorf("failed to list model provider credentials: %w", err) + } + + credMap := make(map[string]map[string]string, len(creds)) + for _, cred := range creds { + credMap[cred.Context+cred.ToolName] = cred.Env + } + var oModels openai.ModelsList - for _, modelProvider := range modelProviderReference.Items { - if convertedModelProvider, err := convertModelProviderToolRef(req.Context(), a.gptscript, modelProvider); err != nil { - return fmt.Errorf("failed to determine if model provider %q is configured: %w", modelProvider.Name, err) - } else if !convertedModelProvider.Configured || modelProvider.Name == system.ModelProviderTool { + for _, modelProvider := range modelProviderReferences.Items { + convertedModelProvider := convertModelProviderToolRef(modelProvider, credMap[string(modelProvider.UID)+modelProvider.Name]) + if !convertedModelProvider.Configured || modelProvider.Name == system.ModelProviderTool { continue } @@ -74,9 +90,19 @@ func (a *AvailableModelsHandler) ListForModelProvider(req api.Context) error { return types.NewErrBadRequest("%s is not a model provider", modelProviderReference.Name) } - if modelProvider, err := convertModelProviderToolRef(req.Context(), a.gptscript, modelProviderReference); err != nil { - return fmt.Errorf("failed to determine if model provider is configured: %w", err) - } else if !modelProvider.Configured { + var credEnvVars map[string]string + if modelProviderReference.Status.Tool != nil { + if envVars := modelProviderReference.Status.Tool.Metadata["envVars"]; envVars != "" { + cred, err := a.gptscript.RevealCredential(req.Context(), []string{string(modelProviderReference.UID)}, modelProviderReference.Name) + if err != nil && !strings.HasSuffix(err.Error(), "credential not found") { + return fmt.Errorf("failed to reveal credential for model provider %q: %w", modelProviderReference.Name, err) + } else if err == nil { + credEnvVars = cred.Env + } + } + } + + if modelProvider := convertModelProviderToolRef(modelProviderReference, credEnvVars); !modelProvider.Configured { return types.NewErrBadRequest("model provider %s is not configured, missing configuration parameters: %s", modelProviderReference.Name, strings.Join(modelProvider.MissingConfigurationParameters, ", ")) } diff --git a/pkg/api/handlers/modelprovider.go b/pkg/api/handlers/modelprovider.go index cc6c5f483..c26e81977 100644 --- a/pkg/api/handlers/modelprovider.go +++ b/pkg/api/handlers/modelprovider.go @@ -1,7 +1,6 @@ package handlers import ( - "context" "fmt" "strings" @@ -39,12 +38,19 @@ func (mp *ModelProviderHandler) ByID(req api.Context) error { ) } - modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref) - if err != nil { - return err + var credEnvVars map[string]string + if ref.Status.Tool != nil { + if envVars := ref.Status.Tool.Metadata["envVars"]; envVars != "" { + cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name) + if err != nil && !strings.HasSuffix(err.Error(), "credential not found") { + return fmt.Errorf("failed to reveal credential for model provider %q: %w", ref.Name, err) + } else if err == nil { + credEnvVars = cred.Env + } + } } - return req.Write(modelProvider) + return req.Write(convertToolReferenceToModelProvider(ref, credEnvVars)) } func (mp *ModelProviderHandler) List(req api.Context) error { @@ -58,14 +64,26 @@ func (mp *ModelProviderHandler) List(req api.Context) error { return err } - resp := make([]types.ModelProvider, 0, len(refList.Items)) + credCtxs := make([]string, 0, len(refList.Items)) for _, ref := range refList.Items { - modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref) - if err != nil { - return fmt.Errorf("failed to determine model provider status: %w", err) - } + credCtxs = append(credCtxs, string(ref.UID)) + } - resp = append(resp, modelProvider) + creds, err := mp.gptscript.ListCredentials(req.Context(), gptscript.ListCredentialsOptions{ + CredentialContexts: credCtxs, + }) + if err != nil { + return fmt.Errorf("failed to list model provider credentials: %w", err) + } + + credMap := make(map[string]map[string]string, len(creds)) + for _, cred := range creds { + credMap[cred.Context+cred.ToolName] = cred.Env + } + + resp := make([]types.ModelProvider, 0, len(refList.Items)) + for _, ref := range refList.Items { + resp = append(resp, convertToolReferenceToModelProvider(ref, credMap[string(ref.UID)+ref.Name])) } return req.Write(types.ModelProviderList{Items: resp}) @@ -91,6 +109,12 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error { return fmt.Errorf("failed to update credential: %w", err) } + for key, val := range envVars { + if val == "" { + delete(envVars, key) + } + } + if err := mp.gptscript.CreateCredential(req.Context(), gptscript.Credential{ Context: string(ref.UID), ToolName: ref.Name, @@ -120,12 +144,18 @@ func (mp *ModelProviderHandler) Reveal(req api.Context) error { return err } + if ref.Spec.Type != types.ToolReferenceTypeModelProvider { + return types.NewErrBadRequest("%q is not a model provider", ref.Name) + } + cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name) if err != nil && !strings.HasSuffix(err.Error(), "credential not found") { return fmt.Errorf("failed to reveal credential: %w", err) + } else if err == nil { + return req.Write(cred.Env) } - return req.Write(cred.Env) + return types.NewErrNotFound("no credential found for %q", ref.Name) } func (mp *ModelProviderHandler) RefreshModels(req api.Context) error { @@ -138,11 +168,19 @@ func (mp *ModelProviderHandler) RefreshModels(req api.Context) error { return types.NewErrBadRequest("%q is not a model provider", ref.Name) } - modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref) - if err != nil { - return err + var credEnvVars map[string]string + if ref.Status.Tool != nil { + if envVars := ref.Status.Tool.Metadata["envVars"]; envVars != "" { + cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name) + if err != nil && !strings.HasSuffix(err.Error(), "credential not found") { + return fmt.Errorf("failed to reveal credential for model provider %q: %w", ref.Name, err) + } else if err == nil { + credEnvVars = cred.Env + } + } } + modelProvider := convertToolReferenceToModelProvider(ref, credEnvVars) if !modelProvider.Configured { return types.NewErrBadRequest("model provider %s is not configured, missing configuration parameters: %s", modelProvider.ModelProviderManifest.Name, strings.Join(modelProvider.MissingConfigurationParameters, ", ")) } @@ -156,19 +194,14 @@ func (mp *ModelProviderHandler) RefreshModels(req api.Context) error { delete(ref.Annotations, v1.ModelProviderSyncAnnotation) } - if err = req.Update(&ref); err != nil { + if err := req.Update(&ref); err != nil { return fmt.Errorf("failed to sync models for model provider %q: %w", ref.Name, err) } return req.Write(modelProvider) } -func convertToolReferenceToModelProvider(ctx context.Context, gClient *gptscript.GPTScript, ref v1.ToolReference) (types.ModelProvider, error) { - status, err := convertModelProviderToolRef(ctx, gClient, ref) - if err != nil { - return types.ModelProvider{}, err - } - +func convertToolReferenceToModelProvider(ref v1.ToolReference, credEnvVars map[string]string) types.ModelProvider { name := ref.Name if ref.Status.Tool != nil { name = ref.Status.Tool.Name @@ -180,34 +213,27 @@ func convertToolReferenceToModelProvider(ctx context.Context, gClient *gptscript Name: name, ToolReference: ref.Spec.Reference, }, - ModelProviderStatus: *status, + ModelProviderStatus: *convertModelProviderToolRef(ref, credEnvVars), } mp.Type = "modelprovider" - return mp, nil + return mp } -func convertModelProviderToolRef(ctx context.Context, gptscript *gptscript.GPTScript, toolRef v1.ToolReference) (*types.ModelProviderStatus, error) { +func convertModelProviderToolRef(toolRef v1.ToolReference, cred map[string]string) *types.ModelProviderStatus { var ( requiredEnvVars, missingEnvVars []string icon string ) if toolRef.Status.Tool != nil { if toolRef.Status.Tool.Metadata["envVars"] != "" { - cred, err := gptscript.RevealCredential(ctx, []string{string(toolRef.UID)}, toolRef.Name) - if err != nil && !strings.HasSuffix(err.Error(), "credential not found") { - return nil, fmt.Errorf("failed to reveal credential for model provider %q: %w", toolRef.Name, err) - } - - if toolRef.Status.Tool.Metadata["envVars"] != "" { - requiredEnvVars = strings.Split(toolRef.Status.Tool.Metadata["envVars"], ",") - } + requiredEnvVars = strings.Split(toolRef.Status.Tool.Metadata["envVars"], ",") + } - for _, envVar := range requiredEnvVars { - if cred.Env[envVar] == "" { - missingEnvVars = append(missingEnvVars, envVar) - } + for _, envVar := range requiredEnvVars { + if _, ok := cred[envVar]; !ok { + missingEnvVars = append(missingEnvVars, envVar) } } @@ -227,5 +253,5 @@ func convertModelProviderToolRef(ctx context.Context, gptscript *gptscript.GPTSc ModelsBackPopulated: modelsPopulated, RequiredConfigurationParameters: requiredEnvVars, MissingConfigurationParameters: missingEnvVars, - }, nil + } } diff --git a/pkg/api/handlers/toolreferences.go b/pkg/api/handlers/toolreferences.go index de173b005..132b7a405 100644 --- a/pkg/api/handlers/toolreferences.go +++ b/pkg/api/handlers/toolreferences.go @@ -1,7 +1,6 @@ package handlers import ( - "context" "fmt" "net/http" "regexp" @@ -27,7 +26,7 @@ func NewToolReferenceHandler(gClient *gptscript.GPTScript) *ToolReferenceHandler } } -func convertToolReference(ctx context.Context, gClient *gptscript.GPTScript, toolRef v1.ToolReference) (types.ToolReference, error) { +func convertToolReference(toolRef v1.ToolReference) types.ToolReference { tf := types.ToolReference{ Metadata: MetadataFrom(&toolRef), ToolReferenceManifest: types.ToolReferenceManifest{ @@ -52,11 +51,7 @@ func convertToolReference(ctx context.Context, gClient *gptscript.GPTScript, too tf.Credential = toolRef.Status.Tool.Credential } - var err error - if toolRef.Spec.Type == types.ToolReferenceTypeModelProvider { - tf.ModelProviderStatus, err = convertModelProviderToolRef(ctx, gClient, toolRef) - } - return tf, err + return tf } func (a *ToolReferenceHandler) ByID(req api.Context) error { @@ -69,12 +64,7 @@ func (a *ToolReferenceHandler) ByID(req api.Context) error { return err } - tr, err := convertToolReference(req.Context(), a.gptscript, toolRef) - if err != nil { - return err - } - - return req.Write(tr) + return req.Write(convertToolReference(toolRef)) } var validCharsRegexp = regexp.MustCompile(`[^a-zA-Z0-9-]+`) @@ -148,12 +138,7 @@ func (a *ToolReferenceHandler) Create(req api.Context) (err error) { return err } - tr, err := convertToolReference(req.Context(), a.gptscript, *toolRef) - if err != nil { - return err - } - - return req.Write(tr) + return req.Write(convertToolReference(*toolRef)) } func (a *ToolReferenceHandler) Delete(req api.Context) error { @@ -213,12 +198,7 @@ func (a *ToolReferenceHandler) Update(req api.Context) error { return err } - tr, err := convertToolReference(req.Context(), a.gptscript, existing) - if err != nil { - return err - } - - return req.Write(tr) + return req.Write(convertToolReference(existing)) } func (a *ToolReferenceHandler) List(req api.Context) error { @@ -234,11 +214,7 @@ func (a *ToolReferenceHandler) List(req api.Context) error { var resp types.ToolReferenceList for _, toolRef := range toolRefList.Items { if toolType == "" || toolRef.Spec.Type == toolType { - tr, err := convertToolReference(req.Context(), a.gptscript, toolRef) - if err != nil { - return err - } - resp.Items = append(resp.Items, tr) + resp.Items = append(resp.Items, convertToolReference(toolRef)) } } diff --git a/pkg/storage/openapi/generated/openapi_generated.go b/pkg/storage/openapi/generated/openapi_generated.go index 19dea6557..b80847318 100644 --- a/pkg/storage/openapi/generated/openapi_generated.go +++ b/pkg/storage/openapi/generated/openapi_generated.go @@ -3440,17 +3440,12 @@ func schema_otto8_ai_otto8_apiclient_types_ToolReference(ref common.ReferenceCal }, }, }, - "modelProviderStatus": { - SchemaProps: spec.SchemaProps{ - Ref: ref("github.com/otto8-ai/otto8/apiclient/types.ModelProviderStatus"), - }, - }, }, Required: []string{"Metadata", "ToolReferenceManifest"}, }, }, Dependencies: []string{ - "github.com/otto8-ai/otto8/apiclient/types.Metadata", "github.com/otto8-ai/otto8/apiclient/types.ModelProviderStatus", "github.com/otto8-ai/otto8/apiclient/types.ToolReferenceManifest"}, + "github.com/otto8-ai/otto8/apiclient/types.Metadata", "github.com/otto8-ai/otto8/apiclient/types.ToolReferenceManifest"}, } }