Skip to content

Commit

Permalink
🐛 fix: fix remove tts and translate not working (lobehub#818)
Browse files Browse the repository at this point in the history
* ♻️ refactor: clean message service methods

* 🐛 fix: fix remove tts and translate not working
  • Loading branch information
arvinxx authored Dec 26, 2023
1 parent 6ec425b commit 4a275e9
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 151 deletions.
6 changes: 1 addition & 5 deletions src/database/models/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ class _MessageModel extends BaseModel {
tts,
...item
}: DBModel<DB_Message>): ChatMessage => {
return {
...item,
extra: { fromModel: fromModel, translate: translate, tts: tts },
meta: {},
};
return { ...item, extra: { fromModel, translate, tts }, meta: {} };
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/database/schemas/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export const DB_MessageSchema = z.object({
plugin: PluginSchema.optional(),
pluginState: z.any().optional(),
fromModel: z.string().optional(),
translate: TranslateSchema.optional().or(z.null()),
translate: TranslateSchema.optional().or(z.literal(false)),
tts: z.any().optional(),

// foreign keys
Expand Down
60 changes: 0 additions & 60 deletions src/services/__tests__/message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,6 @@ describe('MessageService', () => {
});
});

describe('updateMessageContent', () => {
it('should update the content of a message', async () => {
// Setup
const newContent = 'Updated message content';
(MessageModel.update as Mock).mockResolvedValue({ ...mockMessage, content: newContent });

// Execute
const result = await messageService.updateMessageContent(mockMessageId, newContent);

// Assert
expect(MessageModel.update).toHaveBeenCalledWith(mockMessageId, { content: newContent });
expect(result).toEqual({ ...mockMessage, content: newContent });
});
});

describe('removeMessages', () => {
it('should batch remove messages by assistantId and topicId', async () => {
// Setup
Expand Down Expand Up @@ -243,51 +228,6 @@ describe('MessageService', () => {
});
});

describe('updateMessageTranslate', () => {
it('should update the translate field of a message', async () => {
// Setup
const newTranslate = { content: 'Translated text', to: 'es' } as ChatTranslate;
(MessageModel.update as Mock).mockResolvedValue({ ...mockMessage, translate: newTranslate });

// Execute
const result = await messageService.updateMessageTranslate(mockMessageId, newTranslate);

// Assert
expect(MessageModel.update).toHaveBeenCalledWith(mockMessageId, { translate: newTranslate });
expect(result).toEqual({ ...mockMessage, translate: newTranslate });
});
});

describe('updateMessageTTS', () => {
it('should update the tts field of a message', async () => {
// Setup
const newTTS = { init: false } as ChatTTS;
(MessageModel.update as Mock).mockResolvedValue({ ...mockMessage, tts: newTTS });

// Execute
const result = await messageService.updateMessageTTS(mockMessageId, newTTS);

// Assert
expect(MessageModel.update).toHaveBeenCalledWith(mockMessageId, { tts: newTTS });
expect(result).toEqual({ ...mockMessage, tts: newTTS });
});
});

describe('updateMessageRole', () => {
it('should update the role of a message', async () => {
// Setup
const newRole = 'user';
(MessageModel.update as Mock).mockResolvedValue({ ...mockMessage, role: newRole });

// Execute
const result = await messageService.updateMessageRole(mockMessageId, newRole);

// Assert
expect(MessageModel.update).toHaveBeenCalledWith(mockMessageId, { role: newRole });
expect(result).toEqual({ ...mockMessage, role: newRole });
});
});

describe('updateMessagePlugin', () => {
it('should update the plugin payload of a message', async () => {
// Setup
Expand Down
25 changes: 1 addition & 24 deletions src/services/message.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import { CreateMessageParams, MessageModel } from '@/database/models/message';
import { DB_Message } from '@/database/schemas/message';
import { LLMRoleType } from '@/types/llm';
import {
ChatMessage,
ChatMessageError,
ChatPluginPayload,
ChatTTS,
ChatTranslate,
} from '@/types/message';
import { ChatMessage, ChatMessageError, ChatPluginPayload } from '@/types/message';

export class MessageService {
async create(data: CreateMessageParams) {
Expand Down Expand Up @@ -37,22 +30,10 @@ export class MessageService {
return MessageModel.queryBySessionId(sessionId);
}

async updateMessageContent(id: string, content: string) {
return MessageModel.update(id, { content });
}

async updateMessageError(id: string, error: ChatMessageError) {
return MessageModel.update(id, { error });
}

async updateMessageTranslate(id: string, data: Partial<ChatTranslate> | null) {
return MessageModel.update(id, { translate: data as ChatTranslate });
}

async updateMessageTTS(id: string, data: Partial<ChatTTS> | null) {
return MessageModel.update(id, { tts: data as ChatTTS });
}

async removeMessages(assistantId: string, topicId?: string) {
return MessageModel.batchDelete(assistantId, topicId);
}
Expand All @@ -69,10 +50,6 @@ export class MessageService {
return MessageModel.update(id, message);
}

async updateMessageRole(id: string, role: LLMRoleType) {
return MessageModel.update(id, { role });
}

async updateMessagePlugin(id: string, plugin: ChatPluginPayload) {
return MessageModel.update(id, { plugin });
}
Expand Down
26 changes: 4 additions & 22 deletions src/store/chat/slices/enchance/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { useChatStore } from '../../store';
vi.mock('@/services/message', () => ({
messageService: {
updateMessageTTS: vi.fn(),
updateMessageTranslate: vi.fn(),
updateMessage: vi.fn(),
},
}));

Expand Down Expand Up @@ -59,7 +59,7 @@ describe('ChatEnhanceAction', () => {
await result.current.clearTTS(messageId);
});

expect(messageService.updateMessageTTS).toHaveBeenCalledWith(messageId, null);
expect(messageService.updateMessage).toHaveBeenCalledWith(messageId, { tts: false });
});
});

Expand Down Expand Up @@ -102,7 +102,7 @@ describe('ChatEnhanceAction', () => {
await result.current.translateMessage(messageId, targetLang);
});

expect(messageService.updateMessageTranslate).toHaveBeenCalled();
expect(messageService.updateMessage).toHaveBeenCalled();
});
});

Expand All @@ -115,25 +115,7 @@ describe('ChatEnhanceAction', () => {
await result.current.clearTranslate(messageId);
});

expect(messageService.updateMessageTranslate).toHaveBeenCalledWith(messageId, null);
});
});

