Skip to content

Commit

Permalink
✨ feat: 支持停止生成消息
Browse files Browse the repository at this point in the history
close #78
  • Loading branch information
arvinxx committed Aug 14, 2023
1 parent 604b5a8 commit 9eeca80
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
16 changes: 11 additions & 5 deletions src/pages/chat/features/Conversation/Input/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ const ChatInput = () => {
s.preference.inputHeight,
s.updatePreference,
]);
const [sendMessage, hasTopic, saveToTopic] = useSessionStore((s) => [
s.sendMessage,
!!s.activeTopicId,
s.saveToTopic,
]);
const [isLoading, hasTopic, sendMessage, saveToTopic, stopGenerateMessage] = useSessionStore(
(s) => [
!!s.chatLoadingId,
!!s.activeTopicId,
s.sendMessage,
s.saveToTopic,
s.stopGenerateMessage,
],
);

const footer = hasTopic ? null : (
<Tooltip title={t('topic.saveCurrentMessages')}>
Expand Down Expand Up @@ -60,10 +64,12 @@ const ChatInput = () => {
actionsRight={<ActionsRight />}
expand={expand}
footer={footer}
loading={isLoading}
minHeight={CHAT_TEXTAREA_HEIGHT}
onExpandChange={setExpand}
onInputChange={setText}
onSend={sendMessage}
onStop={stopGenerateMessage}
placeholder={t('sendPlaceholder')}
text={{
send: t('send'),
Expand Down
66 changes: 48 additions & 18 deletions src/store/session/slices/chat/actions/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,31 @@ export interface ChatMessageAction {
messages: ChatMessage[],
assistantMessageId: string,
) => Promise<{ isFunctionCall: boolean }>;

/**
* 实际获取 AI 响应
*
* @param messages - 聊天消息数组
* @param parentId - 父消息 ID,可选
*/
realFetchAIResponse: (messages: ChatMessage[], parentId: string) => Promise<void>;

/**
* 重新发送消息
* @param id - 消息 ID
*/
resendMessage: (id: string) => Promise<void>;

/**
* 发送消息
* @param text - 消息文本
*/
sendMessage: (text: string) => Promise<void>;

stopGenerateMessage: () => void;
toggleChatLoading: (
loading: boolean,
id?: string,
action?: string,
) => AbortController | undefined;
triggerFunctionCall: (id: string) => Promise<void>;
}

Expand Down Expand Up @@ -97,14 +103,15 @@ export const chatMessage: StateCreator<

get().dispatchSession({ chats, id: activeId, type: 'updateSessionChat' });
},

generateMessage: async (messages, assistantId) => {
const { dispatchMessage } = get();
set(
{ chatLoadingId: assistantId },
false,
t('generateMessage(start)', { assistantId, messages }),
const { dispatchMessage, toggleChatLoading } = get();

const abortController = toggleChatLoading(
true,
assistantId,
t('generateMessage(start)', { assistantId, messages }) as string,
);

const config = agentSelectors.currentAgentConfig(get());

const compiler = template(config.inputTemplate, { interpolate: /{{([\S\s]+?)}}/g });
Expand Down Expand Up @@ -138,12 +145,15 @@ export const chatMessage: StateCreator<
}

const fetcher = () =>
fetchChatModel({
messages: postMessages,
model: config.model,
...config.params,
plugins: config.plugins,
});
fetchChatModel(
{
messages: postMessages,
model: config.model,
...config.params,
plugins: config.plugins,
},
{ signal: abortController?.signal },
);

let output = '';
let isFunctionCall = false;
Expand Down Expand Up @@ -175,7 +185,7 @@ export const chatMessage: StateCreator<
},
});

set({ chatLoadingId: undefined }, false, t('generateMessage(end)'));
toggleChatLoading(false, undefined, t('generateMessage(end)') as string);

return { isFunctionCall };
},
Expand Down Expand Up @@ -281,8 +291,27 @@ export const chatMessage: StateCreator<
}
},

stopGenerateMessage: () => {
const { abortController, toggleChatLoading } = get();
if (!abortController) return;

abortController.abort();

toggleChatLoading(false);
},

toggleChatLoading: (loading, id, action) => {
if (loading) {
const abortController = new AbortController();
set({ abortController, chatLoadingId: id }, false, action);
return abortController;
} else {
set({ abortController: undefined, chatLoadingId: undefined }, false, action);
}
},

triggerFunctionCall: async (id) => {
const { dispatchMessage, realFetchAIResponse } = get();
const { dispatchMessage, realFetchAIResponse, toggleChatLoading } = get();
const session = sessionSelectors.currentSession(get());

if (!session) return;
Expand Down Expand Up @@ -317,15 +346,16 @@ export const chatMessage: StateCreator<
// type: 'addMessage',
// });

const data = await fetchPlugin(payload);
const abortController = toggleChatLoading(true, id);
const data = await fetchPlugin(payload, { signal: abortController?.signal });
toggleChatLoading(false);

dispatchMessage({ id, key: 'content', type: 'updateMessage', value: data });

const chats = chatSelectors.currentChats(get());

await realFetchAIResponse(chats, message.id);
},

// genShareUrl: () => {
// const session = sessionSelectors.currentSession(get());
// if (!session) return '';
Expand Down
1 change: 1 addition & 0 deletions src/store/session/slices/chat/initialState.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export interface ChatState {
abortController?: AbortController;
activeTopicId?: string;
chatLoadingId?: string;
renameTopicId?: string;
Expand Down

0 comments on commit 9eeca80

Please sign in to comment.