diff --git a/pkg/api/handlers/modelprovider.go b/pkg/api/handlers/modelprovider.go index 5ba57850..230223ff 100644 --- a/pkg/api/handlers/modelprovider.go +++ b/pkg/api/handlers/modelprovider.go @@ -155,7 +155,16 @@ func (mp *ModelProviderHandler) Deconfigure(req api.Context) error { // Stop the model provider so that the credential is completely removed from the system. mp.dispatcher.StopModelProvider(ref.Namespace, ref.Name) - return nil + if ref.Annotations[v1.ModelProviderSyncAnnotation] == "" { + if ref.Annotations == nil { + ref.Annotations = make(map[string]string, 1) + } + ref.Annotations[v1.ModelProviderSyncAnnotation] = "true" + } else { + delete(ref.Annotations, v1.ModelProviderSyncAnnotation) + } + + return req.Update(&ref) } func (mp *ModelProviderHandler) Reveal(req api.Context) error { diff --git a/pkg/controller/handlers/toolreference/toolreference.go b/pkg/controller/handlers/toolreference/toolreference.go index 5633257b..7eb2dd0a 100644 --- a/pkg/controller/handlers/toolreference/toolreference.go +++ b/pkg/controller/handlers/toolreference/toolreference.go @@ -3,6 +3,7 @@ package toolreference import ( "context" "crypto/sha256" + "errors" "fmt" "net/url" "os" @@ -370,8 +371,8 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro cred, err := h.gptClient.RevealCredential(req.Ctx, []string{string(toolRef.UID)}, toolRef.Name) if err != nil { if strings.Contains(err.Error(), "credential not found") { - // Model provider is not configured, don't error - return nil + // Unable to find credential, ensure all models remove for this model provider + return removeModelsForProvider(req.Ctx, req.Client, req.Namespace, req.Name) } return err } @@ -420,6 +421,27 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro return nil } +func removeModelsForProvider(ctx context.Context, c client.Client, namespace, name string) error { + var models v1.ModelList + if err := c.List(ctx, &models, &client.ListOptions{ + Namespace: namespace, + FieldSelector: fields.SelectorFromSet(fields.Set{ + "spec.manifest.modelProvider": name, + }), + }); err != nil { + return fmt.Errorf("failed to list models for model provider %q for cleanup: %w", name, err) + } + + var errs []error + for _, model := range models.Items { + if err := client.IgnoreNotFound(c.Delete(ctx, &model)); err != nil { + errs = append(errs, fmt.Errorf("failed to delete model %q for cleanup: %w", model.Name, err)) + } + } + + return errors.Join(errs...) +} + func (h *Handler) CleanupModelProvider(req router.Request, _ router.Response) error { toolRef := req.Object.(*v1.ToolReference) if toolRef.Spec.Type != types.ToolReferenceTypeModelProvider || toolRef.Status.Tool == nil { @@ -432,23 +454,7 @@ func (h *Handler) CleanupModelProvider(req router.Request, _ router.Response) er } } - var models v1.ModelList - if err := req.List(&models, &client.ListOptions{ - Namespace: req.Namespace, - FieldSelector: fields.SelectorFromSet(fields.Set{ - "spec.manifest.modelProvider": toolRef.Name, - }), - }); err != nil { - return fmt.Errorf("failed to list models for model provider %q for cleanup: %w", toolRef.Name, err) - } - - for _, model := range models.Items { - if err := client.IgnoreNotFound(req.Delete(&model)); err != nil { - return fmt.Errorf("failed to delete model %q for cleanup: %w", model.Name, err) - } - } - - return nil + return removeModelsForProvider(req.Ctx, req.Client, req.Namespace, req.Name) } func modelName(modelProviderName, modelName string) string {