Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model provider configuration validation #1019

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
package handlers

import (
"encoding/json"
"fmt"
"strings"

"github.com/gptscript-ai/go-gptscript"
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/gateway/server/dispatcher"
"github.com/obot-platform/obot/pkg/invoke"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
"github.com/obot-platform/obot/pkg/system"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

type ModelProviderHandler struct {
gptscript *gptscript.GPTScript
dispatcher *dispatcher.Dispatcher
invoker *invoke.Invoker
}

func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler {
func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher, invoker *invoke.Invoker) *ModelProviderHandler {
return &ModelProviderHandler{
gptscript: gClient,
dispatcher: dispatcher,
invoker: invoker,
}
}

Expand Down Expand Up @@ -89,6 +95,78 @@ func (mp *ModelProviderHandler) List(req api.Context) error {
return req.Write(types.ModelProviderList{Items: resp})
}

type ValidationError struct {
Err string `json:"error"`
}

func (ve *ValidationError) Error() string {
return fmt.Sprintf("model-provider credentials validation failed: {\"error\": \"%s\"}", ve.Err)
thedadams marked this conversation as resolved.
Show resolved Hide resolved
}

func (mp *ModelProviderHandler) Validate(req api.Context) error {
var ref v1.ToolReference
if err := req.Get(&ref, req.PathValue("id")); err != nil {
return err
}

if ref.Spec.Type != types.ToolReferenceTypeModelProvider {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

log.Debugf("Validating model provider %q", ref.Name)

var envVars map[string]string
if err := req.Read(&envVars); err != nil {
return err
}

envs := make([]string, 0, len(envVars))
for key, val := range envVars {
envs = append(envs, key+"="+val)
}

thread := &v1.Thread{
ObjectMeta: metav1.ObjectMeta{
GenerateName: system.ThreadPrefix + "-" + ref.Name + "-validate-",
Namespace: ref.Namespace,
},
Spec: v1.ThreadSpec{
SystemTask: true,
},
}

if err := req.Create(thread); err != nil {
return fmt.Errorf("failed to create thread: %w", err)
}

defer func() { _ = req.Delete(thread) }()

task, err := mp.invoker.SystemTask(req.Context(), thread, "validate from "+ref.Spec.Reference, "", invoke.SystemTaskOptions{Env: envs})
if err != nil {
return err
}
defer task.Close()

res, err := task.Result(req.Context())
if err != nil {
if strings.Contains(err.Error(), "tool not found: validate from "+ref.Spec.Reference) { // there's no simple way to do errors.As/.Is at this point unfortunately
log.Errorf("Model provider %q does not provide a validate tool. Looking for 'validate from %s'", ref.Name, ref.Spec.Reference)
return types.NewErrNotFound(
fmt.Sprintf("`validate from %s` tool not found", ref.Spec.Reference),
ref.Name,
)
}
return &ValidationError{Err: strings.Trim(err.Error(), "\"'")}
}

var validationError ValidationError
if json.Unmarshal([]byte(res.Output), &validationError) == nil && validationError.Err != "" {
return &validationError
}

return nil
}

func (mp *ModelProviderHandler) Configure(req api.Context) error {
var ref v1.ToolReference
if err := req.Get(&ref, req.PathValue("id")); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func Router(services *services.Services) (http.Handler, error) {
cronJobs := handlers.NewCronJobHandler()
models := handlers.NewModelHandler()
availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher, services.Invoker)
prompt := handlers.NewPromptHandler(services.GPTClient)
emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName)
defaultModelAliases := handlers.NewDefaultModelAliasHandler()
Expand Down Expand Up @@ -295,6 +295,7 @@ func Router(services *services.Services) (http.Handler, error) {
// Model providers
mux.HandleFunc("GET /api/model-providers", modelProviders.List)
mux.HandleFunc("GET /api/model-providers/{id}", modelProviders.ByID)
mux.HandleFunc("POST /api/model-providers/{id}/validate", modelProviders.Validate)
mux.HandleFunc("POST /api/model-providers/{id}/configure", modelProviders.Configure)
mux.HandleFunc("POST /api/model-providers/{id}/deconfigure", modelProviders.Deconfigure)
mux.HandleFunc("POST /api/model-providers/{id}/reveal", modelProviders.Reveal)
Expand Down
8 changes: 4 additions & 4 deletions pkg/availablemodels/availablemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr

r, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String()+"/v1/models", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request to model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to create request to model provider %q: %w", modelProviderName, err)
}

if token != "" {
Expand All @@ -28,18 +28,18 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr

resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("failed to make request to model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to make request to model provider %q: %w", modelProviderName, err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
message, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get model list from model provider %s: %s", modelProviderName, message)
return nil, fmt.Errorf("failed to get model list from model provider %q: %s", modelProviderName, message)
}

var oModels openai.ModelsList
if err = json.NewDecoder(resp.Body).Decode(&oModels); err != nil {
return nil, fmt.Errorf("failed to decode model list from model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to decode model list from model provider %q: %w", modelProviderName, err)
}

return &oModels, nil
Expand Down
21 changes: 20 additions & 1 deletion pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package toolreference
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"os"
"regexp"
"strings"
"time"

Expand All @@ -28,6 +30,8 @@ import (

var log = logger.Package()

var jsonErrRegexp = regexp.MustCompile(`\{.*"error":.*}`)

type indexEntry struct {
Reference string `json:"reference,omitempty"`
All bool `json:"all,omitempty"`
Expand Down Expand Up @@ -406,7 +410,22 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro
availableModels, err := availablemodels.ForProvider(req.Ctx, h.dispatcher, req.Namespace, req.Name)
if err != nil {
// Don't error and retry because it will likely fail again. Log the error, and the user can re-sync manually.
log.Errorf("Failed to get available models for model provider %q: %v", toolRef.Name, err)
// Also, the toolRef.Status.Error field will bubble up to the user in the UI.

// Check if the model provider returned a properly formatted error message and set it as status
match := jsonErrRegexp.FindString(err.Error())
if match != "" {
toolRef.Status.Error = match
type errorResponse struct {
Error string `json:"error"`
}
var eR errorResponse
if err := json.Unmarshal([]byte(match), &eR); err == nil {
toolRef.Status.Error = eR.Error
}
}

log.Errorf("%v", err)
return nil
}

Expand Down
55 changes: 51 additions & 4 deletions ui/admin/app/components/model-providers/ModelProviderForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,21 @@ export function ModelProviderForm({
}
);

const validateAndConfigureModelProvider = useAsync(
ModelProviderApiService.validateModelProviderById,
{
onSuccess: async (data, params) => {
// Only configure the model provider if validation was successful
const [modelProviderId, configParams] = params;
await configureModelProvider.execute(modelProviderId, configParams);
},
onError: (error) => {
// Handle validation errors
console.error("Validation failed:", error);
},
}
);

const configureModelProvider = useAsync(
ModelProviderApiService.configureModelProviderById,
{
Expand Down Expand Up @@ -184,7 +199,10 @@ export function ModelProviderForm({
}
);

await configureModelProvider.execute(modelProvider.id, allConfigParams);
await validateAndConfigureModelProvider.execute(
modelProvider.id,
allConfigParams
);
}
);

Expand All @@ -193,23 +211,52 @@ export function ModelProviderForm({
modelProvider.id === "azure-openai-model-provider";

const loading =
validateAndConfigureModelProvider.isLoading ||
fetchAvailableModels.isLoading ||
configureModelProvider.isLoading ||
isLoading;

return (
<div className="flex flex-col">
{fetchAvailableModels.error !== null && (
{validateAndConfigureModelProvider.error !== null && (
<div className="px-4">
<Alert variant="destructive">
<CircleAlertIcon className="h-4 w-4" />
<AlertTitle>An error occurred!</AlertTitle>
<AlertDescription>
Your configuration was saved, but we were not able to connect to
the model provider. Please check your configuration and try again.
Your configuration could not be saved, because it failed
validation:{" "}
<strong>
{(typeof validateAndConfigureModelProvider.error === "object" &&
"message" in validateAndConfigureModelProvider.error &&
(validateAndConfigureModelProvider.error
.message as string)) ??
"Unknown error"}
</strong>
</AlertDescription>
</Alert>
</div>
)}
{validateAndConfigureModelProvider.error === null &&
fetchAvailableModels.error !== null && (
<div className="px-4">
<Alert variant="destructive">
<CircleAlertIcon className="h-4 w-4" />
<AlertTitle>An error occurred!</AlertTitle>
<AlertDescription>
Your configuration was saved, but we were not able to connect to
the model provider. Please check your configuration and try
again:{" "}
<strong>
{(typeof fetchAvailableModels.error === "object" &&
"message" in fetchAvailableModels.error &&
(fetchAvailableModels.error.message as string)) ??
"Unknown error"}
</strong>
</AlertDescription>
</Alert>
</div>
)}
<ScrollArea className="max-h-[50vh]">
<div className="flex flex-col gap-4 p-4">
<h4 className="text-md font-semibold">Required Configuration</h4>
Expand Down
11 changes: 11 additions & 0 deletions ui/admin/app/hooks/useAsync.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ export function useAsync<TData, TParams extends unknown[]>(
onSuccess?.(data, params);
})
.catch((error) => {
// If the response data is a JSON string with an error field
// unpack that error field's value into the error message
if (error.response && typeof error.response.data === "string") {
const errorMessageMatch =
error.response.data.match(/{"error":\s+"(.*?)"}/);
if (errorMessageMatch) {
const errorMessage = JSON.parse(errorMessageMatch[0]).error;
console.log("Error: ", errorMessage);
error.message = errorMessage;
}
}
setError(error);
onError?.(error, params);
})
Expand Down
2 changes: 2 additions & 0 deletions ui/admin/app/lib/routers/apiRoutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ export const ApiRoutes = {
getModelProviders: () => buildUrl("/model-providers"),
getModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}`),
validateModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}/validate`),
configureModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}/configure`),
revealModelProviderById: (modelProviderKey: string) =>
Expand Down
17 changes: 17 additions & 0 deletions ui/admin/app/lib/service/api/modelProviderApiService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ getModelProviderById.key = (modelProviderId?: string) => {
};
};

const validateModelProviderById = async (
modelProviderKey: string,
modelProviderConfig: ModelProviderConfig
) => {
const res = await request<ModelProvider>({
url: ApiRoutes.modelProviders.validateModelProviderById(modelProviderKey)
.url,
method: "POST",
data: modelProviderConfig,
errorMessage:
"Failed to validate configuration values on the requested modal provider.",
});

return res.data;
};

const configureModelProviderById = async (
modelProviderKey: string,
modelProviderConfig: ModelProviderConfig
Expand Down Expand Up @@ -81,6 +97,7 @@ const deconfigureModelProviderById = async (modelProviderKey: string) => {
export const ModelProviderApiService = {
getModelProviders,
getModelProviderById,
validateModelProviderById,
configureModelProviderById,
revealModelProviderById,
deconfigureModelProviderById,
Expand Down