Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Search] [Playground] Semantic text support #186268

Merged
merged 10 commits into from
Jun 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,177 @@
* 2.0.
*/

import { SearchResponse } from '@elastic/elasticsearch/lib/api/types';
import { IndicesGetMappingResponse, SearchResponse } from '@elastic/elasticsearch/lib/api/types';

export const SPARSE_SEMANTIC_FIELD_FIELD_CAPS = {
indices: ['test-index2'],
fields: {
infer_field: {
semantic_text: {
type: 'semantic_text',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
'infer_field.inference.chunks.embeddings': {
sparse_vector: {
type: 'sparse_vector',
metadata_field: false,
searchable: true,
aggregatable: false,
},
},
non_infer_field: {
text: {
type: 'text',
metadata_field: false,
searchable: true,
aggregatable: false,
},
},
'infer_field.inference.chunks.text': {
keyword: {
type: 'keyword',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
'infer_field.inference': {
object: {
type: 'object',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
'infer_field.inference.chunks': {
nested: {
type: 'nested',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
},
};

export const SPARSE_SEMANTIC_FIELD_MAPPINGS = {
'test-index2': {
mappings: {
properties: {
infer_field: {
type: 'semantic_text',
inference_id: 'elser-endpoint',
model_settings: {
task_type: 'sparse_embedding',
},
},
non_infer_field: {
type: 'text',
},
},
},
},
} as any as IndicesGetMappingResponse;

export const DENSE_SEMANTIC_FIELD_MAPPINGS = {
'test-index2': {
mappings: {
properties: {
infer_field: {
type: 'semantic_text',
inference_id: 'cohere',
model_settings: {
task_type: 'text_embedding',
dimensions: 1536,
similarity: 'dot_product',
},
},
non_infer_field: {
type: 'text',
},
},
},
},
} as any as IndicesGetMappingResponse;

// for when semantic_text field hasn't been mapped with task_type
// when theres no data / no inference has been performed in the field
export const DENSE_SEMANTIC_FIELD_MAPPINGS_MISSING_TASK_TYPE = {
'test-index2': {
mappings: {
properties: {
infer_field: {
type: 'semantic_text',
inference_id: 'cohere',
model_settings: {
dimensions: 1536,
similarity: 'dot_product',
},
},
non_infer_field: {
type: 'text',
},
},
},
},
} as any as IndicesGetMappingResponse;

export const DENSE_SEMANTIC_FIELD_FIELD_CAPS = {
indices: ['test-index2'],
fields: {
infer_field: {
semantic_text: {
type: 'semantic_text',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
'infer_field.inference.chunks.embeddings': {
sparse_vector: {
type: 'dense_vector',
metadata_field: false,
searchable: true,
aggregatable: false,
},
},
non_infer_field: {
text: {
type: 'text',
metadata_field: false,
searchable: true,
aggregatable: false,
},
},
'infer_field.inference.chunks.text': {
keyword: {
type: 'keyword',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
'infer_field.inference': {
object: {
type: 'object',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
'infer_field.inference.chunks': {
nested: {
type: 'nested',
metadata_field: false,
searchable: false,
aggregatable: false,
},
},
},
};

export const DENSE_SPARSE_SAME_FIELD_NAME_CAPS = {
indices: ['cohere-embeddings', 'elser_index'],
Expand Down
15 changes: 11 additions & 4 deletions x-pack/plugins/search_playground/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,25 @@

export type IndicesQuerySourceFields = Record<string, QuerySourceFields>;

interface ModelFields {
interface ModelField {
field: string;
model_id: string;
nested: boolean;
indices: string[];
}

interface SemanticField {
field: string;
inferenceId: string;
embeddingType: 'sparse_vector' | 'dense_vector';
indices: string[];
}

export interface QuerySourceFields {
elser_query_fields: ModelFields[];
dense_vector_query_fields: ModelFields[];
elser_query_fields: ModelField[];
dense_vector_query_fields: ModelField[];
bm25_query_fields: string[];
source_fields: string[];
semantic_fields: SemanticField[];
skipped_fields: number;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ jest.mock('../../hooks/use_indices_fields', () => ({
dense_vector_query_fields: [],
bm25_query_fields: ['field1', 'field2'],
source_fields: ['context_field1', 'context_field2'],
semantic_fields: [],
},
index2: {
elser_query_fields: [],
dense_vector_query_fields: [],
bm25_query_fields: ['field1', 'field2'],
source_fields: ['context_field1', 'context_field2'],
semantic_fields: [],
},
},
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { render, fireEvent, screen } from '@testing-library/react';
import { ViewQueryFlyout } from './view_query_flyout';
import { FormProvider, useForm } from 'react-hook-form';
import { __IntlProvider as IntlProvider } from '@kbn/i18n-react';
import { ChatFormFields } from '../../types';

jest.mock('../../hooks/use_indices_fields', () => ({
useIndicesFields: () => ({
Expand All @@ -19,12 +20,14 @@ jest.mock('../../hooks/use_indices_fields', () => ({
dense_vector_query_fields: [],
bm25_query_fields: ['field1', 'field2'],
skipped_fields: 1,
semantic_fields: [],
},
index2: {
elser_query_fields: [],
dense_vector_query_fields: [],
bm25_query_fields: ['field1', 'field2'],
skipped_fields: 0,
semantic_fields: [],
},
},
}),
Expand All @@ -41,7 +44,11 @@ jest.mock('../../hooks/use_usage_tracker', () => ({
const MockFormProvider = ({ children }: { children: React.ReactElement }) => {
const methods = useForm({
values: {
indices: ['index1', 'index2'],
[ChatFormFields.indices]: ['index1', 'index2'],
[ChatFormFields.sourceFields]: {
index1: ['field1'],
index2: ['field1'],
},
},
});
return <FormProvider {...methods}>{children}</FormProvider>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ const groupTypeQueryFields = (
typeQueryFields += (typeQueryFields ? '_' : '') + 'SPARSE';
}

if (
selectedFields.some((field) => indexFields.semantic_fields.find((f) => f.field === field))
) {
typeQueryFields += (typeQueryFields ? '_' : '') + 'SEMANTIC';
}

return typeQueryFields;
});

Expand All @@ -76,6 +82,7 @@ export const ViewQueryFlyout: React.FC<ViewQueryFlyoutProps> = ({ onClose }) =>
const usageTracker = useUsageTracker();
const { getValues } = useFormContext<ChatForm>();
const selectedIndices: string[] = getValues(ChatFormFields.indices);
const sourceFields = getValues(ChatFormFields.sourceFields);
const { fields } = useIndicesFields(selectedIndices);
const defaultFields = getDefaultQueryFields(fields);

Expand Down Expand Up @@ -111,7 +118,7 @@ export const ViewQueryFlyout: React.FC<ViewQueryFlyoutProps> = ({ onClose }) =>

const saveQuery = () => {
queryFieldsOnChange(tempQueryFields);
elasticsearchQueryChange(createQuery(tempQueryFields, fields));
elasticsearchQueryChange(createQuery(tempQueryFields, sourceFields, fields));
onClose();

const groupedQueryFields = groupTypeQueryFields(fields, tempQueryFields);
Expand Down Expand Up @@ -168,7 +175,7 @@ export const ViewQueryFlyout: React.FC<ViewQueryFlyoutProps> = ({ onClose }) =>
lineNumbers
data-test-subj="ViewElasticsearchQueryResult"
>
{JSON.stringify(createQuery(tempQueryFields, fields), null, 2)}
{JSON.stringify(createQuery(tempQueryFields, sourceFields, fields), null, 2)}
</EuiCodeBlock>
</EuiFlexItem>
<EuiFlexItem grow={3}>
Expand Down Expand Up @@ -198,6 +205,7 @@ export const ViewQueryFlyout: React.FC<ViewQueryFlyoutProps> = ({ onClose }) =>
aria-label="Select query fields"
data-test-subj={`queryFieldsSelectable_${index}`}
options={[
...group.semantic_fields,
...group.elser_query_fields,
...group.dense_vector_query_fields,
...group.bm25_query_fields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export const useSourceIndicesFields = () => {
setNoFieldsIndicesWarning(null);
}

onElasticsearchQueryChange(createQuery(defaultFields, fields));
onElasticsearchQueryChange(createQuery(defaultFields, defaultSourceFields, fields));
onSourceFieldsChange(defaultSourceFields);
usageTracker?.count(
AnalyticsEvents.sourceFieldsLoaded,
Expand Down
Loading