describe('ttsMessage', () => {
it('should update TTS state for a message and refresh messages', async () => {
const { result } = renderHook(() => useChatStore());
const messageId = 'message-id';
const ttsState = {
contentMd5: 'some-md5',
file: 'path-to-tts-file',
voice: 'voice-type',
};

await act(async () => {
await result.current.ttsMessage(messageId, ttsState);
});

expect(messageService.updateMessageTTS).toHaveBeenCalledWith(messageId, ttsState);
expect(messageService.updateMessage).toHaveBeenCalledWith(messageId, { translate: false });
});
});
});
34 changes: 19 additions & 15 deletions src/store/chat/slices/enchance/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { imageGenerationService } from '@/services/imageGeneration';
import { messageService } from '@/services/message';
import { chatSelectors } from '@/store/chat/selectors';
import { ChatStore } from '@/store/chat/store';
import { ChatTTS, ChatTranslate } from '@/types/message';
import { DallEImageItem } from '@/types/tool/dalle';
import { setNamespace } from '@/utils/storeDebug';

Expand All @@ -31,6 +32,8 @@ export interface ChatEnhanceAction {
state?: { contentMd5?: string; file?: string; voice?: string },
) => Promise<void>;
updateImageItem: (id: string, updater: (data: DallEImageItem[]) => void) => Promise<void>;
updateMessageTTS: (id: string, data: Partial<ChatTTS> | false) => Promise<void>;
updateMessageTranslate: (id: string, data: Partial<ChatTranslate> | false) => Promise<void>;
}

export const chatEnhance: StateCreator<
Expand All @@ -40,13 +43,11 @@ export const chatEnhance: StateCreator<
ChatEnhanceAction
> = (set, get) => ({
clearTTS: async (id) => {
await messageService.updateMessageTTS(id, null);
await get().refreshMessages();
await get().updateMessageTTS(id, false);
},

clearTranslate: async (id) => {
await messageService.updateMessageTranslate(id, null);
await get().refreshMessages();
await get().updateMessageTranslate(id, false);
},

generateImageFromPrompts: async (items, messageId) => {
Expand Down Expand Up @@ -94,16 +95,14 @@ export const chatEnhance: StateCreator<
n('toggleDallEImageLoading'),
);
},

translateMessage: async (id, targetLang) => {
const { toggleChatLoading, dispatchMessage, refreshMessages } = get();
const { toggleChatLoading, updateMessageTranslate, dispatchMessage } = get();

const message = chatSelectors.getMessageById(id)(get());
if (!message) return;

// create translate extra
await messageService.updateMessageTranslate(id, { content: '', from: '', to: targetLang });
await refreshMessages();
await updateMessageTranslate(id, { content: '', from: '', to: targetLang });

toggleChatLoading(true, id, n('translateMessage(start)', { id }) as string);

Expand All @@ -118,8 +117,7 @@ export const chatEnhance: StateCreator<
.then(async (data) => {
if (data && supportLocales.includes(data)) from = data;

await messageService.updateMessageTranslate(id, { content, from, to: targetLang });
await refreshMessages();
await updateMessageTranslate(id, { content, from, to: targetLang });
});

// translate to target language
Expand All @@ -138,15 +136,13 @@ export const chatEnhance: StateCreator<
params: chainTranslate(message.content, targetLang),
});

await messageService.updateMessageTranslate(id, { content, from, to: targetLang });
await refreshMessages();
await updateMessageTranslate(id, { content, from, to: targetLang });

