Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(JAQPOT-432): LLM fixes #88

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/app/SessionChecker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ export default function SessionChecker() {
isLoading,
error,
} = useSWR(`/api/auth/validate`, fetcher, {
revalidateOnFocus: false, // Since we're persisting to localStorage
revalidateOnFocus: false,
revalidateOnReconnect: false,
});

const silentlySignOut = useCallback(async () => {
await signOut({ redirect: false });
clearUserSettings();
}, []);
}, [clearUserSettings]);

if (hasRun) return null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ export default function ModelTabs({ model }: ModelTabsProps) {
<span>Chat</span>
</div>
}
href={`${pathnameWithoutTab}/chat`}
href={`${pathnameWithoutTab}/chat/new`}
>
<ModelChatTab model={model} datasetId={params.datasetId} />
</Tab>
Expand Down
129 changes: 103 additions & 26 deletions src/app/dashboard/models/[modelId]/components/llm/LLMForm.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Textarea } from '@nextui-org/input';
import { DatasetDto, ModelDto } from '@/app/api.types';
import React, { useState, useEffect } from 'react';
import React, { useState, useEffect, useRef } from 'react';
import { Button } from '@nextui-org/button';
import { ArrowUpIcon } from '@heroicons/react/24/solid';
import { KeyboardEvent } from '@react-types/shared/src/events';
Expand All @@ -11,6 +11,7 @@ import { Spinner } from '@nextui-org/spinner';
import SWRClientFetchError from '@/app/components/SWRClientFetchError';
import ChatGrid from '@/app/dashboard/models/[modelId]/components/llm/ChatMessage';
import toast from 'react-hot-toast';
import { useRouter } from 'next/navigation';

interface LLMFormProps {
model: ModelDto;
Expand All @@ -30,7 +31,24 @@ const placeholders = [
'Debate whether hot dogs are sandwiches...',
];

const createDataset = async (modelId: string, datasetDto: DatasetDto) => {
const res = await fetch(`/api/user/models/${modelId}/datasets`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(datasetDto),
});

if (!res.ok) {
throw new Error('Failed to create dataset');
}

return res.json();
};

