Skip to content

Commit

Permalink
feat: Allow users to cancel message generation
Browse files Browse the repository at this point in the history
- temporary: ditch the stream reader when the user clicks stop
- will require rework when the backend supports cancellation via API

Signed-off-by: Ryan Hopper-Lowe <[email protected]>
  • Loading branch information
ryanhopperlowe committed Oct 22, 2024
1 parent d8dd4f6 commit 135f901
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 55 deletions.
11 changes: 9 additions & 2 deletions ui/admin/app/components/chat/ChatContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ interface ChatContextType {
generatingMessage: Message | null;
invoke: (prompt?: string) => void;
readOnly?: boolean;
cancelMessage?: () => void;
}

const ChatContext = createContext<ChatContextType | undefined>(undefined);
Expand All @@ -57,6 +58,7 @@ export function ChatProvider({
const [generatingMessage, setGeneratingMessage] = useState<string | null>(
null
);
const [cancelMessage, setCancelMessage] = useState<() => void>();
const isRunningToolCall = useRef(false);
// todo(tylerslaton): this is a huge hack to get the generating message and runId to be
// interactable during workflow invokes. take a look at invokeWorkflow to see why this is
Expand Down Expand Up @@ -145,7 +147,7 @@ export function ChatProvider({
onSuccess: ({ reader, threadId: responseThreadId }) => {
clearGeneratingMessage();

readStream<ChatEvent>({
const invokeStream = readStream<ChatEvent>({
reader,
onChunk: (chunk) =>
// use a transition for performance
Expand Down Expand Up @@ -206,16 +208,20 @@ export function ChatProvider({

invokeAgent.clear();
generatingRunIdRef.current = null;
setCancelMessage(undefined);
},
});

setCancelMessage(() => invokeStream.cancel);

invokeStream.start();
},
});

const outGeneratingMessage = useMemo<Message | null>(() => {
if (invokeAgent.isLoading)
return { sender: "agent", text: "", isLoading: true };

// slice the first character because it is always a newline for some reason
if (!generatingMessage) {
if (invokeAgent.data?.reader && !isRunningToolCall.current) {
return {
Expand Down Expand Up @@ -250,6 +256,7 @@ export function ChatProvider({
threadId,
generatingMessage: outGeneratingMessage,
invoke,
cancelMessage,
readOnly,
}}
>
Expand Down
37 changes: 26 additions & 11 deletions ui/admin/app/components/chat/Chatbar.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CircleArrowUpIcon } from "lucide-react";
import { CircleArrowUpIcon, StopCircleIcon } from "lucide-react";
import { useState } from "react";

import { cn } from "~/lib/utils";
Expand All @@ -13,10 +13,13 @@ type ChatbarProps = {

export function Chatbar({ className }: ChatbarProps) {
const [input, setInput] = useState("");
const { processUserMessage } = useChat();
const { processUserMessage, cancelMessage } = useChat();

const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();

if (cancelMessage) return;

if (input.trim()) {
processUserMessage(input, "user");
setInput("");
Expand Down Expand Up @@ -45,15 +48,27 @@ export function Chatbar({ className }: ChatbarProps) {
/>
</div>

<Button
size="icon"
variant="secondary"
className="rounded-full"
type="submit"
disabled={!input}
>
<CircleArrowUpIcon />
</Button>
{cancelMessage ? (
<Button
size="icon"
variant="secondary"
className="rounded-full"
type="button"
onClick={cancelMessage}
>
<StopCircleIcon />
</Button>
) : (
<Button
size="icon"
variant="secondary"
className="rounded-full"
type="submit"
disabled={!input}
>
<CircleArrowUpIcon />
</Button>
)}
</form>
);
}
94 changes: 52 additions & 42 deletions ui/admin/app/lib/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
This approach ensures that we don't lose data between chunks and can
handle messages that might be split across multiple chunks.
*/
export async function readStream<T>({
export function readStream<T>({
reader,
onChunk,
onComplete,
Expand All @@ -16,33 +16,51 @@ export async function readStream<T>({
onChunk: (data: T) => void;
onComplete?: (data: T[]) => void;
}) {
const collected: T[] = [];
const decoder = new TextDecoder();
let buffer = "";
let isCanceled = false;

try {
// eslint-disable-next-line no-constant-condition
while (true) {
// Read from the stream
const { value, done } = await reader.read();
if (done) break;
async function start() {
const collected: T[] = [];
const decoder = new TextDecoder();
let buffer = "";

// Decode the chunk and add to buffer
buffer += decoder.decode(new TextEncoder().encode(value), {
stream: true,
});
try {
// eslint-disable-next-line no-constant-condition
while (!isCanceled) {
// Read from the stream
const { value, done } = await reader.read();
if (done) break;

// Split buffer into complete messages
const messages = buffer.split("\n\n");
// Keep the last (potentially incomplete) message in the buffer
buffer = messages.pop() || "";
// Decode the chunk and add to buffer
buffer += decoder.decode(new TextEncoder().encode(value), {
stream: true,
});

// Process complete messages
for (const message of messages) {
const dataString = message
.replace(/^id:.*\n/, "")
.replace(/^data: /, "")
.trim();
// Split buffer into complete messages
const messages = buffer.split("\n\n");
// Keep the last (potentially incomplete) message in the buffer
buffer = messages.pop() || "";

// Process complete messages
for (const message of messages) {
const dataString = message
.replace(/^id:.*\n/, "")
.replace(/^data: /, "")
.trim();
if (dataString) {
try {
const data = JSON.parse(dataString) as T;
onChunk(data);
collected.push(data);
} catch (error) {
console.error("Error parsing JSON:", error);
}
}
}
}

// Process any remaining data in the buffer after stream closes
if (buffer.trim()) {
const dataString = buffer.replace(/^data: /, "").trim();
if (dataString) {
try {
const data = JSON.parse(dataString) as T;
Expand All @@ -53,25 +71,17 @@ export async function readStream<T>({
}
}
}
} catch (error) {
console.error("Error reading stream:", error);
} finally {
// Always call onComplete, even if there was an error
onComplete?.(collected);
}
}

// Process any remaining data in the buffer after stream closes
if (buffer.trim()) {
const dataString = buffer.replace(/^data: /, "").trim();
if (dataString) {
try {
const data = JSON.parse(dataString) as T;
onChunk(data);
collected.push(data);
} catch (error) {
console.error("Error parsing JSON:", error);
}
}
}
} catch (error) {
console.error("Error reading stream:", error);
} finally {
// Always call onComplete, even if there was an error
onComplete?.(collected);
function cancel() {
isCanceled = true;
}

return { start, cancel };
}

0 comments on commit 135f901

Please sign in to comment.