From 51ab1f135d72021c142415d6a93ef0bcbc4708d1 Mon Sep 17 00:00:00 2001 From: alarv Date: Wed, 18 Dec 2024 23:15:19 +0200 Subject: [PATCH] fix(JAQPOT-432): LLM fixes --- src/app/SessionChecker.tsx | 4 +- .../models/[modelId]/components/ModelTabs.tsx | 2 +- .../[modelId]/components/llm/LLMForm.tsx | 129 ++++++++++++++---- .../components/llm/LLMNavigation.tsx | 21 ++- src/app/util/dataset.tsx | 18 --- 5 files changed, 114 insertions(+), 60 deletions(-) diff --git a/src/app/SessionChecker.tsx b/src/app/SessionChecker.tsx index 2ca3712..79c9542 100644 --- a/src/app/SessionChecker.tsx +++ b/src/app/SessionChecker.tsx @@ -26,14 +26,14 @@ export default function SessionChecker() { isLoading, error, } = useSWR(`/api/auth/validate`, fetcher, { - revalidateOnFocus: false, // Since we're persisting to localStorage + revalidateOnFocus: false, revalidateOnReconnect: false, }); const silentlySignOut = useCallback(async () => { await signOut({ redirect: false }); clearUserSettings(); - }, []); + }, [clearUserSettings]); if (hasRun) return null; diff --git a/src/app/dashboard/models/[modelId]/components/ModelTabs.tsx b/src/app/dashboard/models/[modelId]/components/ModelTabs.tsx index 77e1386..defdde9 100644 --- a/src/app/dashboard/models/[modelId]/components/ModelTabs.tsx +++ b/src/app/dashboard/models/[modelId]/components/ModelTabs.tsx @@ -86,7 +86,7 @@ export default function ModelTabs({ model }: ModelTabsProps) { Chat } - href={`${pathnameWithoutTab}/chat`} + href={`${pathnameWithoutTab}/chat/new`} > diff --git a/src/app/dashboard/models/[modelId]/components/llm/LLMForm.tsx b/src/app/dashboard/models/[modelId]/components/llm/LLMForm.tsx index abcac76..de1c00e 100644 --- a/src/app/dashboard/models/[modelId]/components/llm/LLMForm.tsx +++ b/src/app/dashboard/models/[modelId]/components/llm/LLMForm.tsx @@ -1,6 +1,6 @@ import { Textarea } from '@nextui-org/input'; import { DatasetDto, ModelDto } from '@/app/api.types'; -import React, { useState, useEffect } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import { Button } from '@nextui-org/button'; import { ArrowUpIcon } from '@heroicons/react/24/solid'; import { KeyboardEvent } from '@react-types/shared/src/events'; @@ -11,6 +11,7 @@ import { Spinner } from '@nextui-org/spinner'; import SWRClientFetchError from '@/app/components/SWRClientFetchError'; import ChatGrid from '@/app/dashboard/models/[modelId]/components/llm/ChatMessage'; import toast from 'react-hot-toast'; +import { useRouter } from 'next/navigation'; interface LLMFormProps { model: ModelDto; @@ -30,7 +31,24 @@ const placeholders = [ 'Debate whether hot dogs are sandwiches...', ]; +const createDataset = async (modelId: string, datasetDto: DatasetDto) => { + const res = await fetch(`/api/user/models/${modelId}/datasets`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(datasetDto), + }); + + if (!res.ok) { + throw new Error('Failed to create dataset'); + } + + return res.json(); +}; + export function LLMForm({ model, datasetId }: LLMFormProps) { + const router = useRouter(); + + const chatContainerRef = useRef(null); const [isFormLoading, setIsFormLoading] = useState(false); const [textareaContent, setTextareaContent] = useState(''); const [placeholder] = useState( @@ -46,6 +64,10 @@ export function LLMForm({ model, datasetId }: LLMFormProps) { } = useSWR( datasetId !== 'new' ? `/api/datasets/${datasetId}` : null, datasetFetcher, + { + revalidateOnFocus: false, + revalidateOnReconnect: false, + }, ); useEffect(() => { @@ -59,8 +81,39 @@ export function LLMForm({ model, datasetId }: LLMFormProps) { } }, [apiResponse]); + useEffect(() => { + scrollToBottom(); + }, [chatHistory]); + + useEffect(() => { + scrollToBottom(); + }, [currentResponse]); + + const scrollToBottom = () => { + if (chatContainerRef.current) { + const container = chatContainerRef.current; + const scrollHeight = container.scrollHeight; + const height = container.clientHeight; + const maxScroll = scrollHeight - height; + + // Only scroll if we're already near the bottom + const shouldScroll = + (container.scrollTop + container.clientHeight) / + container.scrollHeight >= + 0.5; + + if (shouldScroll) { + container.scrollTo({ + top: maxScroll, + behavior: 'smooth', + }); + } + } + }; + const createStreamingPrediction = async ( modelId: string, + datasetId: string, streamingPredictionRequestDto: { prompt: string }, ) => { const apiResponse = await fetch( @@ -106,18 +159,37 @@ export function LLMForm({ model, datasetId }: LLMFormProps) { const prompt = textareaContent.trim(); setIsFormLoading(true); - try { - const dataset = apiResponse?.data; - const response = await createStreamingPrediction(model.id!.toString(), { - prompt, + let dataset = apiResponse?.data; + if (datasetId === 'new') { + const { data } = await createDataset(model.id!.toString(), { + name: prompt.substring(0, 50), + entryType: 'ARRAY', + type: 'CHAT', + input: [], }); + dataset = data; + } + + if (!dataset) { + toast.error('An error occurred while processing your request.'); + return; + } + + try { + const response = await createStreamingPrediction( + model.id!.toString(), + dataset.id!.toString(), + { + prompt, + }, + ); setChatHistory((prev) => [ + ...prev, { prompt, output: response, }, - ...prev, ]); } catch (error) { toast.error('An error occurred while processing your request.'); @@ -125,6 +197,9 @@ export function LLMForm({ model, datasetId }: LLMFormProps) { setIsFormLoading(false); setCurrentResponse(undefined); setTextareaContent(''); + if (datasetId === 'new') { + router.replace(`/dashboard/models/${model.id}/chat/${dataset.id}`); + } } }; @@ -134,7 +209,28 @@ export function LLMForm({ model, datasetId }: LLMFormProps) { return ( -
e.preventDefault()}> +
+ {chatHistory.length > 0 && ( +
+ +
+ )} + + {currentResponse && ( +
+ +
+ )} +
+ + e.preventDefault()} className="mt-5">