export function LLMForm({ model, datasetId }: LLMFormProps) {
const router = useRouter();

const chatContainerRef = useRef<HTMLDivElement>(null);
const [isFormLoading, setIsFormLoading] = useState(false);
const [textareaContent, setTextareaContent] = useState<string>('');
const [placeholder] = useState(
Expand All @@ -46,6 +64,10 @@ export function LLMForm({ model, datasetId }: LLMFormProps) {
} = useSWR(
datasetId !== 'new' ? `/api/datasets/${datasetId}` : null,
datasetFetcher,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
},
);

useEffect(() => {
Expand All @@ -59,8 +81,39 @@ export function LLMForm({ model, datasetId }: LLMFormProps) {
}
}, [apiResponse]);

useEffect(() => {
scrollToBottom();
}, [chatHistory]);

useEffect(() => {
scrollToBottom();
}, [currentResponse]);

const scrollToBottom = () => {
if (chatContainerRef.current) {
const container = chatContainerRef.current;
const scrollHeight = container.scrollHeight;
const height = container.clientHeight;
const maxScroll = scrollHeight - height;

// Only scroll if we're already near the bottom
const shouldScroll =
(container.scrollTop + container.clientHeight) /
container.scrollHeight >=
0.5;

if (shouldScroll) {
container.scrollTo({
top: maxScroll,
behavior: 'smooth',
});
}
}
};

const createStreamingPrediction = async (
modelId: string,
datasetId: string,
streamingPredictionRequestDto: { prompt: string },
) => {
const apiResponse = await fetch(
Expand Down Expand Up @@ -106,25 +159,47 @@ export function LLMForm({ model, datasetId }: LLMFormProps) {
const prompt = textareaContent.trim();
setIsFormLoading(true);

try {
const dataset = apiResponse?.data;
const response = await createStreamingPrediction(model.id!.toString(), {
prompt,
let dataset = apiResponse?.data;
if (datasetId === 'new') {
const { data } = await createDataset(model.id!.toString(), {
name: prompt.substring(0, 50),
entryType: 'ARRAY',
type: 'CHAT',
input: [],
});
dataset = data;
}

if (!dataset) {
toast.error('An error occurred while processing your request.');
return;
}

try {
const response = await createStreamingPrediction(
model.id!.toString(),
dataset.id!.toString(),
{
prompt,
},
);

setChatHistory((prev) => [
...prev,
{
prompt,
output: response,
},
...prev,
]);
} catch (error) {
toast.error('An error occurred while processing your request.');
} finally {
setIsFormLoading(false);
setCurrentResponse(undefined);
setTextareaContent('');
if (datasetId === 'new') {
router.replace(`/dashboard/models/${model.id}/chat/${dataset.id}`);
}
}
};

Expand All @@ -134,7 +209,28 @@ export function LLMForm({ model, datasetId }: LLMFormProps) {
return (
<Card className="w-full">
<CardBody>
<form onSubmit={(e) => e.preventDefault()}>
<div ref={chatContainerRef} className="max-h-[600px] overflow-y-auto">
{chatHistory.length > 0 && (
<div className="mt-4">
<ChatGrid messages={chatHistory} />
</div>
)}

{currentResponse && (
<div className="mt-4">
<ChatGrid
messages={[
{
prompt: textareaContent || '',
output: currentResponse,
},
]}
/>
</div>
)}
</div>

<form onSubmit={(e) => e.preventDefault()} className="mt-5">
<Textarea
name="prompt"
placeholder={placeholder}
Expand All @@ -156,25 +252,6 @@ export function LLMForm({ model, datasetId }: LLMFormProps) {
isRequired
/>
</form>

{currentResponse && (
<div className="mt-4">
<ChatGrid
messages={[
{
prompt: textareaContent || '',
output: currentResponse,
},
]}
/>
</div>
)}

{chatHistory.length > 0 && (
<div className="mt-4">
<ChatGrid messages={chatHistory} />
</div>
)}
</CardBody>
</Card>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ export default function LLMNavigation({ model }: LLMTabsProps) {
data: apiResponse,
error,
isLoading,
} = useSWR([model.id!.toString(), page], ([modelId, page]) =>
fetchDatasets(modelId, page - 1),
} = useSWR(
[model.id!.toString(), page],
([modelId, page]) => fetchDatasets(modelId, page - 1),
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
},
);

if (error) return <SWRClientFetchError error={error} />;
Expand All @@ -80,17 +85,7 @@ export default function LLMNavigation({ model }: LLMTabsProps) {
<Button
color="primary"
onPress={async () => {
const datasetDto: DatasetDto = {
type: 'CHAT',
input: [],
entryType: 'ARRAY',
};
const res = await fetch(`/api/user/models/${model.id}/datasets`, {
method: 'POST',
body: JSON.stringify(datasetDto),
});
const { data } = await res.json();
router.push(`${pathnameWithoutDatasetId}/${data.id}`);
router.push(`${pathnameWithoutDatasetId}/new`);
}}
>
Start new chat
Expand Down
18 changes: 0 additions & 18 deletions src/app/util/dataset.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,6 @@ function generateResultTableRow(
return resultTableRow;
}

export const createDatasetFetcher: Fetcher<
ApiResponse<DatasetDto>,
string
> = async (url) => {
const res = await fetch(url, { method: 'POST' });

// If the status code is not in the range 200-299,
// we still try to parse and throw it.
if (!res.ok) {
const message = (await res.json()).message;
const status = res.status;
// Attach extra info to the error object.
throw new CustomError(message, status);
}

return res.json();
};

export const datasetFetcher: Fetcher<ApiResponse<DatasetDto>, string> = async (
url,
) => {
Expand Down
Loading