Skip to content

Commit

Permalink
fix: improve performance for checking for configured model providers
Browse files Browse the repository at this point in the history
This change simplifies the process for checking for configured model
providers. When listing model providers, instead of making a call to
reveal each credential, this change will make one call to list
credentials.

A side effect is that empty environment variables are not treated as
"configured" while there were not previously. Care is taken to remove
empty environment variables from the credential before saving it.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Dec 10, 2024
1 parent 11d536a commit d841401
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 95 deletions.
13 changes: 6 additions & 7 deletions apiclient/types/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ type ToolReferenceManifest struct {
type ToolReference struct {
Metadata
ToolReferenceManifest
Resolved bool `json:"resolved,omitempty"`
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credential string `json:"credential,omitempty"`
Params map[string]string `json:"params,omitempty"`
ModelProviderStatus *ModelProviderStatus `json:"modelProviderStatus,omitempty"`
Resolved bool `json:"resolved,omitempty"`
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credential string `json:"credential,omitempty"`
Params map[string]string `json:"params,omitempty"`
}

type ToolReferenceList List[ToolReference]
5 changes: 0 additions & 5 deletions apiclient/types/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 35 additions & 9 deletions pkg/api/handlers/availablemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func NewAvailableModelsHandler(gClient *gptscript.GPTScript, dispatcher *dispatc
}

func (a *AvailableModelsHandler) List(req api.Context) error {
var modelProviderReference v1.ToolReferenceList
if err := req.List(&modelProviderReference, &kclient.ListOptions{
var modelProviderReferences v1.ToolReferenceList
if err := req.List(&modelProviderReferences, &kclient.ListOptions{
Namespace: req.Namespace(),
FieldSelector: fields.SelectorFromSet(map[string]string{
"spec.type": string(types.ToolReferenceTypeModelProvider),
Expand All @@ -39,11 +39,27 @@ func (a *AvailableModelsHandler) List(req api.Context) error {
return err
}

credCtxs := make([]string, 0, len(modelProviderReferences.Items))
for _, ref := range modelProviderReferences.Items {
credCtxs = append(credCtxs, string(ref.UID))
}

creds, err := a.gptscript.ListCredentials(req.Context(), gptscript.ListCredentialsOptions{
CredentialContexts: credCtxs,
})
if err != nil {
return fmt.Errorf("failed to list model provider credentials: %w", err)
}

credMap := make(map[string]map[string]string, len(creds))
for _, cred := range creds {
credMap[cred.Context+cred.ToolName] = cred.Env
}

var oModels openai.ModelsList
for _, modelProvider := range modelProviderReference.Items {
if convertedModelProvider, err := convertModelProviderToolRef(req.Context(), a.gptscript, modelProvider); err != nil {
return fmt.Errorf("failed to determine if model provider %q is configured: %w", modelProvider.Name, err)
} else if !convertedModelProvider.Configured || modelProvider.Name == system.ModelProviderTool {
for _, modelProvider := range modelProviderReferences.Items {
convertedModelProvider := convertModelProviderToolRef(modelProvider, credMap[string(modelProvider.UID)+modelProvider.Name])
if !convertedModelProvider.Configured || modelProvider.Name == system.ModelProviderTool {
continue
}

Expand Down Expand Up @@ -74,9 +90,19 @@ func (a *AvailableModelsHandler) ListForModelProvider(req api.Context) error {
return types.NewErrBadRequest("%s is not a model provider", modelProviderReference.Name)
}

if modelProvider, err := convertModelProviderToolRef(req.Context(), a.gptscript, modelProviderReference); err != nil {
return fmt.Errorf("failed to determine if model provider is configured: %w", err)
} else if !modelProvider.Configured {
var credEnvVars map[string]string
if modelProviderReference.Status.Tool != nil {
if envVars := modelProviderReference.Status.Tool.Metadata["envVars"]; envVars != "" {
cred, err := a.gptscript.RevealCredential(req.Context(), []string{string(modelProviderReference.UID)}, modelProviderReference.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential for model provider %q: %w", modelProviderReference.Name, err)
} else if err == nil {
credEnvVars = cred.Env
}
}
}

if modelProvider := convertModelProviderToolRef(modelProviderReference, credEnvVars); !modelProvider.Configured {
return types.NewErrBadRequest("model provider %s is not configured, missing configuration parameters: %s", modelProviderReference.Name, strings.Join(modelProvider.MissingConfigurationParameters, ", "))
}

Expand Down
102 changes: 64 additions & 38 deletions pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handlers

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -39,12 +38,19 @@ func (mp *ModelProviderHandler) ByID(req api.Context) error {
)
}

modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref)
if err != nil {
return err
var credEnvVars map[string]string
if ref.Status.Tool != nil {
if envVars := ref.Status.Tool.Metadata["envVars"]; envVars != "" {
cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential for model provider %q: %w", ref.Name, err)
} else if err == nil {
credEnvVars = cred.Env
}
}
}

return req.Write(modelProvider)
return req.Write(convertToolReferenceToModelProvider(ref, credEnvVars))
}

func (mp *ModelProviderHandler) List(req api.Context) error {
Expand All @@ -58,14 +64,26 @@ func (mp *ModelProviderHandler) List(req api.Context) error {
return err
}

resp := make([]types.ModelProvider, 0, len(refList.Items))
credCtxs := make([]string, 0, len(refList.Items))
for _, ref := range refList.Items {
modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref)
if err != nil {
return fmt.Errorf("failed to determine model provider status: %w", err)
}
credCtxs = append(credCtxs, string(ref.UID))
}

resp = append(resp, modelProvider)
creds, err := mp.gptscript.ListCredentials(req.Context(), gptscript.ListCredentialsOptions{
CredentialContexts: credCtxs,
})
if err != nil {
return fmt.Errorf("failed to list model provider credentials: %w", err)
}

credMap := make(map[string]map[string]string, len(creds))
for _, cred := range creds {
credMap[cred.Context+cred.ToolName] = cred.Env
}

resp := make([]types.ModelProvider, 0, len(refList.Items))
for _, ref := range refList.Items {
resp = append(resp, convertToolReferenceToModelProvider(ref, credMap[string(ref.UID)+ref.Name]))
}

return req.Write(types.ModelProviderList{Items: resp})
Expand All @@ -91,6 +109,12 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error {
return fmt.Errorf("failed to update credential: %w", err)
}

for key, val := range envVars {
if val == "" {
delete(envVars, key)
}
}

if err := mp.gptscript.CreateCredential(req.Context(), gptscript.Credential{
Context: string(ref.UID),
ToolName: ref.Name,
Expand Down Expand Up @@ -120,12 +144,18 @@ func (mp *ModelProviderHandler) Reveal(req api.Context) error {
return err
}

if ref.Spec.Type != types.ToolReferenceTypeModelProvider {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential: %w", err)
} else if err == nil {
return req.Write(cred.Env)
}

return req.Write(cred.Env)
return types.NewErrNotFound("no credential found for %q", ref.Name)
}

func (mp *ModelProviderHandler) RefreshModels(req api.Context) error {
Expand All @@ -138,11 +168,19 @@ func (mp *ModelProviderHandler) RefreshModels(req api.Context) error {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref)
if err != nil {
return err
var credEnvVars map[string]string
if ref.Status.Tool != nil {
if envVars := ref.Status.Tool.Metadata["envVars"]; envVars != "" {
cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential for model provider %q: %w", ref.Name, err)
} else if err == nil {
credEnvVars = cred.Env
}
}
}

modelProvider := convertToolReferenceToModelProvider(ref, credEnvVars)
if !modelProvider.Configured {
return types.NewErrBadRequest("model provider %s is not configured, missing configuration parameters: %s", modelProvider.ModelProviderManifest.Name, strings.Join(modelProvider.MissingConfigurationParameters, ", "))
}
Expand All @@ -156,19 +194,14 @@ func (mp *ModelProviderHandler) RefreshModels(req api.Context) error {
delete(ref.Annotations, v1.ModelProviderSyncAnnotation)
}

if err = req.Update(&ref); err != nil {
if err := req.Update(&ref); err != nil {
return fmt.Errorf("failed to sync models for model provider %q: %w", ref.Name, err)
}

return req.Write(modelProvider)
}

func convertToolReferenceToModelProvider(ctx context.Context, gClient *gptscript.GPTScript, ref v1.ToolReference) (types.ModelProvider, error) {
status, err := convertModelProviderToolRef(ctx, gClient, ref)
if err != nil {
return types.ModelProvider{}, err
}

func convertToolReferenceToModelProvider(ref v1.ToolReference, credEnvVars map[string]string) types.ModelProvider {
name := ref.Name
if ref.Status.Tool != nil {
name = ref.Status.Tool.Name
Expand All @@ -180,34 +213,27 @@ func convertToolReferenceToModelProvider(ctx context.Context, gClient *gptscript
Name: name,
ToolReference: ref.Spec.Reference,
},
ModelProviderStatus: *status,
ModelProviderStatus: *convertModelProviderToolRef(ref, credEnvVars),
}

mp.Type = "modelprovider"

return mp, nil
return mp
}

func convertModelProviderToolRef(ctx context.Context, gptscript *gptscript.GPTScript, toolRef v1.ToolReference) (*types.ModelProviderStatus, error) {
func convertModelProviderToolRef(toolRef v1.ToolReference, cred map[string]string) *types.ModelProviderStatus {
var (
requiredEnvVars, missingEnvVars []string
icon string
)
if toolRef.Status.Tool != nil {
if toolRef.Status.Tool.Metadata["envVars"] != "" {
cred, err := gptscript.RevealCredential(ctx, []string{string(toolRef.UID)}, toolRef.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return nil, fmt.Errorf("failed to reveal credential for model provider %q: %w", toolRef.Name, err)
}

if toolRef.Status.Tool.Metadata["envVars"] != "" {
requiredEnvVars = strings.Split(toolRef.Status.Tool.Metadata["envVars"], ",")
}
requiredEnvVars = strings.Split(toolRef.Status.Tool.Metadata["envVars"], ",")
}

for _, envVar := range requiredEnvVars {
if cred.Env[envVar] == "" {
missingEnvVars = append(missingEnvVars, envVar)
}
for _, envVar := range requiredEnvVars {
if _, ok := cred[envVar]; !ok {
missingEnvVars = append(missingEnvVars, envVar)
}
}

Expand All @@ -227,5 +253,5 @@ func convertModelProviderToolRef(ctx context.Context, gptscript *gptscript.GPTSc
ModelsBackPopulated: modelsPopulated,
RequiredConfigurationParameters: requiredEnvVars,
MissingConfigurationParameters: missingEnvVars,
}, nil
}
}
36 changes: 6 additions & 30 deletions pkg/api/handlers/toolreferences.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handlers

import (
"context"
"fmt"
"net/http"
"regexp"
Expand All @@ -27,7 +26,7 @@ func NewToolReferenceHandler(gClient *gptscript.GPTScript) *ToolReferenceHandler
}
}

func convertToolReference(ctx context.Context, gClient *gptscript.GPTScript, toolRef v1.ToolReference) (types.ToolReference, error) {
func convertToolReference(toolRef v1.ToolReference) types.ToolReference {
tf := types.ToolReference{
Metadata: MetadataFrom(&toolRef),
ToolReferenceManifest: types.ToolReferenceManifest{
Expand All @@ -52,11 +51,7 @@ func convertToolReference(ctx context.Context, gClient *gptscript.GPTScript, too
tf.Credential = toolRef.Status.Tool.Credential
}

var err error
if toolRef.Spec.Type == types.ToolReferenceTypeModelProvider {
tf.ModelProviderStatus, err = convertModelProviderToolRef(ctx, gClient, toolRef)
}
return tf, err
return tf
}

func (a *ToolReferenceHandler) ByID(req api.Context) error {
Expand All @@ -69,12 +64,7 @@ func (a *ToolReferenceHandler) ByID(req api.Context) error {
return err
}

tr, err := convertToolReference(req.Context(), a.gptscript, toolRef)
if err != nil {
return err
}

return req.Write(tr)
return req.Write(convertToolReference(toolRef))
}

var validCharsRegexp = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
Expand Down Expand Up @@ -148,12 +138,7 @@ func (a *ToolReferenceHandler) Create(req api.Context) (err error) {
return err
}

tr, err := convertToolReference(req.Context(), a.gptscript, *toolRef)
if err != nil {
return err
}

return req.Write(tr)
return req.Write(convertToolReference(*toolRef))
}

func (a *ToolReferenceHandler) Delete(req api.Context) error {
Expand Down Expand Up @@ -213,12 +198,7 @@ func (a *ToolReferenceHandler) Update(req api.Context) error {
return err
}

tr, err := convertToolReference(req.Context(), a.gptscript, existing)
if err != nil {
return err
}

return req.Write(tr)
return req.Write(convertToolReference(existing))
}

func (a *ToolReferenceHandler) List(req api.Context) error {
Expand All @@ -234,11 +214,7 @@ func (a *ToolReferenceHandler) List(req api.Context) error {
var resp types.ToolReferenceList
for _, toolRef := range toolRefList.Items {
if toolType == "" || toolRef.Spec.Type == toolType {
tr, err := convertToolReference(req.Context(), a.gptscript, toolRef)
if err != nil {
return err
}
resp.Items = append(resp.Items, tr)
resp.Items = append(resp.Items, convertToolReference(toolRef))
}
}

Expand Down
7 changes: 1 addition & 6 deletions pkg/storage/openapi/generated/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d841401

Please sign in to comment.