toggleChatLoading(false);
},

ttsMessage: async (id, state = {}) => {
await messageService.updateMessageTTS(id, state);
await get().refreshMessages();
await get().updateMessageTTS(id, state);
},

updateImageItem: async (id, updater) => {
Expand All @@ -156,7 +152,15 @@ export const chatEnhance: StateCreator<
const data: DallEImageItem[] = JSON.parse(message.content);

const nextContent = produce(data, updater);
await messageService.updateMessageContent(id, JSON.stringify(nextContent));
await get().updateMessageContent(id, JSON.stringify(nextContent));
},

updateMessageTTS: async (id, data) => {
await messageService.updateMessage(id, { tts: data as ChatTTS });
await get().refreshMessages();
},
updateMessageTranslate: async (id, data) => {
await messageService.updateMessage(id, { translate: data as ChatTranslate });
await get().refreshMessages();
},
});
4 changes: 2 additions & 2 deletions src/store/chat/slices/message/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ vi.mock('@/services/message', () => ({
createAssistantMessage: vi.fn(() => Promise.resolve('content-content-content')),
removeMessages: vi.fn(() => Promise.resolve()),
create: vi.fn(() => Promise.resolve('new-message-id')),
updateMessageContent: vi.fn(),
updateMessage: vi.fn(),
clearAllMessage: vi.fn(() => Promise.resolve()),
},
}));
Expand Down Expand Up @@ -303,7 +303,7 @@ describe('chatMessage actions', () => {
await result.current.updateMessageContent(messageId, newContent);
});

expect(messageService.updateMessageContent).toHaveBeenCalledWith(messageId, newContent);
expect(messageService.updateMessage).toHaveBeenCalledWith(messageId, { content: newContent });
});

it('should dispatch message update action', async () => {
Expand Down
12 changes: 5 additions & 7 deletions src/store/chat/slices/message/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export const chatMessage: StateCreator<
// refs: https://medium.com/@kyledeguzmanx/what-are-optimistic-updates-483662c3e171
dispatchMessage({ id, key: 'content', type: 'updateMessage', value: content });

await messageService.updateMessageContent(id, content);
await messageService.updateMessage(id, { content });
await refreshMessages();
},
useFetchMessages: (sessionId, activeTopicId) =>
Expand Down Expand Up @@ -265,7 +265,7 @@ export const chatMessage: StateCreator<
if (functionCallAtEnd) {
// create a new separate message and remove the function call from the prev message

await messageService.updateMessageContent(mid, content.replace(functionCallContent, ''));
await get().updateMessageContent(mid, content.replace(functionCallContent, ''));

const functionMessage: CreateMessageParams = {
role: 'function',
Expand Down Expand Up @@ -294,7 +294,7 @@ export const chatMessage: StateCreator<
set({ messages }, false, n(`dispatchMessage/${payload.type}`, payload));
},
fetchAIChatMessage: async (messages, assistantId) => {
const { toggleChatLoading, refreshMessages } = get();
const { toggleChatLoading, refreshMessages, updateMessageContent } = get();

const abortController = toggleChatLoading(
true,
Expand Down Expand Up @@ -369,14 +369,12 @@ export const chatMessage: StateCreator<
},
onFinish: async (content) => {
// update the content after fetch result
await messageService.updateMessageContent(assistantId, content);
await refreshMessages();
await updateMessageContent(assistantId, content);
},
onMessageHandle: async (text) => {
output += text;

await messageService.updateMessageContent(assistantId, output);
await refreshMessages();
await updateMessageContent(assistantId, output);

// is this message is just a function call
if (isFunctionMessageAtStart(output)) isFunctionCall = true;
Expand Down
9 changes: 4 additions & 5 deletions src/store/chat/slices/tool/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ describe('ChatPluginAction', () => {
});

// 验证 messageService.updateMessageContent 是否被正确调用
expect(messageService.updateMessageContent).toHaveBeenCalledWith(messageId, newContent);
expect(messageService.updateMessage).toHaveBeenCalledWith(messageId, { content: newContent });

// 验证 refreshMessages 是否被调用
expect(result.current.refreshMessages).toHaveBeenCalled();
Expand Down Expand Up @@ -95,10 +95,9 @@ describe('ChatPluginAction', () => {
expect.any(String),
);
expect(chatService.runPluginApi).toHaveBeenCalledWith(pluginPayload, { signal: undefined });
expect(messageService.updateMessageContent).toHaveBeenCalledWith(
messageId,
pluginApiResponse,
);
expect(messageService.updateMessage).toHaveBeenCalledWith(messageId, {
content: pluginApiResponse,
});
expect(initialState.refreshMessages).toHaveBeenCalled();
expect(initialState.coreProcessMessage).toHaveBeenCalled();
expect(initialState.toggleChatLoading).toHaveBeenCalledWith(false);
Expand Down
Loading

0 comments on commit 4a275e9

Please sign in to comment.