Skip to content

Commit

Permalink
Feat/admin/model-create-and-edit-dropdown (#744)
Browse files Browse the repository at this point in the history
* feat: provide available model options from api and populate usage field

Signed-off-by: Ryan Hopper-Lowe <[email protected]>

* chore: remove unused models code

* enhance: prevent unnecessary cache revalidations for available models

* chore: upgrade model provider api usage

---------

Signed-off-by: Ryan Hopper-Lowe <[email protected]>
  • Loading branch information
ryanhopperlowe authored Dec 3, 2024
1 parent 1a7a2ea commit 13e341a
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 81 deletions.
1 change: 1 addition & 0 deletions ui/admin/app/components/model/AddModel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export function AddModel() {

<DialogContent>
<DialogTitle>Create Model</DialogTitle>

<DialogDescription hidden>Create Model</DialogDescription>

<ModelForm onSubmit={() => setOpen(false)} />
Expand Down
82 changes: 63 additions & 19 deletions ui/admin/app/components/model/ModelForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import {
ModelProvider,
ModelUsage,
getModelUsageLabel,
getModelsForProvider,
} from "~/lib/model/models";
import { ModelApiService } from "~/lib/service/api/modelApiService";
import { ModelProviderApiService } from "~/lib/service/api/modelProviderApiService";

import { ControlledCustomInput } from "~/components/form/controlledInputs";
import { Button } from "~/components/ui/button";
Expand All @@ -39,8 +39,8 @@ export function ModelForm(props: ModelFormProps) {
const { model, onSubmit } = props;

const { data: modelProviders } = useSWR(
ModelApiService.getModelProviders.key(),
ModelApiService.getModelProviders
ModelProviderApiService.getModelProviders.key(),
() => ModelProviderApiService.getModelProviders()
);

const updateModel = useAsync(ModelApiService.updateModel, {
Expand Down Expand Up @@ -76,6 +76,15 @@ export function ModelForm(props: ModelFormProps) {
defaultValues,
});

const getAvailableModels = useSWR(
ModelApiService.getAvailableModelsByProvider.key(
form.watch("modelProvider")
),
({ provider }) =>
ModelApiService.getAvailableModelsByProvider(provider),
{ revalidateIfStale: false }
);

const { loading, submit } = getSubmitInfo();

const handleSubmit = form.handleSubmit((values) =>
Expand All @@ -85,8 +94,7 @@ export function ModelForm(props: ModelFormProps) {
const providerName = (provider: ModelProvider) => {
let text = provider.name || provider.id;

if (!provider.modelProviderStatus.configured)
text += " (not configured)";
if (!provider.configured) text += " (not configured)";

return text;
};
Expand All @@ -110,10 +118,7 @@ export function ModelForm(props: ModelFormProps) {
<SelectItem
key={provider.id}
value={provider.id}
disabled={
!provider.modelProviderStatus
.configured
}
disabled={!provider.configured}
>
{providerName(provider)}
</SelectItem>
Expand All @@ -129,26 +134,21 @@ export function ModelForm(props: ModelFormProps) {
label="Target Model"
>
{({ field: { ref: _, ...field }, className }) => {
const models = getModelsForProvider(
form.watch("modelProvider")
);

return (
<Select
{...field}
disabled={!form.watch("modelProvider")}
onValueChange={field.onChange}
onValueChange={(value) => {
field.onChange(value);
updateUsageFromModel(value);
}}
>
<SelectTrigger className={className}>
<SelectValue placeholder="Select a Model" />
</SelectTrigger>

<SelectContent>
{models.map((model) => (
<SelectItem key={model} value={model}>
{model}
</SelectItem>
))}
{getModelOptions()}
</SelectContent>
</Select>
);
Expand Down Expand Up @@ -191,6 +191,50 @@ export function ModelForm(props: ModelFormProps) {
</Form>
);

function updateUsageFromModel(value: string) {
const model = getAvailableModels.data?.find((m) => m.id === value);

const usage = model?.metadata?.usage ?? ModelUsage.Other;

form.setValue("usage", usage);
}

function getModelOptions() {
if (getAvailableModels.data) {
return getAvailableModels.data.map((model) => (
<SelectItem key={model.id} value={model.id}>
{model.id}
</SelectItem>
));
}

const options: React.ReactNode[] = [];

const targetModel = form.watch("targetModel");
if (targetModel)
options.push(
<SelectItem key={targetModel} value={targetModel}>
{targetModel}
</SelectItem>
);

if (getAvailableModels.isLoading) {
options.push(
<SelectItem key="loading" value="loading" disabled>
Loading models...
</SelectItem>
);
} else if (!getAvailableModels.data) {
options.push(
<SelectItem key="no-models" value="no-models" disabled>
No models available
</SelectItem>
);
}

return options;
}

function getSubmitInfo() {
if (model) {
return {
Expand Down
10 changes: 10 additions & 0 deletions ui/admin/app/lib/model/availableModels.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { ModelUsage } from "~/lib/model/models";
import { EntityMeta } from "~/lib/model/primitives";

export type AvailableModel = EntityMeta<{ usage?: ModelUsage }> & {
object: string;
owned_by: string;
permission: string[];
root: string;
parent: string;
};
51 changes: 5 additions & 46 deletions ui/admin/app/lib/model/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,50 +47,14 @@ export const ModelManifestSchema = z.object({
usage: z.nativeEnum(ModelUsage),
});

export type ModelProvider = EntityMeta & {
description?: string;
builtin: boolean;
active: boolean;
modelProviderStatus: ModelProviderStatus;
type ModelProviderManifest = {
name: string;
reference: string;
toolType: "modelProvider";
toolReference: string;
};

// note(ryanhopperlowe): these values are hardcoded for now
// ideally they should come from the backend
const ModelToProviderMap = {
"openai-model-provider": [
"text-embedding-3-small",
"dall-e-3",
"gpt-4o-mini",
"gpt-3.5-turbo",
"text-embedding-ada-002",
"gpt-4o",
],
"azure-openai-model-provider": [
"text-embedding-3-small",
"dall-e-3",
"gpt-4o-mini",
"gpt-3.5-turbo",
"text-embedding-ada-002",
"gpt-4o",
],
"anthropic-model-provider": [
"claude-3-opus-latest",
"claude-3-5-sonnet-latest",
"claude-3-5-haiku-latest",
],
"ollama-model-provider": ["llama3.2"],
"voyage-model-provider": [
"voyage-3",
"voyage-3-lite",
"voyage-finance-2",
"voyage-multilingual-2",
"voyage-law-2",
"voyage-code-2",
],
};
export type ModelProvider = EntityMeta &
ModelProviderManifest &
ModelProviderStatus;

export const ModelAliasToUsageMap = {
llm: ModelUsage.LLM,
Expand All @@ -104,8 +68,3 @@ export function getModelUsageFromAlias(alias: string) {

return ModelAliasToUsageMap[alias as keyof typeof ModelAliasToUsageMap];
}

export function getModelsForProvider(providerId: string) {
if (!providerId || !(providerId in ModelToProviderMap)) return [];
return ModelToProviderMap[providerId as keyof typeof ModelToProviderMap];
}
6 changes: 6 additions & 0 deletions ui/admin/app/lib/routers/apiRoutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ export const ApiRoutes = {
createModel: () => buildUrl(`/models`),
updateModel: (modelId: string) => buildUrl(`/models/${modelId}`),
deleteModel: (modelId: string) => buildUrl(`/models/${modelId}`),
getAvailableModels: () => buildUrl("/available-models"),
getAvailableModelsByProvider: (provider: string) =>
buildUrl(`/available-models/${provider}`),
},
modelProviders: {
getModelProviders: () => buildUrl("/model-providers"),
},
defaultModelAliases: {
base: () => buildUrl("/default-model-aliases"),
Expand Down
27 changes: 15 additions & 12 deletions ui/admin/app/lib/service/api/modelApiService.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Model, ModelManifest, ModelProvider } from "~/lib/model/models";
import { AvailableModel } from "~/lib/model/availableModels";
import { Model, ModelManifest } from "~/lib/model/models";
import { ApiRoutes } from "~/lib/routers/apiRoutes";
import { request } from "~/lib/service/api/primitives";

Expand Down Expand Up @@ -27,19 +28,23 @@ getModelById.key = (modelId?: string) => {
};
};

async function getModelProviders() {
const { data } = await request<{ items?: ModelProvider[] }>({
url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).url,
async function getAvailableModelsByProvider(provider: string) {
const { data } = await request<{ data?: AvailableModel[] }>({
url: ApiRoutes.models.getAvailableModelsByProvider(provider).url,
});

return data.items ?? [];
return data.data ?? [];
}
getModelProviders.key = () => ({
url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).path,
});
getAvailableModelsByProvider.key = (provider?: Nullish<string>) => {
if (!provider) return null;

return {
url: ApiRoutes.models.getAvailableModelsByProvider(provider).path,
provider,
};
};

async function createModel(manifest: ModelManifest) {
await new Promise((resolve) => setTimeout(resolve, 1000));
const { data } = await request<Model>({
url: ApiRoutes.models.createModel().url,
method: "POST",
Expand All @@ -50,8 +55,6 @@ async function createModel(manifest: ModelManifest) {
}

async function updateModel(modelId: string, manifest: ModelManifest) {
await new Promise((resolve) => setTimeout(resolve, 1000));

const { data } = await request<Model>({
url: ApiRoutes.models.updateModel(modelId).url,
method: "PUT",
Expand All @@ -71,7 +74,7 @@ async function deleteModel(modelId: string) {
export const ModelApiService = {
getModels,
getModelById,
getModelProviders,
getAvailableModelsByProvider,
createModel,
updateModel,
deleteModel,
Expand Down
16 changes: 16 additions & 0 deletions ui/admin/app/lib/service/api/modelProviderApiService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { ModelProvider } from "~/lib/model/models";
import { ApiRoutes } from "~/lib/routers/apiRoutes";
import { request } from "~/lib/service/api/primitives";

async function getModelProviders() {
const { data } = await request<{ items?: ModelProvider[] }>({
url: ApiRoutes.modelProviders.getModelProviders().url,
});

return data.items ?? [];
}
getModelProviders.key = () => ({
url: ApiRoutes.modelProviders.getModelProviders().path,
});

export const ModelProviderApiService = { getModelProviders };
9 changes: 5 additions & 4 deletions ui/admin/app/routes/_auth.models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import useSWR, { preload } from "swr";
import { Model } from "~/lib/model/models";
import { DefaultModelAliasApiService } from "~/lib/service/api/defaultModelAliasApiService";
import { ModelApiService } from "~/lib/service/api/modelApiService";
import { ModelProviderApiService } from "~/lib/service/api/modelProviderApiService";

import { TypographyH2 } from "~/components/Typography";
import { DataTable } from "~/components/composed/DataTable";
Expand All @@ -24,8 +25,8 @@ export async function clientLoader() {
await Promise.all([
preload(ModelApiService.getModels.key(), ModelApiService.getModels),
preload(
ModelApiService.getModelProviders.key(),
ModelApiService.getModelProviders
ModelProviderApiService.getModelProviders.key(),
ModelProviderApiService.getModelProviders
),
preload(
DefaultModelAliasApiService.getAliases.key(),
Expand All @@ -44,8 +45,8 @@ export default function Models() {
);

const { data: providers } = useSWR(
ModelApiService.getModelProviders.key(),
ModelApiService.getModelProviders
ModelProviderApiService.getModelProviders.key(),
ModelProviderApiService.getModelProviders
);

const providerMap = useMemo(() => {
Expand Down

0 comments on commit 13e341a

Please sign in to comment.