From 401df75b73716be62efc295d36263ee3b6fc51c9 Mon Sep 17 00:00:00 2001 From: xcatliu Date: Wed, 22 Nov 2023 23:40:26 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20abortSendMessage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/components/History.tsx | 9 ++++- app/components/Messages.tsx | 10 ++++- app/components/TextareaForm.tsx | 39 +++++++++++--------- app/components/buttons/AttachImageButton.tsx | 9 +++-- app/context/ChatContext.tsx | 36 ++++++++++++++---- app/utils/api.ts | 6 +++ app/utils/constants.ts | 5 +++ 7 files changed, 80 insertions(+), 34 deletions(-) diff --git a/app/components/History.tsx b/app/components/History.tsx index d646ad48..8b877656 100644 --- a/app/components/History.tsx +++ b/app/components/History.tsx @@ -45,7 +45,7 @@ export const HistoryItemComp: FC<{ historyIndex: 'current' | number; isActive: b historyIndex, isActive, }) => { - const { messages, history, loadHistory } = useContext(ChatContext)!; + const { messages, history, loadHistory, abortSendMessage } = useContext(ChatContext)!; const { settings } = useContext(SettingsContext)!; let historyItem: HistoryItem; @@ -63,7 +63,12 @@ export const HistoryItemComp: FC<{ historyIndex: 'current' | number; isActive: b className={classNames('p-4 border-b-[0.5px] relative cursor-default border-gray md:-mx-4 md:px-8', { 'bg-gray-300 dark:bg-gray-700': isActive, })} - onClick={() => historyIndex !== 'current' && loadHistory(historyIndex)} + onClick={() => { + if (historyIndex !== 'current') { + abortSendMessage(); + loadHistory(historyIndex); + } + }} >

{getContentText(historyItem.messages[0])}

{ const { settings } = useContext(SettingsContext)!; - let { isLoading, messages, history, historyIndex, startNewChat } = useContext(ChatContext)!; + let { isLoading, messages, history, historyIndex, startNewChat, abortSendMessage } = useContext(ChatContext)!; // 初始化滚动事件 useEffect(initEventListenerScroll, []); @@ -52,7 +52,13 @@ export const Messages = () => { {messages.length > 1 && ( 连续对话会加倍消耗 tokens, - + { + abortSendMessage(); + startNewChat(); + }} + > 开启新对话 diff --git a/app/components/TextareaForm.tsx b/app/components/TextareaForm.tsx index 028675dd..3aa530fb 100644 --- a/app/components/TextareaForm.tsx +++ b/app/components/TextareaForm.tsx @@ -17,7 +17,7 @@ export const TextareaForm: FC = () => { const { isMobile } = useContext(DeviceContext)!; const { isLogged } = useContext(LoginContext)!; const { settings } = useContext(SettingsContext)!; - const { images, appendImages, sendMessage } = useContext(ChatContext)!; + const { images, appendImages, sendMessage, abortSendMessage } = useContext(ChatContext)!; // 是否正在中文输入 const [isComposing, setIsComposing] = useState(false); @@ -43,26 +43,27 @@ export const TextareaForm: FC = () => { }); }, []); - /** * Handle pasting images into the textarea */ - const handlePaste = useCallback(async (e: React.ClipboardEvent) => { - const items = e.clipboardData?.items; - if (items) { - for (let i = 0; i < items.length; i++) { - if (items[i].type.indexOf('image') === 0) { - const file = items[i].getAsFile(); - if (file == null) { - throw new Error("Expected file") + const handlePaste = useCallback( + async (e: React.ClipboardEvent) => { + const items = e.clipboardData?.items; + if (items) { + for (let i = 0; i < items.length; i++) { + if (items[i].type.indexOf('image') === 0) { + const file = items[i].getAsFile(); + if (file == null) { + throw new Error('Expected file'); + } + const image = await readImageFile(file); + appendImages(image); } - const image = await readImageFile(file); - appendImages(image); } } - } - }, [appendImages]); - + }, + [appendImages], + ); /** * 更新 textarea 的 empty 状态 @@ -129,9 +130,10 @@ export const TextareaForm: FC = () => { } updateTextareaHeight(); updateIsTextareaEmpty(); + abortSendMessage(); await sendMessage(value); }, - [sendMessage, updateTextareaHeight, updateIsTextareaEmpty], + [sendMessage, abortSendMessage, updateTextareaHeight, updateIsTextareaEmpty], ); /** @@ -195,7 +197,9 @@ export const TextareaForm: FC = () => { {settings.model.includes('vision') && }

0, + })} type="submit" disabled={isTextareaEmpty && images.length === 0} value="发送" @@ -206,4 +210,3 @@ export const TextareaForm: FC = () => { ); }; - diff --git a/app/components/buttons/AttachImageButton.tsx b/app/components/buttons/AttachImageButton.tsx index bdd47322..04014285 100644 --- a/app/components/buttons/AttachImageButton.tsx +++ b/app/components/buttons/AttachImageButton.tsx @@ -5,6 +5,7 @@ import { useContext } from 'react'; import { ChatContext } from '@/context/ChatContext'; import { LoginContext } from '@/context/LoginContext'; +import { MAX_GPT_VISION_IMAGES } from '@/utils/constants'; import { readImageFile } from '@/utils/image'; /** @@ -15,11 +16,11 @@ export const AttachImageButton: FC<{}> = () => { const { isLogged } = useContext(LoginContext)!; return ( -
+ <>
+ ); }; diff --git a/app/context/ChatContext.tsx b/app/context/ChatContext.tsx index d301eb54..0393c410 100644 --- a/app/context/ChatContext.tsx +++ b/app/context/ChatContext.tsx @@ -2,12 +2,12 @@ import omit from 'lodash.omit'; import type { FC, ReactNode } from 'react'; -import { createContext, useCallback, useContext, useEffect, useState } from 'react'; +import { createContext, useCallback, useContext, useEffect, useMemo, useState } from 'react'; import { fetchApiChat } from '@/utils/api'; import { getCache, setCache } from '@/utils/cache'; import type { ChatResponse, Message, StructuredMessageContentItem } from '@/utils/constants'; -import { MAX_TOKENS, MessageContentType, Model, Role } from '@/utils/constants'; +import { MAX_GPT_VISION_IMAGES, MAX_TOKENS, MessageContentType, Model, Role } from '@/utils/constants'; import type { ResError } from '@/utils/error'; import type { ImageProp } from '@/utils/image'; import { isMessage } from '@/utils/message'; @@ -31,6 +31,7 @@ export interface HistoryItem { */ export const ChatContext = createContext<{ sendMessage: (content?: string) => Promise; + abortSendMessage: () => void; isLoading: boolean; messages: (Message | ChatResponse)[]; images: ImageProp[]; @@ -52,6 +53,8 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { const [history, setHistory] = useState(undefined); // 当前选中的对话在 history 中的 index,empty 表示未选中,current 表示选中的是当前对话 const [historyIndex, setHistoryIndex] = useState<'empty' | 'current' | number>('empty'); + // 控制请求中断 + const [abortController, setAbortController] = useState(); // 页面加载后从 cache 中读取 history 和 messages // 如果 messages 不为空,则将最近的一条消息写入 history @@ -165,6 +168,9 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { if (settings.systemMessage) { fetchApiChatMessages.unshift(settings.systemMessage); } + // 创建一个新的 abortController + const newAbortController = new AbortController(); + setAbortController(newAbortController); // TODO 收到完整消息后,写入 cache 中 const fullContent = await fetchApiChat({ // gpt-4-vision-preview 有个 bug:不传 max_tokens 时,会中断消息 @@ -184,6 +190,7 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { scrollToBottom(); } }, + signal: newAbortController.signal, }); // 收到完整消息后,重新设置 messages @@ -194,7 +201,11 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { if (gapToBottom() <= 72 && !getIsScrolling()) { scrollToBottom(); } - } catch (e) { + } catch (e: any) { + // 如果是调用 abortController.abort() 捕获到的 error 则不处理 + if (e.name === 'AbortError') { + return; + } // 发生错误时,展示错误消息 setIsLoading(false); setMessages([ @@ -203,9 +214,16 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { ]); } }, - [settings, messages, images, history, historyIndex], + [settings, messages, images, history, historyIndex, setAbortController], ); + /** + * 中断请求 + */ + const abortSendMessage = useCallback(() => { + abortController?.abort(); + }, [abortController]); + /** * 加载聊天记录 */ @@ -258,7 +276,8 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { setCache('history', newHistory); setMessages([]); setCache('messages', []); - setHistoryIndex(index); + // 此时因为将 current 进行归档了,所以需要 +1 + setHistoryIndex(index + 1); setSettings({ model: newModel, }); @@ -339,9 +358,9 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { const appendImages = useCallback( (...newImages: ImageProp[]) => { const finalImages = [...images, ...newImages]; - if (finalImages.length > 9) { - setImages(finalImages.slice(0, 9)); - alert('最多只能发送九张图片,超出的图片已删除'); + if (finalImages.length > MAX_GPT_VISION_IMAGES) { + setImages(finalImages.slice(0, MAX_GPT_VISION_IMAGES)); + alert(`最多只能发送 ${MAX_GPT_VISION_IMAGES} 张图片,超出的图片已删除`); return; } setImages(finalImages); @@ -362,6 +381,7 @@ export const ChatProvider: FC<{ children: ReactNode }> = ({ children }) => { void; + /** + * 控制请求中断的 AbortSignal + */ + signal?: AbortSignal; } & Partial) => { const fetchResult = await fetch('/api/chat', { method: HttpMethod.POST, headers: HttpHeaderJson, body: JSON.stringify(chatRequest), + signal, }); // 如果返回错误,则直接抛出错误 diff --git a/app/utils/constants.ts b/app/utils/constants.ts index 7e54947f..6889815e 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -19,6 +19,11 @@ export const HttpHeaderJson = { */ export const FULL_SPACE = ' '; +/** + * 使用 gpt-4-vision 时单次可传输的最多图片数量 + */ +export const MAX_GPT_VISION_IMAGES = 9; + /** * 角色 * 参考 https://github.com/openai/openai-node/blob/master/api.ts