diff --git a/ui/admin/app/components/model/AddModel.tsx b/ui/admin/app/components/model/AddModel.tsx index 84d11589c..c1c4aba6a 100644 --- a/ui/admin/app/components/model/AddModel.tsx +++ b/ui/admin/app/components/model/AddModel.tsx @@ -24,6 +24,7 @@ export function AddModel() { Create Model + setOpen(false)} /> diff --git a/ui/admin/app/components/model/ModelForm.tsx b/ui/admin/app/components/model/ModelForm.tsx index 76cdbc9dd..7cbcf5221 100644 --- a/ui/admin/app/components/model/ModelForm.tsx +++ b/ui/admin/app/components/model/ModelForm.tsx @@ -12,9 +12,9 @@ import { ModelProvider, ModelUsage, getModelUsageLabel, - getModelsForProvider, } from "~/lib/model/models"; import { ModelApiService } from "~/lib/service/api/modelApiService"; +import { ModelProviderApiService } from "~/lib/service/api/modelProviderApiService"; import { ControlledCustomInput } from "~/components/form/controlledInputs"; import { Button } from "~/components/ui/button"; @@ -39,8 +39,8 @@ export function ModelForm(props: ModelFormProps) { const { model, onSubmit } = props; const { data: modelProviders } = useSWR( - ModelApiService.getModelProviders.key(), - ModelApiService.getModelProviders + ModelProviderApiService.getModelProviders.key(), + () => ModelProviderApiService.getModelProviders() ); const updateModel = useAsync(ModelApiService.updateModel, { @@ -76,6 +76,15 @@ export function ModelForm(props: ModelFormProps) { defaultValues, }); + const getAvailableModels = useSWR( + ModelApiService.getAvailableModelsByProvider.key( + form.watch("modelProvider") + ), + ({ provider }) => + ModelApiService.getAvailableModelsByProvider(provider), + { revalidateIfStale: false } + ); + const { loading, submit } = getSubmitInfo(); const handleSubmit = form.handleSubmit((values) => @@ -85,8 +94,7 @@ export function ModelForm(props: ModelFormProps) { const providerName = (provider: ModelProvider) => { let text = provider.name || provider.id; - if (!provider.modelProviderStatus.configured) - text += " (not configured)"; + if (!provider.configured) text += " (not configured)"; return text; }; @@ -110,10 +118,7 @@ export function ModelForm(props: ModelFormProps) { {providerName(provider)} @@ -129,26 +134,21 @@ export function ModelForm(props: ModelFormProps) { label="Target Model" > {({ field: { ref: _, ...field }, className }) => { - const models = getModelsForProvider( - form.watch("modelProvider") - ); - return ( ); @@ -191,6 +191,50 @@ export function ModelForm(props: ModelFormProps) { ); + function updateUsageFromModel(value: string) { + const model = getAvailableModels.data?.find((m) => m.id === value); + + const usage = model?.metadata?.usage ?? ModelUsage.Other; + + form.setValue("usage", usage); + } + + function getModelOptions() { + if (getAvailableModels.data) { + return getAvailableModels.data.map((model) => ( + + {model.id} + + )); + } + + const options: React.ReactNode[] = []; + + const targetModel = form.watch("targetModel"); + if (targetModel) + options.push( + + {targetModel} + + ); + + if (getAvailableModels.isLoading) { + options.push( + + Loading models... + + ); + } else if (!getAvailableModels.data) { + options.push( + + No models available + + ); + } + + return options; + } + function getSubmitInfo() { if (model) { return { diff --git a/ui/admin/app/lib/model/availableModels.ts b/ui/admin/app/lib/model/availableModels.ts new file mode 100644 index 000000000..4f0971825 --- /dev/null +++ b/ui/admin/app/lib/model/availableModels.ts @@ -0,0 +1,10 @@ +import { ModelUsage } from "~/lib/model/models"; +import { EntityMeta } from "~/lib/model/primitives"; + +export type AvailableModel = EntityMeta<{ usage?: ModelUsage }> & { + object: string; + owned_by: string; + permission: string[]; + root: string; + parent: string; +}; diff --git a/ui/admin/app/lib/model/models.ts b/ui/admin/app/lib/model/models.ts index 27aa2a673..3e8640ef0 100644 --- a/ui/admin/app/lib/model/models.ts +++ b/ui/admin/app/lib/model/models.ts @@ -47,50 +47,14 @@ export const ModelManifestSchema = z.object({ usage: z.nativeEnum(ModelUsage), }); -export type ModelProvider = EntityMeta & { - description?: string; - builtin: boolean; - active: boolean; - modelProviderStatus: ModelProviderStatus; +type ModelProviderManifest = { name: string; - reference: string; - toolType: "modelProvider"; + toolReference: string; }; -// note(ryanhopperlowe): these values are hardcoded for now -// ideally they should come from the backend -const ModelToProviderMap = { - "openai-model-provider": [ - "text-embedding-3-small", - "dall-e-3", - "gpt-4o-mini", - "gpt-3.5-turbo", - "text-embedding-ada-002", - "gpt-4o", - ], - "azure-openai-model-provider": [ - "text-embedding-3-small", - "dall-e-3", - "gpt-4o-mini", - "gpt-3.5-turbo", - "text-embedding-ada-002", - "gpt-4o", - ], - "anthropic-model-provider": [ - "claude-3-opus-latest", - "claude-3-5-sonnet-latest", - "claude-3-5-haiku-latest", - ], - "ollama-model-provider": ["llama3.2"], - "voyage-model-provider": [ - "voyage-3", - "voyage-3-lite", - "voyage-finance-2", - "voyage-multilingual-2", - "voyage-law-2", - "voyage-code-2", - ], -}; +export type ModelProvider = EntityMeta & + ModelProviderManifest & + ModelProviderStatus; export const ModelAliasToUsageMap = { llm: ModelUsage.LLM, @@ -104,8 +68,3 @@ export function getModelUsageFromAlias(alias: string) { return ModelAliasToUsageMap[alias as keyof typeof ModelAliasToUsageMap]; } - -export function getModelsForProvider(providerId: string) { - if (!providerId || !(providerId in ModelToProviderMap)) return []; - return ModelToProviderMap[providerId as keyof typeof ModelToProviderMap]; -} diff --git a/ui/admin/app/lib/routers/apiRoutes.ts b/ui/admin/app/lib/routers/apiRoutes.ts index ae2d1b6ac..bf4e75f9e 100644 --- a/ui/admin/app/lib/routers/apiRoutes.ts +++ b/ui/admin/app/lib/routers/apiRoutes.ts @@ -172,6 +172,12 @@ export const ApiRoutes = { createModel: () => buildUrl(`/models`), updateModel: (modelId: string) => buildUrl(`/models/${modelId}`), deleteModel: (modelId: string) => buildUrl(`/models/${modelId}`), + getAvailableModels: () => buildUrl("/available-models"), + getAvailableModelsByProvider: (provider: string) => + buildUrl(`/available-models/${provider}`), + }, + modelProviders: { + getModelProviders: () => buildUrl("/model-providers"), }, defaultModelAliases: { base: () => buildUrl("/default-model-aliases"), diff --git a/ui/admin/app/lib/service/api/modelApiService.ts b/ui/admin/app/lib/service/api/modelApiService.ts index c135e41f3..1918763b4 100644 --- a/ui/admin/app/lib/service/api/modelApiService.ts +++ b/ui/admin/app/lib/service/api/modelApiService.ts @@ -1,4 +1,5 @@ -import { Model, ModelManifest, ModelProvider } from "~/lib/model/models"; +import { AvailableModel } from "~/lib/model/availableModels"; +import { Model, ModelManifest } from "~/lib/model/models"; import { ApiRoutes } from "~/lib/routers/apiRoutes"; import { request } from "~/lib/service/api/primitives"; @@ -27,19 +28,23 @@ getModelById.key = (modelId?: string) => { }; }; -async function getModelProviders() { - const { data } = await request<{ items?: ModelProvider[] }>({ - url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).url, +async function getAvailableModelsByProvider(provider: string) { + const { data } = await request<{ data?: AvailableModel[] }>({ + url: ApiRoutes.models.getAvailableModelsByProvider(provider).url, }); - return data.items ?? []; + return data.data ?? []; } -getModelProviders.key = () => ({ - url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).path, -}); +getAvailableModelsByProvider.key = (provider?: Nullish) => { + if (!provider) return null; + + return { + url: ApiRoutes.models.getAvailableModelsByProvider(provider).path, + provider, + }; +}; async function createModel(manifest: ModelManifest) { - await new Promise((resolve) => setTimeout(resolve, 1000)); const { data } = await request({ url: ApiRoutes.models.createModel().url, method: "POST", @@ -50,8 +55,6 @@ async function createModel(manifest: ModelManifest) { } async function updateModel(modelId: string, manifest: ModelManifest) { - await new Promise((resolve) => setTimeout(resolve, 1000)); - const { data } = await request({ url: ApiRoutes.models.updateModel(modelId).url, method: "PUT", @@ -71,7 +74,7 @@ async function deleteModel(modelId: string) { export const ModelApiService = { getModels, getModelById, - getModelProviders, + getAvailableModelsByProvider, createModel, updateModel, deleteModel, diff --git a/ui/admin/app/lib/service/api/modelProviderApiService.ts b/ui/admin/app/lib/service/api/modelProviderApiService.ts new file mode 100644 index 000000000..06afc83d5 --- /dev/null +++ b/ui/admin/app/lib/service/api/modelProviderApiService.ts @@ -0,0 +1,16 @@ +import { ModelProvider } from "~/lib/model/models"; +import { ApiRoutes } from "~/lib/routers/apiRoutes"; +import { request } from "~/lib/service/api/primitives"; + +async function getModelProviders() { + const { data } = await request<{ items?: ModelProvider[] }>({ + url: ApiRoutes.modelProviders.getModelProviders().url, + }); + + return data.items ?? []; +} +getModelProviders.key = () => ({ + url: ApiRoutes.modelProviders.getModelProviders().path, +}); + +export const ModelProviderApiService = { getModelProviders }; diff --git a/ui/admin/app/routes/_auth.models.tsx b/ui/admin/app/routes/_auth.models.tsx index 668d5dd45..214dda0c1 100644 --- a/ui/admin/app/routes/_auth.models.tsx +++ b/ui/admin/app/routes/_auth.models.tsx @@ -6,6 +6,7 @@ import useSWR, { preload } from "swr"; import { Model } from "~/lib/model/models"; import { DefaultModelAliasApiService } from "~/lib/service/api/defaultModelAliasApiService"; import { ModelApiService } from "~/lib/service/api/modelApiService"; +import { ModelProviderApiService } from "~/lib/service/api/modelProviderApiService"; import { TypographyH2 } from "~/components/Typography"; import { DataTable } from "~/components/composed/DataTable"; @@ -24,8 +25,8 @@ export async function clientLoader() { await Promise.all([ preload(ModelApiService.getModels.key(), ModelApiService.getModels), preload( - ModelApiService.getModelProviders.key(), - ModelApiService.getModelProviders + ModelProviderApiService.getModelProviders.key(), + ModelProviderApiService.getModelProviders ), preload( DefaultModelAliasApiService.getAliases.key(), @@ -44,8 +45,8 @@ export default function Models() { ); const { data: providers } = useSWR( - ModelApiService.getModelProviders.key(), - ModelApiService.getModelProviders + ModelProviderApiService.getModelProviders.key(), + ModelProviderApiService.getModelProviders ); const providerMap = useMemo(() => {