Skip to content

Commit

Permalink
[ML] Shared service for elastic curated models (#167000)
Browse files Browse the repository at this point in the history
## Summary

Adds a shared service for elastic curated models. The first use case is
to provide a default/recommended ELSER version based on the hardware of
the current cluster.

#### Why?
In 8.11 we'll provide a platform-specific version of the ELSER v2
alongside the portable one. At the moment several solutions refer to
ELSER for download/inference purposes with a `.elser_model_1` constant.
Starting 8.11 the model ID will vary, so using the `ElastcModels`
service allows retrieving the recommended version of ELSER for the
current cluster without any changes by solution teams in future
releases. It is still possible to request an older version of the model
if necessary.

#### Implementation 
- Adds a new Kibana API endpoint `/trained_models/model_downloads` that
provides a list of model definitions, with the `default` and
`recommended` flags.
- Adds a new Kibana API endpoint `/trained_models/elser_config` that
provides an ELSER configuration based on the cluster architecture.
- `getELSER` method is exposed from the plugin `setup` server-side as
part of our shared services and plugin `start` client-side.

### Checklist

- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
  • Loading branch information
darnautov authored Sep 25, 2023
1 parent 2f1b6ac commit 2bce7bb
Show file tree
Hide file tree
Showing 17 changed files with 438 additions and 77 deletions.
6 changes: 6 additions & 0 deletions x-pack/packages/ml/trained_models_utils/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ export {
type DeploymentState,
type SupportedPytorchTasksType,
type TrainedModelType,
ELASTIC_MODEL_DEFINITIONS,
type ElasticModelId,
type ModelDefinition,
type ModelDefinitionResponse,
type ElserVersion,
type GetElserOptions,
} from './src/constants/trained_models';
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ export const BUILT_IN_MODEL_TAG = 'prepackaged';

export const ELASTIC_MODEL_TAG = 'elastic';

export const ELASTIC_MODEL_DEFINITIONS = {
export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object.freeze({
'.elser_model_1': {
version: 1,
config: {
input: {
field_names: ['text_field'],
Expand All @@ -57,7 +58,49 @@ export const ELASTIC_MODEL_DEFINITIONS = {
defaultMessage: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
}),
},
} as const;
'.elser_model_2_SNAPSHOT': {
version: 2,
default: true,
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2Description', {
defaultMessage: 'Elastic Learned Sparse EncodeR v2 (Tech Preview)',
}),
},
'.elser_model_2_linux-x86_64_SNAPSHOT': {
version: 2,
os: 'Linux',
arch: 'amd64',
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2x86Description', {
defaultMessage:
'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64 (Tech Preview)',
}),
},
} as const);

export interface ModelDefinition {
version: number;
config: object;
description: string;
os?: string;
arch?: string;
default?: boolean;
recommended?: boolean;
}

export type ModelDefinitionResponse = ModelDefinition & {
name: string;
};

export type ElasticModelId = keyof typeof ELASTIC_MODEL_DEFINITIONS;

