From 9eeca80c835acc4a65b41ff378cd8cad312a238d Mon Sep 17 00:00:00 2001 From: arvinxx Date: Tue, 15 Aug 2023 00:28:32 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20=E6=94=AF=E6=8C=81=E5=81=9C?= =?UTF-8?q?=E6=AD=A2=E7=94=9F=E6=88=90=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit close #78 --- .../features/Conversation/Input/index.tsx | 16 +++-- .../session/slices/chat/actions/message.ts | 66 ++++++++++++++----- src/store/session/slices/chat/initialState.ts | 1 + 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/pages/chat/features/Conversation/Input/index.tsx b/src/pages/chat/features/Conversation/Input/index.tsx index ff29ae46fee6..08255b7ef488 100644 --- a/src/pages/chat/features/Conversation/Input/index.tsx +++ b/src/pages/chat/features/Conversation/Input/index.tsx @@ -21,11 +21,15 @@ const ChatInput = () => { s.preference.inputHeight, s.updatePreference, ]); - const [sendMessage, hasTopic, saveToTopic] = useSessionStore((s) => [ - s.sendMessage, - !!s.activeTopicId, - s.saveToTopic, - ]); + const [isLoading, hasTopic, sendMessage, saveToTopic, stopGenerateMessage] = useSessionStore( + (s) => [ + !!s.chatLoadingId, + !!s.activeTopicId, + s.sendMessage, + s.saveToTopic, + s.stopGenerateMessage, + ], + ); const footer = hasTopic ? null : ( @@ -60,10 +64,12 @@ const ChatInput = () => { actionsRight={} expand={expand} footer={footer} + loading={isLoading} minHeight={CHAT_TEXTAREA_HEIGHT} onExpandChange={setExpand} onInputChange={setText} onSend={sendMessage} + onStop={stopGenerateMessage} placeholder={t('sendPlaceholder')} text={{ send: t('send'), diff --git a/src/store/session/slices/chat/actions/message.ts b/src/store/session/slices/chat/actions/message.ts index 38ae31fd76c4..7d18749677cb 100644 --- a/src/store/session/slices/chat/actions/message.ts +++ b/src/store/session/slices/chat/actions/message.ts @@ -46,7 +46,6 @@ export interface ChatMessageAction { messages: ChatMessage[], assistantMessageId: string, ) => Promise<{ isFunctionCall: boolean }>; - /** * 实际获取 AI 响应 * @@ -54,17 +53,24 @@ export interface ChatMessageAction { * @param parentId - 父消息 ID,可选 */ realFetchAIResponse: (messages: ChatMessage[], parentId: string) => Promise; + /** * 重新发送消息 * @param id - 消息 ID */ resendMessage: (id: string) => Promise; - /** * 发送消息 * @param text - 消息文本 */ sendMessage: (text: string) => Promise; + + stopGenerateMessage: () => void; + toggleChatLoading: ( + loading: boolean, + id?: string, + action?: string, + ) => AbortController | undefined; triggerFunctionCall: (id: string) => Promise; } @@ -97,14 +103,15 @@ export const chatMessage: StateCreator< get().dispatchSession({ chats, id: activeId, type: 'updateSessionChat' }); }, - generateMessage: async (messages, assistantId) => { - const { dispatchMessage } = get(); - set( - { chatLoadingId: assistantId }, - false, - t('generateMessage(start)', { assistantId, messages }), + const { dispatchMessage, toggleChatLoading } = get(); + + const abortController = toggleChatLoading( + true, + assistantId, + t('generateMessage(start)', { assistantId, messages }) as string, ); + const config = agentSelectors.currentAgentConfig(get()); const compiler = template(config.inputTemplate, { interpolate: /{{([\S\s]+?)}}/g }); @@ -138,12 +145,15 @@ export const chatMessage: StateCreator< } const fetcher = () => - fetchChatModel({ - messages: postMessages, - model: config.model, - ...config.params, - plugins: config.plugins, - }); + fetchChatModel( + { + messages: postMessages, + model: config.model, + ...config.params, + plugins: config.plugins, + }, + { signal: abortController?.signal }, + ); let output = ''; let isFunctionCall = false; @@ -175,7 +185,7 @@ export const chatMessage: StateCreator< }, }); - set({ chatLoadingId: undefined }, false, t('generateMessage(end)')); + toggleChatLoading(false, undefined, t('generateMessage(end)') as string); return { isFunctionCall }; }, @@ -281,8 +291,27 @@ export const chatMessage: StateCreator< } }, + stopGenerateMessage: () => { + const { abortController, toggleChatLoading } = get(); + if (!abortController) return; + + abortController.abort(); + + toggleChatLoading(false); + }, + + toggleChatLoading: (loading, id, action) => { + if (loading) { + const abortController = new AbortController(); + set({ abortController, chatLoadingId: id }, false, action); + return abortController; + } else { + set({ abortController: undefined, chatLoadingId: undefined }, false, action); + } + }, + triggerFunctionCall: async (id) => { - const { dispatchMessage, realFetchAIResponse } = get(); + const { dispatchMessage, realFetchAIResponse, toggleChatLoading } = get(); const session = sessionSelectors.currentSession(get()); if (!session) return; @@ -317,7 +346,9 @@ export const chatMessage: StateCreator< // type: 'addMessage', // }); - const data = await fetchPlugin(payload); + const abortController = toggleChatLoading(true, id); + const data = await fetchPlugin(payload, { signal: abortController?.signal }); + toggleChatLoading(false); dispatchMessage({ id, key: 'content', type: 'updateMessage', value: data }); @@ -325,7 +356,6 @@ export const chatMessage: StateCreator< await realFetchAIResponse(chats, message.id); }, - // genShareUrl: () => { // const session = sessionSelectors.currentSession(get()); // if (!session) return ''; diff --git a/src/store/session/slices/chat/initialState.ts b/src/store/session/slices/chat/initialState.ts index 5bc75bcca141..615c995405f3 100644 --- a/src/store/session/slices/chat/initialState.ts +++ b/src/store/session/slices/chat/initialState.ts @@ -1,4 +1,5 @@ export interface ChatState { + abortController?: AbortController; activeTopicId?: string; chatLoadingId?: string; renameTopicId?: string;