Skip to content

Commit

Permalink
[Security Solution] Security Assistant: Reduces the public API surfac…
Browse files Browse the repository at this point in the history
…e & improves code coverage (#15)

## [Security Solution] Security Assistant: Reduces the public API surface & improves code coverage

- Reduces the public API surface
- improves code coverage
  • Loading branch information
andrew-goldstein authored May 26, 2023
1 parent 42e86e3 commit 578fb89
Show file tree
Hide file tree
Showing 42 changed files with 1,531 additions and 445 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React from 'react';
import { fireEvent, render, screen } from '@testing-library/react';
import userEvent from '@testing-library/user-event';

import { TestProviders } from '../../mock/test_providers/test_providers';
import type { PromptContext } from '../prompt_context/types';
import { ContextPills } from '.';

const mockPromptContexts: Record<string, PromptContext> = {
context1: {
category: 'alert',
description: 'Context 1',
getPromptContext: () => Promise.resolve('Context 1 data'),
id: 'context1',
tooltip: 'Context 1 tooltip',
},
context2: {
category: 'event',
description: 'Context 2',
getPromptContext: () => Promise.resolve('Context 2 data'),
id: 'context2',
tooltip: 'Context 2 tooltip',
},
};

describe('ContextPills', () => {
beforeEach(() => jest.clearAllMocks());

it('renders the context pill descriptions', () => {
render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[]}
setSelectedPromptContextIds={jest.fn()}
/>
</TestProviders>
);

Object.values(mockPromptContexts).forEach(({ id, description }) => {
expect(screen.getByTestId(`pillButton-${id}`)).toHaveTextContent(description);
});
});

it('invokes setSelectedPromptContextIds() when the prompt is NOT already selected', () => {
const context = mockPromptContexts.context1;
const setSelectedPromptContextIds = jest.fn();

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[]} // <-- the prompt is NOT selected
setSelectedPromptContextIds={setSelectedPromptContextIds}
/>
</TestProviders>
);

userEvent.click(screen.getByTestId(`pillButton-${context.id}`));

expect(setSelectedPromptContextIds).toBeCalled();
});

it('it does NOT invoke setSelectedPromptContextIds() when the prompt is already selected', () => {
const context = mockPromptContexts.context1;
const setSelectedPromptContextIds = jest.fn();

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[context.id]} // <-- the context is already selected
setSelectedPromptContextIds={setSelectedPromptContextIds}
/>
</TestProviders>
);

// NOTE: this test uses `fireEvent` instead of `userEvent` to bypass the disabled button:
fireEvent.click(screen.getByTestId(`pillButton-${context.id}`));

expect(setSelectedPromptContextIds).not.toBeCalled();
});

it('disables selected context pills', () => {
const context = mockPromptContexts.context1;

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[context.id]} // <-- context1 is selected
setSelectedPromptContextIds={jest.fn()}
/>
</TestProviders>
);

expect(screen.getByTestId(`pillButton-${context.id}`)).toBeDisabled();
});

it("does NOT disable context pills that aren't selected", () => {
const context = mockPromptContexts.context1;

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={['context2']} // context1 is NOT selected
setSelectedPromptContextIds={jest.fn()}
/>
</TestProviders>
);

expect(screen.getByTestId(`pillButton-${context.id}`)).not.toBeDisabled();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const ContextPillsComponent: React.FC<Props> = ({
<EuiFlexItem grow={false} key={id}>
<EuiToolTip content={tooltip}>
<PillButton
data-test-subj={`pillButton-${id}`}
disabled={selectedPromptContextIds.includes(id)}
iconSide="left"
iconType="plus"
Expand Down
62 changes: 0 additions & 62 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import crypto from 'crypto';

import { fetchOpenAlerts, fetchVirusTotalAnalysis, sendFileToVirusTotal } from './api';
import { SYSTEM_PROMPT_CONTEXT_NON_I18N } from '../content/prompts/system/translations';
import type { PromptContext } from './prompt_context/types';
import type { Message } from '../assistant_context/types';
import type { Prompt } from './types';

/**
* Do it like this in your `kibana.dev.yml`, use 'overrides' as keys aren't actually defined
Expand Down Expand Up @@ -397,62 +394,3 @@ export const getMessageFromRawResponse = (rawResponse: string): Message => {
};
}
};

export const getSystemMessages = ({
isNewChat,
selectedSystemPrompt,
}: {
isNewChat: boolean;
selectedSystemPrompt: Prompt | undefined;
}): Message[] => {
if (!isNewChat || selectedSystemPrompt == null) {
return [];
}

const message: Message = {
content: selectedSystemPrompt.content,
role: 'system',
timestamp: new Date().toLocaleString(),
};

return [message];
};

export async function getCombinedMessage({
isNewChat,
promptContexts,
promptText,
selectedPromptContextIds,
selectedSystemPrompt,
}: {
isNewChat: boolean;
promptContexts: Record<string, PromptContext>;
promptText: string;
selectedPromptContextIds: string[];
selectedSystemPrompt: Prompt | undefined;
}): Promise<Message> {
const selectedPromptContexts = selectedPromptContextIds.reduce<PromptContext[]>((acc, id) => {
const promptContext = promptContexts[id];
return promptContext != null ? [...acc, promptContext] : acc;
}, []);

const promptContextsContent = await Promise.all(
selectedPromptContexts.map(async ({ getPromptContext, id }) => {
const promptContext = await getPromptContext();

return `\n\n${SYSTEM_PROMPT_CONTEXT_NON_I18N(promptContext)}\n\n`;
})
);

return {
content: `${
isNewChat ? `${selectedSystemPrompt?.content ?? ''}` : `${promptContextsContent}\n\n`
}
${promptContextsContent}
${promptText}`,
role: 'user', // we are combining the system and user messages into one message
timestamp: new Date().toLocaleString(),
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import styled from 'styled-components';
import { createPortal } from 'react-dom';
import { css } from '@emotion/react';

import { getCombinedMessage, getMessageFromRawResponse } from './helpers';
import { getMessageFromRawResponse } from './helpers';

import { SettingsPopover } from './settings_popover';
import { useAssistantContext } from '../assistant_context';
Expand All @@ -36,7 +36,7 @@ import { useSendMessages } from './use_send_messages';
import type { Message } from '../assistant_context/types';
import { ConversationSelector } from './conversation_selector';
import { PromptEditor } from './prompt_editor';
import { getDefaultSystemPrompt, getSuperheroPrompt } from './prompt/helpers';
import { getCombinedMessage, getDefaultSystemPrompt, getSuperheroPrompt } from './prompt/helpers';
import * as i18n from './translations';
import type { Prompt } from './types';
import { getPromptById } from './prompt_editor/helpers';
Expand All @@ -47,6 +47,7 @@ import { WELCOME_CONVERSATION_ID } from './use_conversation/sample_conversations

const CommentsContainer = styled.div`
max-height: 600px;
max-width: 100%;
overflow-y: scroll;
`;

Expand Down
Loading

0 comments on commit 578fb89

Please sign in to comment.