export const MODEL_STATE = {
...DEPLOYMENT_STATE,
Expand All @@ -66,3 +109,9 @@ export const MODEL_STATE = {
} as const;

export type ModelState = typeof MODEL_STATE[keyof typeof MODEL_STATE] | null;

export type ElserVersion = 1 | 2;

export interface GetElserOptions {
version?: ElserVersion;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { ModelDefinitionResponse, GetElserOptions } from '@kbn/ml-trained-models-utils';
import { type TrainedModelsApiService } from './ml_api_service/trained_models';

export class ElasticModels {
constructor(private readonly trainedModels: TrainedModelsApiService) {}

/**
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
* a portable version of the model is returned.
*/
public async getELSER(options?: GetElserOptions): Promise<ModelDefinitionResponse> {
return await this.trainedModels.getElserConfig(options);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { type HttpStart } from '@kbn/core-http-browser';
import { ElasticModels } from './elastic_models_service';
import { HttpService } from './http_service';
import { mlApiServicesProvider } from './ml_api_service';

export type MlSharedServices = ReturnType<typeof getMlSharedServices>;

/**
* Provides ML services exposed from the plugin start.
*/
export function getMlSharedServices(httpStart: HttpStart) {
const httpService = new HttpService(httpStart);
const mlApiServices = mlApiServicesProvider(httpService);

return {
elasticModels: new ElasticModels(mlApiServices.trainedModels),
};
}
64 changes: 1 addition & 63 deletions x-pack/plugins/ml/public/application/services/http_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

import { Observable } from 'rxjs';
import { HttpFetchOptionsWithPath, HttpFetchOptions, HttpStart } from '@kbn/core/public';
import type { HttpFetchOptionsWithPath, HttpFetchOptions, HttpStart } from '@kbn/core/public';
import { getHttp } from '../util/dependency_cache';

function getResultHeaders(headers: HeadersInit) {
Expand Down Expand Up @@ -59,68 +59,6 @@ export async function http<T>(options: HttpFetchOptionsWithPath): Promise<T> {
return getHttp().fetch<T>(path, fetchOptions);
}

/**
* Function for making HTTP requests to Kibana's backend which returns an Observable
* with request cancellation support.
*
* @deprecated use {@link HttpService} instead
*/
export function http$<T>(options: HttpFetchOptionsWithPath): Observable<T> {
const { path, fetchOptions } = getFetchOptions(options);
return fromHttpHandler<T>(path, fetchOptions);
}

/**
* Creates an Observable from Kibana's HttpHandler.
*/
function fromHttpHandler<T>(input: string, init?: RequestInit): Observable<T> {
return new Observable<T>((subscriber) => {
const controller = new AbortController();
const signal = controller.signal;

let abortable = true;
let unsubscribed = false;

if (init?.signal) {
if (init.signal.aborted) {
controller.abort();
} else {
init.signal.addEventListener('abort', () => {
if (!signal.aborted) {
controller.abort();
}
});
}
}

const perSubscriberInit: RequestInit = {
...(init ? init : {}),
signal,
};

getHttp()
.fetch<T>(input, perSubscriberInit)
.then((response) => {
abortable = false;
subscriber.next(response);
subscriber.complete();
})
.catch((err) => {
abortable = false;
if (!unsubscribed) {
subscriber.error(err);
}
});

return () => {
unsubscribed = true;
if (abortable) {
controller.abort();
}
};
});
}

/**
* ML Http Service
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import type { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';

import { useMemo } from 'react';
import type { HttpFetchQuery } from '@kbn/core/public';
import type { ErrorType } from '@kbn/ml-error-utils';
import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils';
import { ML_INTERNAL_BASE_PATH } from '../../../../common/constants/app';
import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
import { HttpService } from '../http_service';
Expand Down Expand Up @@ -57,6 +58,29 @@ export interface InferenceStatsResponse {
*/
export function trainedModelsApiProvider(httpService: HttpService) {
return {
/**
* Fetches the trained models list available for download.
*/
getTrainedModelDownloads() {
return httpService.http<ModelDefinitionResponse[]>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/model_downloads`,
method: 'GET',
version: '1',
});
},

/**
* Gets ELSER config for download based on the cluster OS and CPU architecture.
*/
getElserConfig(options?: GetElserOptions) {
return httpService.http<ModelDefinitionResponse>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/elser_config`,
method: 'GET',
...(options ? { query: options as HttpFetchQuery } : {}),
version: '1',
});
},

/**
* Fetches configuration information for a trained inference model.
* @param modelId - Model ID, collection of Model IDs or Model ID pattern.
Expand Down
18 changes: 17 additions & 1 deletion x-pack/plugins/ml/public/mocks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
*/

import { sharePluginMock } from '@kbn/share-plugin/public/mocks';
import { MlPluginSetup, MlPluginStart } from './plugin';
import { type ElasticModels } from './application/services/elastic_models_service';
import type { MlPluginSetup, MlPluginStart } from './plugin';

const createSetupContract = (): jest.Mocked<MlPluginSetup> => {
return {
Expand All @@ -17,6 +18,21 @@ const createSetupContract = (): jest.Mocked<MlPluginSetup> => {
const createStartContract = (): jest.Mocked<MlPluginStart> => {
return {
locator: sharePluginMock.createLocator(),
elasticModels: {
getELSER: jest.fn(() =>
Promise.resolve({
version: 2,
default: true,
config: {
input: {
field_names: ['text_field'],
},
},
description: 'Elastic Learned Sparse EncodeR v2 (Tech Preview)',
name: '.elser_model_2',
})
),
} as unknown as jest.Mocked<ElasticModels>,
};
};

Expand Down
10 changes: 10 additions & 0 deletions x-pack/plugins/ml/public/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ import type { ChartsPluginStart } from '@kbn/charts-plugin/public';
import type { CasesUiSetup, CasesUiStart } from '@kbn/cases-plugin/public';
import type { SavedSearchPublicPluginStart } from '@kbn/saved-search-plugin/public';
import type { PresentationUtilPluginStart } from '@kbn/presentation-util-plugin/public';
import {
getMlSharedServices,
MlSharedServices,
} from './application/services/get_shared_ml_services';
import { registerManagementSection } from './application/management';
import { MlLocatorDefinition, type MlLocator } from './locator';
import { setDependencyCache } from './application/util/dependency_cache';
Expand Down Expand Up @@ -103,13 +107,18 @@ export class MlPlugin implements Plugin<MlPluginSetup, MlPluginStart> {
private appUpdater$ = new BehaviorSubject<AppUpdater>(() => ({}));

private locator: undefined | MlLocator;

private sharedMlServices: MlSharedServices | undefined;

private isServerless: boolean = false;

constructor(private initializerContext: PluginInitializerContext) {
this.isServerless = initializerContext.env.packageInfo.buildFlavor === 'serverless';
}

setup(core: MlCoreSetup, pluginsSetup: MlSetupDependencies) {
this.sharedMlServices = getMlSharedServices(core.http);

core.application.register({
id: PLUGIN_ID,
title: i18n.translate('xpack.ml.plugin.title', {
Expand Down Expand Up @@ -249,6 +258,7 @@ export class MlPlugin implements Plugin<MlPluginSetup, MlPluginStart> {

return {
locator: this.locator,
elasticModels: this.sharedMlServices?.elasticModels,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
"InferTrainedModelDeployment",
"CreateInferencePipeline",
"GetIngestPipelines",
"GetTrainedModelDownloadList",
"GetElserConfig",

"Alerting",
"PreviewAlert",
Expand Down
Loading

0 comments on commit 2bce7bb

Please sign in to comment.