diff --git a/public/chat_header_button.tsx b/public/chat_header_button.tsx
index 1c8240de..18e610bb 100644
--- a/public/chat_header_button.tsx
+++ b/public/chat_header_button.tsx
@@ -7,6 +7,7 @@ import { EuiBadge, EuiFieldText, EuiIcon } from '@elastic/eui';
import classNames from 'classnames';
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useEffectOnce } from 'react-use';
+
import { ApplicationStart, SIDECAR_DOCKED_MODE } from '../../../src/core/public';
// TODO: Replace with getChrome().logos.Chat.url
import chatIcon from './assets/chat.svg';
diff --git a/public/components/agent_framework_traces_flyout_body.test.tsx b/public/components/agent_framework_traces_flyout_body.test.tsx
index c204a039..6cdbbcd6 100644
--- a/public/components/agent_framework_traces_flyout_body.test.tsx
+++ b/public/components/agent_framework_traces_flyout_body.test.tsx
@@ -5,10 +5,12 @@
import React from 'react';
import '@testing-library/jest-dom/extend-expect';
-import { act, waitFor, render, screen, fireEvent } from '@testing-library/react';
+import { waitFor, render, screen, fireEvent } from '@testing-library/react';
import * as chatContextExports from '../contexts/chat_context';
+import * as coreContextExports from '../contexts/core_context';
import { AgentFrameworkTracesFlyoutBody } from './agent_framework_traces_flyout_body';
import { TAB_ID } from '../utils/constants';
+import { BehaviorSubject, Subject } from 'rxjs';
jest.mock('./agent_framework_traces', () => {
return {
@@ -17,6 +19,20 @@ jest.mock('./agent_framework_traces', () => {
});
describe(' spec', () => {
+ let dataSourceIdUpdates$: Subject;
+ beforeEach(() => {
+ dataSourceIdUpdates$ = new Subject();
+ jest.spyOn(coreContextExports, 'useCore').mockImplementation(() => {
+ return {
+ services: {
+ dataSource: {
+ dataSourceIdUpdates$,
+ },
+ },
+ };
+ });
+ });
+
it('show back button if interactionId exists', async () => {
const onCloseMock = jest.fn();
jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({
@@ -70,4 +86,19 @@ describe(' spec', () => {
expect(onCloseMock).toHaveBeenCalledWith(TAB_ID.HISTORY);
});
});
+
+ it('should set tab to chat after data source changed', () => {
+ const setSelectedTabIdMock = jest.fn();
+ jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({
+ interactionId: 'test-interaction-id',
+ flyoutFullScreen: true,
+ setSelectedTabId: setSelectedTabIdMock,
+ preSelectedTabId: TAB_ID.HISTORY,
+ });
+ render();
+
+ expect(setSelectedTabIdMock).not.toHaveBeenCalled();
+ dataSourceIdUpdates$.next('foo');
+ expect(setSelectedTabIdMock).toHaveBeenCalled();
+ });
});
diff --git a/public/components/agent_framework_traces_flyout_body.tsx b/public/components/agent_framework_traces_flyout_body.tsx
index 5cf12b0c..b18f8004 100644
--- a/public/components/agent_framework_traces_flyout_body.tsx
+++ b/public/components/agent_framework_traces_flyout_body.tsx
@@ -13,14 +13,26 @@ import {
EuiButtonIcon,
EuiPageHeaderSection,
} from '@elastic/eui';
-import React from 'react';
+import React, { useEffect } from 'react';
import { useChatContext } from '../contexts/chat_context';
+import { useCore } from '../../public/contexts';
import { AgentFrameworkTraces } from './agent_framework_traces';
import { TAB_ID } from '../utils/constants';
export const AgentFrameworkTracesFlyoutBody: React.FC = () => {
+ const core = useCore();
const chatContext = useChatContext();
const interactionId = chatContext.interactionId;
+
+ useEffect(() => {
+ const subscription = core.services.dataSource.dataSourceIdUpdates$.subscribe(() => {
+ chatContext.setSelectedTabId(TAB_ID.CHAT);
+ });
+ return () => {
+ subscription.unsubscribe();
+ };
+ }, [core.services.dataSource, chatContext.setSelectedTabId]);
+
if (!interactionId) {
return null;
}
diff --git a/public/hooks/use_chat_actions.test.tsx b/public/hooks/use_chat_actions.test.tsx
index 26ca0096..eb1db675 100644
--- a/public/hooks/use_chat_actions.test.tsx
+++ b/public/hooks/use_chat_actions.test.tsx
@@ -26,7 +26,14 @@ jest.mock('../services/conversations_service', () => {
jest.mock('../services/conversation_load_service', () => {
return {
ConversationLoadService: jest.fn().mockImplementation(() => {
- return { load: jest.fn().mockReturnValue({ messages: [], interactions: [] }) };
+ const conversationLoadMock = {
+ abortController: new AbortController(),
+ load: jest.fn().mockImplementation(async () => {
+ conversationLoadMock.abortController = new AbortController();
+ return { messages: [], interactions: [] };
+ }),
+ };
+ return conversationLoadMock;
}),
};
});
@@ -126,7 +133,7 @@ describe('useChatActions hook', () => {
messages: [SEND_MESSAGE_RESPONSE.messages[0]],
input: INPUT_MESSAGE,
}),
- query: await dataSourceServiceMock.getDataSourceQuery(),
+ query: dataSourceServiceMock.getDataSourceQuery(),
});
// it should send dispatch `receive` action to remove the message without messageId
@@ -201,7 +208,7 @@ describe('useChatActions hook', () => {
messages: [],
input: { type: 'input', content: 'message that send as input', contentType: 'text' },
}),
- query: await dataSourceServiceMock.getDataSourceQuery(),
+ query: dataSourceServiceMock.getDataSourceQuery(),
});
});
@@ -264,7 +271,7 @@ describe('useChatActions hook', () => {
expect(chatStateDispatchMock).toHaveBeenCalledWith({ type: 'abort' });
expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.ABORT_AGENT_EXECUTION, {
body: JSON.stringify({ conversationId: 'conversation_id_to_abort' }),
- query: await dataSourceServiceMock.getDataSourceQuery(),
+ query: dataSourceServiceMock.getDataSourceQuery(),
});
});
@@ -292,7 +299,7 @@ describe('useChatActions hook', () => {
conversationId: 'conversation_id_mock',
interactionId: 'interaction_id_mock',
}),
- query: await dataSourceServiceMock.getDataSourceQuery(),
+ query: dataSourceServiceMock.getDataSourceQuery(),
});
expect(chatStateDispatchMock).toHaveBeenCalledWith(
expect.objectContaining({ type: 'receive', payload: { messages: [], interactions: [] } })
@@ -312,6 +319,7 @@ describe('useChatActions hook', () => {
it('should not handle regenerate response if the regenerate operation has already aborted', async () => {
const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
signal: { aborted: true },
+ abort: jest.fn(),
}));
httpMock.put.mockResolvedValue(SEND_MESSAGE_RESPONSE);
@@ -328,7 +336,7 @@ describe('useChatActions hook', () => {
conversationId: 'conversation_id_mock',
interactionId: 'interaction_id_mock',
}),
- query: await dataSourceServiceMock.getDataSourceQuery(),
+ query: dataSourceServiceMock.getDataSourceQuery(),
});
expect(chatStateDispatchMock).not.toHaveBeenCalledWith(
expect.objectContaining({ type: 'receive' })
@@ -353,6 +361,7 @@ describe('useChatActions hook', () => {
it('should not handle regenerate error if the regenerate operation has already aborted', async () => {
const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
signal: { aborted: true },
+ abort: jest.fn(),
}));
httpMock.put.mockImplementationOnce(() => {
throw new Error();
@@ -369,4 +378,43 @@ describe('useChatActions hook', () => {
);
AbortControllerMock.mockRestore();
});
+
+ it('should clear chat title, conversation id, flyoutComponent and call reset action', async () => {
+ const { result } = renderHook(() => useChatActions());
+ result.current.resetChat();
+
+ expect(chatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined);
+ expect(chatContextMock.setTitle).toHaveBeenLastCalledWith(undefined);
+ expect(chatContextMock.setFlyoutComponent).toHaveBeenLastCalledWith(null);
+
+ expect(chatStateDispatchMock).toHaveBeenLastCalledWith({ type: 'reset' });
+ });
+
+ it('should abort send action after reset chat', async () => {
+ const abortFn = jest.fn();
+ const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
+ signal: { aborted: true },
+ abort: abortFn,
+ }));
+ const { result } = renderHook(() => useChatActions());
+ await result.current.send(INPUT_MESSAGE);
+ result.current.resetChat();
+
+ expect(abortFn).toHaveBeenCalled();
+ AbortControllerMock.mockRestore();
+ });
+
+ it('should abort load action after reset chat', async () => {
+ const abortFn = jest.fn();
+ const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
+ signal: { aborted: true },
+ abort: abortFn,
+ }));
+ const { result } = renderHook(() => useChatActions());
+ await result.current.loadChat('conversation_id_mock');
+ result.current.resetChat();
+
+ expect(abortFn).toHaveBeenCalled();
+ AbortControllerMock.mockRestore();
+ });
});
diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx
index 962420f2..5c3c239a 100644
--- a/public/hooks/use_chat_actions.tsx
+++ b/public/hooks/use_chat_actions.tsx
@@ -35,7 +35,7 @@ export const useChatActions = (): AssistantActions => {
...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
- query: await core.services.dataSource.getDataSourceQuery(),
+ query: core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) return;
// Refresh history list after new conversation created if new conversation saved and history list page visible
@@ -106,6 +106,15 @@ export const useChatActions = (): AssistantActions => {
}
};
+ const resetChat = () => {
+ abortControllerRef?.abort();
+ core.services.conversationLoad.abortController?.abort();
+ chatContext.setConversationId(undefined);
+ chatContext.setTitle(undefined);
+ chatContext.setFlyoutComponent(null);
+ chatStateDispatch({ type: 'reset' });
+ };
+
const openChatUI = () => {
chatContext.setFlyoutVisible(true);
chatContext.setSelectedTabId(TAB_ID.CHAT);
@@ -163,7 +172,7 @@ export const useChatActions = (): AssistantActions => {
// abort agent execution
await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, {
body: JSON.stringify({ conversationId }),
- query: await core.services.dataSource.getDataSourceQuery(),
+ query: core.services.dataSource.getDataSourceQuery(),
});
}
};
@@ -180,7 +189,7 @@ export const useChatActions = (): AssistantActions => {
conversationId: chatContext.conversationId,
interactionId,
}),
- query: await core.services.dataSource.getDataSourceQuery(),
+ query: core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) {
@@ -225,5 +234,5 @@ export const useChatActions = (): AssistantActions => {
}
};
- return { send, loadChat, executeAction, openChatUI, abortAction, regenerate };
+ return { send, loadChat, resetChat, executeAction, openChatUI, abortAction, regenerate };
};
diff --git a/public/hooks/use_conversations.ts b/public/hooks/use_conversations.ts
index 2828a4af..516a982e 100644
--- a/public/hooks/use_conversations.ts
+++ b/public/hooks/use_conversations.ts
@@ -14,13 +14,13 @@ export const useDeleteConversation = () => {
const abortControllerRef = useRef();
const deleteConversation = useCallback(
- async (conversationId: string) => {
+ (conversationId: string) => {
abortControllerRef.current = new AbortController();
dispatch({ type: 'request' });
return core.services.http
.delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, {
signal: abortControllerRef.current.signal,
- query: await core.services.dataSource.getDataSourceQuery(),
+ query: core.services.dataSource.getDataSourceQuery(),
})
.then((payload) => {
dispatch({ type: 'success', payload });
@@ -53,7 +53,7 @@ export const usePatchConversation = () => {
const abortControllerRef = useRef();
const patchConversation = useCallback(
- async (conversationId: string, title: string) => {
+ (conversationId: string, title: string) => {
abortControllerRef.current = new AbortController();
dispatch({ type: 'request' });
return core.services.http
@@ -61,7 +61,7 @@ export const usePatchConversation = () => {
body: JSON.stringify({
title,
}),
- query: await core.services.dataSource.getDataSourceQuery(),
+ query: core.services.dataSource.getDataSourceQuery(),
signal: abortControllerRef.current.signal,
})
.then((payload) => dispatch({ type: 'success', payload }))
diff --git a/public/hooks/use_feed_back.test.tsx b/public/hooks/use_feed_back.test.tsx
index 74de2031..943aab2f 100644
--- a/public/hooks/use_feed_back.test.tsx
+++ b/public/hooks/use_feed_back.test.tsx
@@ -84,7 +84,7 @@ describe('useFeedback hook', () => {
body: JSON.stringify({
satisfaction: true,
}),
- query: await dataSourceMock.getDataSourceQuery(),
+ query: dataSourceMock.getDataSourceQuery(),
}
);
expect(result.current.feedbackResult).toBe(true);
@@ -119,7 +119,7 @@ describe('useFeedback hook', () => {
body: JSON.stringify({
satisfaction: true,
}),
- query: await dataSourceMock.getDataSourceQuery(),
+ query: dataSourceMock.getDataSourceQuery(),
}
);
expect(result.current.feedbackResult).toBe(undefined);
diff --git a/public/hooks/use_feed_back.tsx b/public/hooks/use_feed_back.tsx
index 2222b301..7a37bd37 100644
--- a/public/hooks/use_feed_back.tsx
+++ b/public/hooks/use_feed_back.tsx
@@ -38,7 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => {
try {
await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, {
body: JSON.stringify(body),
- query: await core.services.dataSource.getDataSourceQuery(),
+ query: core.services.dataSource.getDataSourceQuery(),
});
setFeedbackResult(correct);
} catch (error) {
diff --git a/public/hooks/use_fetch_agentframework_traces.ts b/public/hooks/use_fetch_agentframework_traces.ts
index 43925855..7ad43668 100644
--- a/public/hooks/use_fetch_agentframework_traces.ts
+++ b/public/hooks/use_fetch_agentframework_traces.ts
@@ -22,26 +22,24 @@ export const useFetchAgentFrameworkTraces = (interactionId: string) => {
return;
}
- core.services.dataSource.getDataSourceQuery().then((query) => {
- core.services.http
- .get(`${ASSISTANT_API.TRACE}/${interactionId}`, {
- signal: abortController.signal,
- query,
+ core.services.http
+ .get(`${ASSISTANT_API.TRACE}/${interactionId}`, {
+ signal: abortController.signal,
+ query: core.services.dataSource.getDataSourceQuery(),
+ })
+ .then((payload) =>
+ dispatch({
+ type: 'success',
+ payload,
})
- .then((payload) =>
- dispatch({
- type: 'success',
- payload,
- })
- )
- .catch((error) => {
- if (error.name === 'AbortError') return;
- dispatch({ type: 'failure', error });
- });
- });
+ )
+ .catch((error) => {
+ if (error.name === 'AbortError') return;
+ dispatch({ type: 'failure', error });
+ });
return () => abortController.abort();
- }, [core.services.http, interactionId]);
+ }, [core.services.http, interactionId, core.services.dataSource]);
return { ...state };
};
diff --git a/public/plugin.tsx b/public/plugin.tsx
index 8e5c1151..89c4f545 100644
--- a/public/plugin.tsx
+++ b/public/plugin.tsx
@@ -5,6 +5,7 @@
import { EuiLoadingSpinner } from '@elastic/eui';
import React, { lazy, Suspense } from 'react';
+import { Subscription } from 'rxjs';
import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '../../../src/core/public';
import {
createOpenSearchDashboardsReactContext,
@@ -61,6 +62,7 @@ export class AssistantPlugin
private config: ConfigSchema;
incontextInsightRegistry: IncontextInsightRegistry | undefined;
private dataSourceService: DataSourceService;
+ private resetChatSubscription: Subscription | undefined;
constructor(initializerContext: PluginInitializerContext) {
this.config = initializerContext.config.get();
@@ -108,6 +110,12 @@ export class AssistantPlugin
const username = account.user_name;
this.incontextInsightRegistry?.setIsEnabled(this.config.incontextInsight.enabled);
+ if (this.dataSourceService.isMDSEnabled()) {
+ this.resetChatSubscription = this.dataSourceService.dataSourceIdUpdates$.subscribe(() => {
+ assistantActions.resetChat?.();
+ });
+ }
+
coreStart.chrome.navControls.registerRight({
order: 10000,
mount: toMountPoint(
@@ -163,5 +171,6 @@ export class AssistantPlugin
public stop() {
this.dataSourceService.stop();
+ this.resetChatSubscription?.unsubscribe();
}
}
diff --git a/public/services/__tests__/data_source_service.test.ts b/public/services/__tests__/data_source_service.test.ts
index d1db20dc..a14e5faf 100644
--- a/public/services/__tests__/data_source_service.test.ts
+++ b/public/services/__tests__/data_source_service.test.ts
@@ -111,6 +111,22 @@ describe('DataSourceService', () => {
});
expect(await dataSource.getDataSourceId$().pipe(first()).toPromise()).toBe('foo');
});
+
+ it('should not fire change for same data source id', async () => {
+ const { dataSource, defaultDataSourceSelection$ } = setup({
+ dataSourceSelection: new Map(),
+ defaultDataSourceId: 'foo',
+ });
+ const observerFn = jest.fn();
+ dataSource.getDataSourceId$().subscribe(observerFn);
+
+ expect(observerFn).toHaveBeenCalledTimes(1);
+ dataSource.setDataSourceId('foo');
+ expect(observerFn).toHaveBeenCalledTimes(1);
+
+ defaultDataSourceSelection$.next('foo');
+ expect(observerFn).toHaveBeenCalledTimes(1);
+ });
});
describe('isMDSEnabled', () => {
@@ -126,23 +142,23 @@ describe('DataSourceService', () => {
describe('getDataSourceQuery', () => {
it('should return empty object if MDS not enabled', async () => {
const { dataSource } = setup({ dataSourceManagement: undefined });
- expect(await dataSource.getDataSourceQuery()).toEqual({});
+ expect(dataSource.getDataSourceQuery()).toEqual({});
});
it('should return empty object if data source id is empty', async () => {
const { dataSource } = setup({
dataSourceSelection: new Map([['test', [{ label: '', id: '' }]]]),
});
- expect(await dataSource.getDataSourceQuery()).toEqual({});
+ expect(dataSource.getDataSourceQuery()).toEqual({});
});
it('should return query object with provided data source id', async () => {
const { dataSource } = setup({ defaultDataSourceId: 'foo' });
- expect(await dataSource.getDataSourceQuery()).toEqual({ dataSourceId: 'foo' });
+ expect(dataSource.getDataSourceQuery()).toEqual({ dataSourceId: 'foo' });
});
it('should throw error if data source id not exists', async () => {
const { dataSource } = setup();
let error;
try {
- await dataSource.getDataSourceQuery();
+ dataSource.getDataSourceQuery();
} catch (e) {
error = e;
}
@@ -209,4 +225,15 @@ describe('DataSourceService', () => {
dataSource.setDataSourceId('bar');
expect(observerFn).toHaveBeenCalledTimes(3);
});
+
+ it('should emit new data source id updates after data source id change', () => {
+ const { dataSource } = setup();
+ const observerFn = jest.fn();
+ dataSource.dataSourceIdUpdates$.subscribe(observerFn);
+ dataSource.setDataSourceId('foo');
+ expect(observerFn).toHaveBeenCalledTimes(1);
+
+ dataSource.setDataSourceId('bar');
+ expect(observerFn).toHaveBeenCalledTimes(2);
+ });
});
diff --git a/public/services/conversation_load_service.ts b/public/services/conversation_load_service.ts
index 866ad6d6..b2875a20 100644
--- a/public/services/conversation_load_service.ts
+++ b/public/services/conversation_load_service.ts
@@ -26,7 +26,7 @@ export class ConversationLoadService {
`${ASSISTANT_API.CONVERSATION}/${conversationId}`,
{
signal: this.abortController.signal,
- query: await this._dataSource.getDataSourceQuery(),
+ query: this._dataSource.getDataSourceQuery(),
}
);
this.status$.next('idle');
diff --git a/public/services/conversations_service.ts b/public/services/conversations_service.ts
index ab188070..52973fea 100644
--- a/public/services/conversations_service.ts
+++ b/public/services/conversations_service.ts
@@ -29,7 +29,10 @@ export class ConversationsService {
}
load = async (
- query?: Pick
+ query?: Pick<
+ SavedObjectsFindOptions,
+ 'page' | 'perPage' | 'fields' | 'sortField' | 'sortOrder' | 'search' | 'searchFields'
+ >
) => {
this.abortController?.abort();
this.abortController = new AbortController();
@@ -40,7 +43,7 @@ export class ConversationsService {
await this._http.get(ASSISTANT_API.CONVERSATIONS, {
query: {
...this._options,
- ...(await this._dataSource.getDataSourceQuery()),
+ ...this._dataSource.getDataSourceQuery(),
} as HttpFetchQuery,
signal: this.abortController.signal,
})
diff --git a/public/services/data_source_service.mock.ts b/public/services/data_source_service.mock.ts
index 0ca0bea1..d8496607 100644
--- a/public/services/data_source_service.mock.ts
+++ b/public/services/data_source_service.mock.ts
@@ -10,14 +10,11 @@ export class DataSourceServiceMock {
}
getDataSourceQuery() {
- const result = this._isMDSEnabled
+ return this._isMDSEnabled
? {
dataSourceId: '',
}
: {};
- return new Promise((resolve) => {
- resolve(result);
- });
}
isMDSEnabled() {
diff --git a/public/services/data_source_service.ts b/public/services/data_source_service.ts
index f89565e4..02f4bc21 100644
--- a/public/services/data_source_service.ts
+++ b/public/services/data_source_service.ts
@@ -3,8 +3,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { BehaviorSubject, Subscription, combineLatest, of } from 'rxjs';
-import { first, map } from 'rxjs/operators';
+import { BehaviorSubject, Observable, Subject, Subscription, combineLatest, of } from 'rxjs';
+import { distinctUntilChanged, map } from 'rxjs/operators';
import type { IUiSettingsClient } from '../../../../src/core/public';
import type { DataSourceOption } from '../../../../src/plugins/data_source_management/public/components/data_source_menu/types';
@@ -36,6 +36,9 @@ export class DataSourceService {
private uiSettings: IUiSettingsClient | undefined;
private dataSourceManagement: DataSourceManagementPluginSetup | undefined;
private dataSourceSelectionSubscription: Subscription | undefined;
+ private finalDataSourceId: string | null = null;
+ dataSourceIdUpdates$ = new Subject();
+ private getDataSourceIdSubscription: Subscription | undefined;
constructor() {}
@@ -54,11 +57,11 @@ export class DataSourceService {
});
}
- async getDataSourceQuery() {
+ getDataSourceQuery() {
if (!this.isMDSEnabled()) {
return {};
}
- const dataSourceId = await this.getDataSourceId$().pipe(first()).toPromise();
+ const dataSourceId = this.finalDataSourceId;
if (dataSourceId === null) {
throw new Error('No data source id');
}
@@ -74,24 +77,24 @@ export class DataSourceService {
}
setDataSourceId(newDataSourceId: string | null) {
- if (this.dataSourceId$.getValue() === newDataSourceId) {
- return;
- }
this.dataSourceId$.next(newDataSourceId);
}
getDataSourceId$() {
return combineLatest([
this.dataSourceId$,
- this.dataSourceManagement?.getDefaultDataSourceId$?.(this.uiSettings) ?? of(null),
- ]).pipe(
- map(([selectedDataSourceId, defaultDataSourceId]) => {
- if (selectedDataSourceId !== null) {
- return selectedDataSourceId;
- }
- return defaultDataSourceId;
- })
- );
+ (this.dataSourceManagement?.getDefaultDataSourceId$?.(this.uiSettings) ??
+ of(null)) as Observable,
+ ])
+ .pipe(
+ map(([selectedDataSourceId, defaultDataSourceId]) => {
+ if (selectedDataSourceId !== null) {
+ return selectedDataSourceId;
+ }
+ return defaultDataSourceId;
+ })
+ )
+ .pipe(distinctUntilChanged());
}
setup({
@@ -104,6 +107,10 @@ export class DataSourceService {
this.uiSettings = uiSettings;
this.dataSourceManagement = dataSourceManagement;
this.init();
+ this.getDataSourceIdSubscription = this.getDataSourceId$().subscribe((finalDataSourceId) => {
+ this.finalDataSourceId = finalDataSourceId;
+ this.dataSourceIdUpdates$.next(finalDataSourceId);
+ });
return {
setDataSourceId: (newDataSourceId: string | null) => {
@@ -122,6 +129,8 @@ export class DataSourceService {
public stop() {
this.dataSourceSelectionSubscription?.unsubscribe();
+ this.getDataSourceIdSubscription?.unsubscribe();
+ this.dataSourceIdUpdates$.complete();
this.dataSourceId$.complete();
}
}
diff --git a/public/tabs/history/__tests__/chat_history_page.test.tsx b/public/tabs/history/__tests__/chat_history_page.test.tsx
index 944fedf6..2f87af01 100644
--- a/public/tabs/history/__tests__/chat_history_page.test.tsx
+++ b/public/tabs/history/__tests__/chat_history_page.test.tsx
@@ -6,6 +6,7 @@
import React from 'react';
import { act, fireEvent, render, waitFor } from '@testing-library/react';
import { I18nProvider } from '@osd/i18n/react';
+import { BehaviorSubject, Subject } from 'rxjs';
import { coreMock } from '../../../../../../src/core/public/mocks';
import { HttpStart } from '../../../../../../src/core/public';
@@ -14,7 +15,6 @@ import * as useChatStateExports from '../../../hooks/use_chat_state';
import * as chatContextExports from '../../../contexts/chat_context';
import * as coreContextExports from '../../../contexts/core_context';
import { ConversationsService } from '../../../services/conversations_service';
-import { DataSourceServiceMock } from '../../../services/data_source_service.mock';
import { ChatHistoryPage } from '../chat_history_page';
@@ -27,7 +27,7 @@ const mockGetConversationsHttp = () => {
title: 'foo',
},
],
- total: 1,
+ total: 100,
}));
return http;
};
@@ -35,11 +35,16 @@ const mockGetConversationsHttp = () => {
const setup = ({
http = mockGetConversationsHttp(),
chatContext = {},
+ shouldRefresh = false,
}: {
http?: HttpStart;
chatContext?: { flyoutFullScreen?: boolean };
+ shouldRefresh?: boolean;
} = {}) => {
- const dataSourceMock = new DataSourceServiceMock();
+ const dataSourceMock = {
+ dataSourceIdUpdates$: new Subject(),
+ getDataSourceQuery: jest.fn(() => ({ dataSourceId: 'foo' })),
+ };
const useCoreMock = {
services: {
...coreMock.createStart(),
@@ -65,7 +70,7 @@ const setup = ({
const renderResult = render(
-
+
);
@@ -73,6 +78,7 @@ const setup = ({
useCoreMock,
useChatStateMock,
useChatContextMock,
+ dataSourceMock,
renderResult,
};
};
@@ -240,4 +246,100 @@ describe('', () => {
expect(abortMock).toHaveBeenCalled();
});
});
+
+ it('should call conversations.reload after data source changed', async () => {
+ const { useCoreMock, dataSourceMock } = setup({ shouldRefresh: true });
+
+ jest.spyOn(useCoreMock.services.conversations, 'load');
+
+ expect(useCoreMock.services.conversations.load).not.toHaveBeenCalled();
+
+ act(() => {
+ dataSourceMock.dataSourceIdUpdates$.next('bar');
+ });
+
+ await waitFor(() => {
+ expect(useCoreMock.services.conversations.load).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ it('should not call conversations.load after unmount', async () => {
+ const { useCoreMock, dataSourceMock, renderResult } = setup({ shouldRefresh: true });
+
+ jest.spyOn(useCoreMock.services.conversations, 'reload');
+
+ expect(useCoreMock.services.conversations.reload).not.toHaveBeenCalled();
+ renderResult.unmount();
+
+ dataSourceMock.dataSourceIdUpdates$.next('bar');
+ expect(useCoreMock.services.conversations.reload).not.toHaveBeenCalled();
+ });
+
+ it('should load conversations with empty search after data source changed', async () => {
+ const { useCoreMock, dataSourceMock, renderResult } = setup({ shouldRefresh: true });
+
+ jest.spyOn(useCoreMock.services.conversations, 'load');
+
+ fireEvent.change(renderResult.getByPlaceholderText('Search by conversation name'), {
+ target: {
+ value: 'bar',
+ },
+ });
+
+ await waitFor(() => {
+ expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith(
+ expect.objectContaining({
+ search: 'bar',
+ })
+ );
+ });
+
+ act(() => {
+ dataSourceMock.dataSourceIdUpdates$.next('baz');
+ });
+
+ await waitFor(() => {
+ expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith({
+ fields: expect.any(Array),
+ page: 1,
+ perPage: 10,
+ sortField: 'updatedTimeMs',
+ sortOrder: 'DESC',
+ searchFields: ['title'],
+ });
+ expect(useCoreMock.services.conversations.load).toHaveBeenCalledTimes(2);
+ });
+ });
+
+ it('should load conversations with first page after data source changed', async () => {
+ const { useCoreMock, dataSourceMock, renderResult } = setup({ shouldRefresh: true });
+
+ jest.spyOn(useCoreMock.services.conversations, 'load');
+
+ await waitFor(() => {
+ expect(renderResult.getByTestId('pagination-button-1')).toBeInTheDocument();
+ });
+
+ fireEvent.click(renderResult.getByTestId('pagination-button-1'));
+
+ await waitFor(() => {
+ expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith(
+ expect.objectContaining({
+ page: 2,
+ })
+ );
+ });
+
+ act(() => {
+ dataSourceMock.dataSourceIdUpdates$.next('baz');
+ });
+
+ await waitFor(() => {
+ expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith(
+ expect.objectContaining({
+ page: 1,
+ })
+ );
+ });
+ });
});
diff --git a/public/tabs/history/chat_history_page.tsx b/public/tabs/history/chat_history_page.tsx
index 98662891..92b175c8 100644
--- a/public/tabs/history/chat_history_page.tsx
+++ b/public/tabs/history/chat_history_page.tsx
@@ -18,8 +18,9 @@ import {
} from '@elastic/eui';
import React, { useCallback, useEffect, useMemo, useState } from 'react';
import { FormattedMessage } from '@osd/i18n/react';
-import { useDebounce, useObservable } from 'react-use';
+import { useDebounce, useObservable, useUpdateEffect } from 'react-use';
import cs from 'classnames';
+
import { useChatActions, useChatState } from '../../hooks';
import { useChatContext, useCore } from '../../contexts';
import { TAB_ID } from '../../utils/constants';
@@ -41,34 +42,36 @@ export const ChatHistoryPage: React.FC = React.memo((props
setConversationId,
setTitle,
} = useChatContext();
- const [pageIndex, setPageIndex] = useState(0);
- const [pageSize, setPageSize] = useState(10);
const [searchName, setSearchName] = useState('');
- const [debouncedSearchName, setDebouncedSearchName] = useState('');
- const bulkGetOptions = useMemo(
- () => ({
- page: pageIndex + 1,
- perPage: pageSize,
- fields: ['createdTimeMs', 'updatedTimeMs', 'title'],
- sortField: 'updatedTimeMs',
- sortOrder: 'DESC',
- ...(debouncedSearchName ? { search: debouncedSearchName, searchFields: ['title'] } : {}),
- }),
- [pageIndex, pageSize, debouncedSearchName]
- );
+ const [bulkGetOptions, setBulkGetOptions] = useState<{
+ page: number;
+ perPage: number;
+ fields: string[];
+ sortField: string;
+ sortOrder: string;
+ searchFields: string[];
+ search?: string;
+ }>({
+ page: 1,
+ perPage: 10,
+ fields: ['createdTimeMs', 'updatedTimeMs', 'title'],
+ sortField: 'updatedTimeMs',
+ sortOrder: 'DESC',
+ searchFields: ['title'],
+ });
const conversations = useObservable(services.conversations.conversations$);
const loading = useObservable(services.conversations.status$) === 'loading';
const chatHistories = useMemo(() => conversations?.objects || [], [conversations]);
const hasNoConversations =
- !debouncedSearchName && !!conversations && conversations.total === 0 && !loading;
+ !bulkGetOptions.search && !!conversations && conversations.total === 0 && !loading;
+ const dataSourceUpdate = useObservable(services.dataSource.dataSourceIdUpdates$);
const handleSearchChange = useCallback((e) => {
setSearchName(e.target.value);
}, []);
const handleItemsPerPageChange = useCallback((itemsPerPage: number) => {
- setPageIndex(0);
- setPageSize(itemsPerPage);
+ setBulkGetOptions((prevOptions) => ({ ...prevOptions, page: 1, perPage: itemsPerPage }));
}, []);
const handleBack = useCallback(() => {
@@ -87,19 +90,49 @@ export const ChatHistoryPage: React.FC = React.memo((props
[conversationId, setConversationId, setTitle, chatStateDispatch]
);
+ const handlePageChange = useCallback((newPage) => {
+ setBulkGetOptions((prevOptions) => ({
+ ...prevOptions,
+ page: newPage + 1,
+ }));
+ }, []);
+
useDebounce(
() => {
- setPageIndex(0);
- setDebouncedSearchName(searchName);
+ setBulkGetOptions((prevOptions) => {
+ if (prevOptions.search === searchName || (!prevOptions.search && searchName === '')) {
+ return prevOptions;
+ }
+ const { search, ...rest } = prevOptions;
+ return {
+ ...rest,
+ page: 1,
+ ...(searchName ? { search: searchName } : {}),
+ };
+ });
},
150,
[searchName]
);
- useEffect(() => {
- if (props.shouldRefresh) services.conversations.reload();
+ useUpdateEffect(() => {
+ if (!props.shouldRefresh) {
+ return;
+ }
+ services.conversations.reload();
+ return () => {
+ services.conversations.abortController?.abort();
+ };
}, [props.shouldRefresh, services.conversations]);
+ useUpdateEffect(() => {
+ setSearchName('');
+ setBulkGetOptions(({ search, page, ...rest }) => ({
+ ...rest,
+ page: 1,
+ }));
+ }, [dataSourceUpdate]);
+
useEffect(() => {
services.conversations.load(bulkGetOptions);
return () => {
@@ -150,11 +183,13 @@ export const ChatHistoryPage: React.FC = React.memo((props
onLoadChat={loadChat}
onRefresh={services.conversations.reload}
histories={chatHistories}
- activePage={pageIndex}
- itemsPerPage={pageSize}
+ activePage={bulkGetOptions.page - 1}
+ itemsPerPage={bulkGetOptions.perPage}
onChangeItemsPerPage={handleItemsPerPageChange}
- onChangePage={setPageIndex}
- {...(conversations ? { pageCount: Math.ceil(conversations.total / pageSize) } : {})}
+ onChangePage={handlePageChange}
+ {...(conversations
+ ? { pageCount: Math.ceil(conversations.total / bulkGetOptions.perPage) }
+ : {})}
onHistoryDeleted={handleHistoryDeleted}
/>
)}
diff --git a/public/types.ts b/public/types.ts
index ca6505b2..4b05b4ad 100644
--- a/public/types.ts
+++ b/public/types.ts
@@ -21,6 +21,7 @@ export type ActionExecutor = (params: Record) => void;
export interface AssistantActions {
send: (input: IMessage) => Promise;
loadChat: (conversationId?: string, title?: string) => Promise;
+ resetChat: () => void;
openChatUI: (conversationId?: string) => void;
executeAction: (suggestedAction: ISuggestedAction, message: IMessage) => Promise;
abortAction: (conversationId?: string) => Promise;