Skip to content

Commit

Permalink
Validate JSON for request schema (jupyterlab#261)
Browse files Browse the repository at this point in the history
* Corrects capitalization for SageMaker endpoint

* WIP: Pass expected format in model for field

* Validates JSON using JSON.parse

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP: Validate JSON in magics

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix reference to error

* Update packages/jupyter-ai-magics/jupyter_ai_magics/magics.py

Co-authored-by: david qiu <[email protected]>

* Avoids redundant parameter

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: david qiu <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2023
1 parent d524707 commit 584878d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
9 changes: 9 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,15 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

# Validate that the request schema is well-formed JSON
try:
json.loads(args.request_schema)
except json.JSONDecodeError as e:
raise ValueError(
"request-schema must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
) from None

provider = Provider(**provider_params)

# generate output from model via provider
Expand Down
30 changes: 12 additions & 18 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@ class AwsAuthStrategy(BaseModel):
]


class TextField(BaseModel):
type: Literal["text"] = "text"
class Field(BaseModel):
key: str
label: str
# "text" accepts any text
format: Literal["json", "jsonpath", "text"]


class TextField(Field):
type: Literal["text"] = "text"


class MultilineTextField(BaseModel):
class MultilineTextField(Field):
type: Literal["text-multiline"] = "text-multiline"
key: str
label: str


Field = Union[TextField, MultilineTextField]
Expand Down Expand Up @@ -393,7 +396,7 @@ def transform_output(self, output: bytes) -> str:

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
# This all needs to be on one line of markdown, for use in a table
Expand All @@ -408,18 +411,9 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
auth_strategy = AwsAuthStrategy()
registry = True
fields = [
TextField(
key="region_name",
label="Region name",
),
MultilineTextField(
key="request_schema",
label="Request schema",
),
TextField(
key="response_path",
label="Response path",
),
TextField(key="region_name", label="Region name", format="text"),
MultilineTextField(key="request_schema", label="Request schema", format="json"),
TextField(key="response_path", label="Response path", format="jsonpath"),
]

def __init__(self, *args, **kwargs):
Expand Down
27 changes: 26 additions & 1 deletion packages/jupyter-ai/src/components/settings/model-fields.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React from 'react';
import React, { useState } from 'react';
import { AiService } from '../../handler';
import { TextField } from '@mui/material';

Expand All @@ -13,9 +13,30 @@ export type ModelFieldProps = {
};

export function ModelField(props: ModelFieldProps): JSX.Element {
const [errorMessage, setErrorMessage] = useState<string | null>(null);

function handleChange(
e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
) {
// Perform validation based on the field format
switch (props.field.format) {
case 'json':
try {
// JSON.parse does not allow single quotes or trailing commas
JSON.parse(e.target.value);
setErrorMessage(null);
} catch (exc) {
setErrorMessage('You must specify a value in JSON format.');
}
break;
case 'jsonpath':
// TODO: Do JSONPath validation
break;
default:
// No validation performed
break;
}

props.setConfig({
...props.config,
fields: {
Expand All @@ -34,6 +55,8 @@ export function ModelField(props: ModelFieldProps): JSX.Element {
label={props.field.label}
value={props.config.fields[props.gmid]?.[props.field.key]}
onChange={handleChange}
error={!!errorMessage}
helperText={errorMessage ?? undefined}
fullWidth
/>
);
Expand All @@ -47,6 +70,8 @@ export function ModelField(props: ModelFieldProps): JSX.Element {
onChange={handleChange}
fullWidth
multiline
error={!!errorMessage}
helperText={errorMessage ?? undefined}
minRows={2}
/>
);
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,14 @@ export namespace AiService {
type: 'text';
key: string;
label: string;
format: string;
};

export type MultilineTextField = {
type: 'text-multiline';
key: string;
label: string;
format: string;
};

export type Field = TextField | MultilineTextField;
Expand Down

0 comments on commit 584878d

Please sign in to comment.