diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 196d8f3f00056..298c680fc70d4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1673,6 +1673,11 @@ x-pack/test/security_solution_cypress/cypress/tasks/expandable_flyout @elastic/ /x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra /x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +# Inference API +/x-pack/plugins/stack_connectors/public/connector_types/inference @elastic/appex-ai-infra @elastic/security-generative-ai @elastic/obs-ai-assistant +/x-pack/plugins/stack_connectors/server/connector_types/inference @elastic/appex-ai-infra @elastic/security-generative-ai @elastic/obs-ai-assistant +/x-pack/plugins/stack_connectors/common/inference @elastic/appex-ai-infra @elastic/security-generative-ai @elastic/obs-ai-assistant + ## Defend Workflows owner connectors /x-pack/plugins/stack_connectors/public/connector_types/sentinelone @elastic/security-defend-workflows /x-pack/plugins/stack_connectors/server/connector_types/sentinelone @elastic/security-defend-workflows diff --git a/docs/management/action-types.asciidoc b/docs/management/action-types.asciidoc index 30bf4f791e5d8..361892e430afd 100644 --- a/docs/management/action-types.asciidoc +++ b/docs/management/action-types.asciidoc @@ -28,6 +28,10 @@ a| <> | Send a request to {gemini}. +a| <> + +| Send a request to {inference}. + a| <> | Send email from your server. diff --git a/docs/management/connectors/action-types/inference.asciidoc b/docs/management/connectors/action-types/inference.asciidoc new file mode 100644 index 0000000000000..8c7f2840f9c5c --- /dev/null +++ b/docs/management/connectors/action-types/inference.asciidoc @@ -0,0 +1,126 @@ +[[inference-action-type]] +== {infer-cap} connector and action +++++ +{inference} +++++ +:frontmatter-description: Add a connector that can send requests to {inference}. +:frontmatter-tags-products: [kibana] +:frontmatter-tags-content-type: [how-to] +:frontmatter-tags-user-goals: [configure] + + +The {infer} connector uses the {es} client to send requests to an {infer} service. The connector uses the <> to send the request. + +[float] +[[define-inference-ui]] +=== Create connectors in {kib} + +You can create connectors in *{stack-manage-app} > {connectors-ui}*. For example: + +[role="screenshot"] +image::management/connectors/images/inference-connector.png[{inference} connector] +// NOTE: This is an autogenerated screenshot. Do not edit it directly. + +[float] +[[inference-connector-configuration]] +==== Connector configuration + +{infer-cap} connectors have the following configuration properties: + +Name:: The name of the connector. +Service:: The supported {infer} service provider. +Task type:: The {infer} task type, it depends on the selected service. +Inference ID:: The unique identifier of the {infer} endpoint. +Provider configuration:: Settings for service configuration. +Provider secrets:: Configuration for authentication. +Task type configuration:: Settings for task type configuration. + +[float] +[[inference-action-configuration]] +=== Test connectors + +You can test connectors using the <> or +while creating or editing the connector in {kib}. For example: + +[role="screenshot"] +image::management/connectors/images/inference-completion-params.png[{infer} params test] +// NOTE: This is an autogenerated screenshot. Do not edit it directly. +[float] +[[inference-connector-actions]] +=== {infer-cap} connector actions + +The {infer} actions have the following configuration properties. Properties depend on the selected task type. + +[float] +[[inference-connector-perform-completion]] +==== Completion + +The following example performs a completion task on the example question. +Input:: +The text on which you want to perform the {infer} task. For example: ++ +[source,text] +-- +{ + input: 'What is Elastic?' +} +-- + +[float] +[[inference-connector-perform-text-embedding]] +==== Text embedding + +The following example performs a text embedding task. +Input:: +The text on which you want to perform the {infer} task. For example: ++ +[source,text] +-- +{ + input: 'The sky above the port was the color of television tuned to a dead channel.', + task_settings: { + input_type: 'ingest' + } +} +-- +Input type:: +An optional string that overwrites the connector's default model. + +[float] +[[inference-connector-perform-rerank]] +==== Reranking + +The following example performs a reranking task on the example input. +Input:: +The text on which you want to perform the {infer} task. Should be a string array. For example: ++ +[source,text] +-- +{ + input: ['luke', 'like', 'leia', 'chewy', 'r2d2', 'star', 'wars'], + query: 'star wars main character' +} +-- +Query:: +The search query text. + +[float] +[[inference-connector-perform-sparse-embedding]] +==== Sparse embedding + +The following example performs a sparse embedding task on the example sentence. +Input:: +The text on which you want to perform the {infer} task. For example: ++ +[source,text] +-- +{ + input: 'The sky above the port was the color of television tuned to a dead channel.' +} +-- + +[float] +[[inference-connector-networking-configuration]] +=== Connector networking configuration + +Use the <> to customize connector networking configurations, such as proxies, certificates, or TLS settings. You can apply these settings to all your connectors or use `xpack.actions.customHostSettings` to set per-host configurations. diff --git a/docs/management/connectors/images/inference-completion-params.png b/docs/management/connectors/images/inference-completion-params.png new file mode 100644 index 0000000000000..686ee9771c2f5 Binary files /dev/null and b/docs/management/connectors/images/inference-completion-params.png differ diff --git a/docs/management/connectors/images/inference-connector.png b/docs/management/connectors/images/inference-connector.png new file mode 100644 index 0000000000000..dcd37f0865c54 Binary files /dev/null and b/docs/management/connectors/images/inference-connector.png differ diff --git a/docs/management/connectors/index.asciidoc b/docs/management/connectors/index.asciidoc index 18f2c28d10f04..c5233ad4f4934 100644 --- a/docs/management/connectors/index.asciidoc +++ b/docs/management/connectors/index.asciidoc @@ -4,6 +4,7 @@ include::action-types/crowdstrike.asciidoc[leveloffset=+1] include::action-types/d3security.asciidoc[leveloffset=+1] include::action-types/email.asciidoc[leveloffset=+1] include::action-types/gemini.asciidoc[leveloffset=+1] +include::action-types/inference.asciidoc[leveloffset=+1] include::action-types/resilient.asciidoc[leveloffset=+1] include::action-types/index.asciidoc[leveloffset=+1] include::action-types/jira.asciidoc[leveloffset=+1] diff --git a/docs/settings/alert-action-settings.asciidoc b/docs/settings/alert-action-settings.asciidoc index 18273319ecd08..db36248ef194f 100644 --- a/docs/settings/alert-action-settings.asciidoc +++ b/docs/settings/alert-action-settings.asciidoc @@ -283,6 +283,7 @@ A configuration URL that varies by connector: -- * For an <>, specifies the {bedrock} request URL. * For an <>, specifies the {gemini} request URL. +* For an <>, specifies the Elastic {inference} request. * For a <>, specifies the OpenAI request URL. * For a <>, specifies the {ibm-r} instance URL. * For a <>, specifies the Jira instance URL. diff --git a/oas_docs/examples/get_connector_types_generativeai_response.yaml b/oas_docs/examples/get_connector_types_generativeai_response.yaml index a97199e0a3927..8299da3558150 100644 --- a/oas_docs/examples/get_connector_types_generativeai_response.yaml +++ b/oas_docs/examples/get_connector_types_generativeai_response.yaml @@ -31,3 +31,14 @@ value: supported_feature_ids: - generativeAIForSecurity is_system_action_type: false + - id: .inference + name: Inference API + enabled: true + enabled_in_config: true + enabled_in_license: true + minimum_license_required: enterprise + supported_feature_ids: + - generativeAIForSecurity + - generativeAIForObservability + - generativeAIForSearchPlayground + is_system_action_type: false diff --git a/x-pack/packages/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx b/x-pack/packages/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx index 3213322463d51..81c8c2a4ea7e4 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx @@ -40,8 +40,18 @@ export const useLoadActionTypes = ({ http, featureId: GenerativeAIForSecurityConnectorFeatureId, }); - const sortedData = queryResult.sort((a, b) => a.name.localeCompare(b.name)); + const actionTypeKey = { + bedrock: '.bedrock', + openai: '.gen-ai', + gemini: '.gemini', + }; + + const sortedData = queryResult + .filter((p) => + [actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(p.id) + ) + .sort((a, b) => a.name.localeCompare(b.name)); return sortedData; }, { diff --git a/x-pack/plugins/actions/common/connector_feature_config.ts b/x-pack/plugins/actions/common/connector_feature_config.ts index 5ba316f47d59b..cffa4c433b8f7 100644 --- a/x-pack/plugins/actions/common/connector_feature_config.ts +++ b/x-pack/plugins/actions/common/connector_feature_config.ts @@ -46,7 +46,7 @@ const compatibilityGenerativeAIForObservability = i18n.translate( const compatibilityGenerativeAIForSearchPlayground = i18n.translate( 'xpack.actions.availableConnectorFeatures.compatibility.generativeAIForSearchPlayground', { - defaultMessage: 'Generative AI for Search Playground', + defaultMessage: 'Generative AI for Search', } ); diff --git a/x-pack/plugins/actions/docs/openapi/components/schemas/connector_types.yaml b/x-pack/plugins/actions/docs/openapi/components/schemas/connector_types.yaml index db6262f04c010..1db9e155f2eec 100644 --- a/x-pack/plugins/actions/docs/openapi/components/schemas/connector_types.yaml +++ b/x-pack/plugins/actions/docs/openapi/components/schemas/connector_types.yaml @@ -4,6 +4,7 @@ description: The type of connector. For example, `.email`, `.index`, `.jira`, `. enum: - .bedrock - .gemini + - .inference - .cases-webhook - .d3security - .email diff --git a/x-pack/plugins/actions/docs/openapi/components/schemas/inference_config.yaml b/x-pack/plugins/actions/docs/openapi/components/schemas/inference_config.yaml new file mode 100644 index 0000000000000..8b1219d079f32 --- /dev/null +++ b/x-pack/plugins/actions/docs/openapi/components/schemas/inference_config.yaml @@ -0,0 +1,23 @@ +title: Connector request properties for an Inference API connector +description: Defines properties for connectors when type is `.inference`. +type: object +required: + - provider + - taskType + - inferenceId +properties: + provider: + type: string + description: The Inference API service provider. + taskType: + type: string + description: The Inference task type supported by provider. + providerConfig: + type: object + description: The provider settings. + taskTypeConfig: + type: object + description: The task type settings. + inferenceId: + type: string + description: The task type settings. diff --git a/x-pack/plugins/actions/docs/openapi/components/schemas/inference_secrets.yaml b/x-pack/plugins/actions/docs/openapi/components/schemas/inference_secrets.yaml new file mode 100644 index 0000000000000..5630a18097633 --- /dev/null +++ b/x-pack/plugins/actions/docs/openapi/components/schemas/inference_secrets.yaml @@ -0,0 +1,9 @@ +title: Connector secrets properties for an AI Connector +description: Defines secrets for connectors when type is `.inference`. +type: object +required: + - providerSecrets +properties: + providerSecrets: + type: object + description: The service account credentials. The service account could have different type of properties to encode. \ No newline at end of file diff --git a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap index 94bc911557c21..d778849347d18 100644 --- a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap +++ b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap @@ -5617,6 +5617,353 @@ Object { } `; +exports[`Connector type config checks detect connector type changes for: .inference 1`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "input": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 2`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "input": Object { + "flags": Object { + "default": Array [], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + ], + "type": "array", + }, + "query": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 3`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "input": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 4`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "input": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "inputType": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 5`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "input": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 6`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "inferenceId": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "provider": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "providerConfig": Object { + "flags": Object { + "default": Object {}, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + "taskType": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "taskTypeConfig": Object { + "flags": Object { + "default": Object {}, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 7`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "providerSecrets": Object { + "flags": Object { + "default": Object {}, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .inference 8`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "subAction": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "subActionParams": Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + }, + "type": "object", +} +`; + exports[`Connector type config checks detect connector type changes for: .jira 1`] = ` Object { "flags": Object { diff --git a/x-pack/plugins/actions/server/integration_tests/mocks/connector_types.ts b/x-pack/plugins/actions/server/integration_tests/mocks/connector_types.ts index a26c775a74a5b..fff112de59f16 100644 --- a/x-pack/plugins/actions/server/integration_tests/mocks/connector_types.ts +++ b/x-pack/plugins/actions/server/integration_tests/mocks/connector_types.ts @@ -32,6 +32,7 @@ export const connectorTypes: string[] = [ '.thehive', '.sentinelone', '.crowdstrike', + '.inference', '.cases', '.observability-ai-assistant', ]; diff --git a/x-pack/plugins/stack_connectors/common/experimental_features.ts b/x-pack/plugins/stack_connectors/common/experimental_features.ts index 495921a95c60e..7adcad74aad85 100644 --- a/x-pack/plugins/stack_connectors/common/experimental_features.ts +++ b/x-pack/plugins/stack_connectors/common/experimental_features.ts @@ -15,6 +15,7 @@ export const allowedExperimentalValues = Object.freeze({ isMustacheAutocompleteOn: false, sentinelOneConnectorOn: true, crowdstrikeConnectorOn: true, + inferenceConnectorOn: true, }); export type ExperimentalConfigKeys = Array; diff --git a/x-pack/plugins/stack_connectors/common/inference/constants.ts b/x-pack/plugins/stack_connectors/common/inference/constants.ts new file mode 100644 index 0000000000000..b795e54f5d32a --- /dev/null +++ b/x-pack/plugins/stack_connectors/common/inference/constants.ts @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; + +export const INFERENCE_CONNECTOR_TITLE = i18n.translate( + 'xpack.stackConnectors.components.inference.connectorTypeTitle', + { + defaultMessage: 'AI Connector', + } +); + +export enum ServiceProviderKeys { + amazonbedrock = 'amazonbedrock', + azureopenai = 'azureopenai', + azureaistudio = 'azureaistudio', + cohere = 'cohere', + elasticsearch = 'elasticsearch', + googleaistudio = 'googleaistudio', + googlevertexai = 'googlevertexai', + hugging_face = 'hugging_face', + mistral = 'mistral', + openai = 'openai', + anthropic = 'anthropic', + watsonxai = 'watsonxai', + 'alibabacloud-ai-search' = 'alibabacloud-ai-search', +} + +export const INFERENCE_CONNECTOR_ID = '.inference'; +export enum SUB_ACTION { + COMPLETION = 'completion', + RERANK = 'rerank', + TEXT_EMBEDDING = 'text_embedding', + SPARSE_EMBEDDING = 'sparse_embedding', + COMPLETION_STREAM = 'completion_stream', +} + +export const DEFAULT_PROVIDER = 'openai'; +export const DEFAULT_TASK_TYPE = 'completion'; diff --git a/x-pack/plugins/stack_connectors/common/inference/schema.ts b/x-pack/plugins/stack_connectors/common/inference/schema.ts new file mode 100644 index 0000000000000..07b51cf9a5aa3 --- /dev/null +++ b/x-pack/plugins/stack_connectors/common/inference/schema.ts @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { schema } from '@kbn/config-schema'; + +export const ConfigSchema = schema.object({ + provider: schema.string(), + taskType: schema.string(), + inferenceId: schema.string(), + providerConfig: schema.object({}, { unknowns: 'allow', defaultValue: {} }), + taskTypeConfig: schema.object({}, { unknowns: 'allow', defaultValue: {} }), +}); + +export const SecretsSchema = schema.object({ + providerSecrets: schema.object({}, { unknowns: 'allow', defaultValue: {} }), +}); + +export const ChatCompleteParamsSchema = schema.object({ + input: schema.string(), +}); + +export const ChatCompleteResponseSchema = schema.arrayOf( + schema.object({ + result: schema.string(), + }), + { defaultValue: [] } +); + +export const RerankParamsSchema = schema.object({ + input: schema.arrayOf(schema.string(), { defaultValue: [] }), + query: schema.string(), +}); + +export const RerankResponseSchema = schema.arrayOf( + schema.object({ + text: schema.maybe(schema.string()), + index: schema.number(), + score: schema.number(), + }), + { defaultValue: [] } +); + +export const SparseEmbeddingParamsSchema = schema.object({ + input: schema.string(), +}); + +export const SparseEmbeddingResponseSchema = schema.arrayOf( + schema.object({}, { unknowns: 'allow' }), + { defaultValue: [] } +); + +export const TextEmbeddingParamsSchema = schema.object({ + input: schema.string(), + inputType: schema.string(), +}); + +export const TextEmbeddingResponseSchema = schema.arrayOf( + schema.object({ + embedding: schema.arrayOf(schema.any(), { defaultValue: [] }), + }), + { defaultValue: [] } +); + +export const StreamingResponseSchema = schema.stream(); diff --git a/x-pack/plugins/stack_connectors/common/inference/types.ts b/x-pack/plugins/stack_connectors/common/inference/types.ts new file mode 100644 index 0000000000000..9dbd447cb4578 --- /dev/null +++ b/x-pack/plugins/stack_connectors/common/inference/types.ts @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { TypeOf } from '@kbn/config-schema'; +import { + ConfigSchema, + SecretsSchema, + StreamingResponseSchema, + ChatCompleteParamsSchema, + ChatCompleteResponseSchema, + RerankParamsSchema, + RerankResponseSchema, + SparseEmbeddingParamsSchema, + SparseEmbeddingResponseSchema, + TextEmbeddingParamsSchema, + TextEmbeddingResponseSchema, +} from './schema'; + +export type Config = TypeOf; +export type Secrets = TypeOf; + +export type ChatCompleteParams = TypeOf; +export type ChatCompleteResponse = TypeOf; + +export type RerankParams = TypeOf; +export type RerankResponse = TypeOf; + +export type SparseEmbeddingParams = TypeOf; +export type SparseEmbeddingResponse = TypeOf; + +export type TextEmbeddingParams = TypeOf; +export type TextEmbeddingResponse = TypeOf; + +export type StreamingResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/index.ts b/x-pack/plugins/stack_connectors/public/connector_types/index.ts index dd1c5e5c63a2a..92c10bc6ccd57 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/index.ts +++ b/x-pack/plugins/stack_connectors/public/connector_types/index.ts @@ -14,6 +14,7 @@ import { getJiraConnectorType } from './jira'; import { getOpenAIConnectorType } from './openai'; import { getBedrockConnectorType } from './bedrock'; import { getGeminiConnectorType } from './gemini'; +import { getInferenceConnectorType } from './inference'; import { getOpsgenieConnectorType } from './opsgenie'; import { getPagerDutyConnectorType } from './pagerduty'; import { getResilientConnectorType } from './resilient'; @@ -80,4 +81,7 @@ export function registerConnectorTypes({ if (ExperimentalFeaturesService.get().crowdstrikeConnectorOn) { connectorTypeRegistry.register(getCrowdStrikeConnectorType()); } + if (ExperimentalFeaturesService.get().inferenceConnectorOn) { + connectorTypeRegistry.register(getInferenceConnectorType()); + } } diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/additional_options_fields.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/additional_options_fields.tsx new file mode 100644 index 0000000000000..8973f3124bc86 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/additional_options_fields.tsx @@ -0,0 +1,360 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { useMemo, useCallback } from 'react'; +import { css } from '@emotion/react'; + +import { + EuiFormRow, + EuiSpacer, + EuiTitle, + EuiAccordion, + EuiFieldText, + useEuiTheme, + EuiTextColor, + EuiButtonGroup, + EuiPanel, + EuiHorizontalRule, + EuiButtonEmpty, + EuiCopy, + EuiButton, + useEuiFontSize, +} from '@elastic/eui'; +import { + getFieldValidityAndErrorMessage, + UseField, + useFormContext, +} from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; +import { FormattedMessage } from '@kbn/i18n-react'; + +import { fieldValidators } from '@kbn/es-ui-shared-plugin/static/forms/helpers'; +import { ConnectorConfigurationFormItems } from '../lib/dynamic_config/connector_configuration_form_items'; +import * as i18n from './translations'; +import { DEFAULT_TASK_TYPE } from './constants'; +import { ConfigEntryView } from '../lib/dynamic_config/types'; +import { Config } from './types'; +import { TaskTypeOption } from './helpers'; + +// Custom trigger button CSS +const buttonCss = css` + &:hover { + text-decoration: none; + } +`; + +interface AdditionalOptionsConnectorFieldsProps { + config: Config; + readOnly: boolean; + isEdit: boolean; + optionalProviderFormFields: ConfigEntryView[]; + onSetProviderConfigEntry: (key: string, value: unknown) => Promise; + onTaskTypeOptionsSelect: (taskType: string, provider?: string) => Promise; + selectedTaskType?: string; + taskTypeFormFields: ConfigEntryView[]; + taskTypeSchema: ConfigEntryView[]; + taskTypeOptions: TaskTypeOption[]; +} + +export const AdditionalOptionsConnectorFields: React.FC = ({ + config, + readOnly, + isEdit, + taskTypeOptions, + optionalProviderFormFields, + taskTypeFormFields, + taskTypeSchema, + selectedTaskType, + onSetProviderConfigEntry, + onTaskTypeOptionsSelect, +}) => { + const xsFontSize = useEuiFontSize('xs').fontSize; + const { euiTheme } = useEuiTheme(); + const { setFieldValue, validateFields } = useFormContext(); + + const onSetTaskTypeConfigEntry = useCallback( + async (key: string, value: unknown) => { + if (taskTypeSchema) { + const entry: ConfigEntryView | undefined = taskTypeSchema.find( + (p: ConfigEntryView) => p.key === key + ); + if (entry) { + if (!config.taskTypeConfig) { + config.taskTypeConfig = {}; + } + const newConfig = { ...config.taskTypeConfig }; + newConfig[key] = value; + setFieldValue('config.taskTypeConfig', newConfig); + await validateFields(['config.taskTypeConfig']); + } + } + }, + [config, setFieldValue, taskTypeSchema, validateFields] + ); + + const taskTypeSettings = useMemo( + () => + selectedTaskType || config.taskType?.length ? ( + <> + +

+ +

+
+ +
+ +
+ + + {(field) => { + const { isInvalid, errorMessage } = getFieldValidityAndErrorMessage(field); + + return ( + + } + isInvalid={isInvalid} + error={errorMessage} + > + {isEdit || readOnly ? ( + + {config.taskType} + + ) : taskTypeOptions.length === 1 ? ( + onTaskTypeOptionsSelect(config.taskType)} + > + {config.taskType} + + ) : ( + onTaskTypeOptionsSelect(id)} + options={taskTypeOptions} + color="text" + type="single" + /> + )} + + ); + }} + + + + + ) : null, + [ + selectedTaskType, + config?.taskType, + xsFontSize, + euiTheme.colors, + taskTypeFormFields, + onSetTaskTypeConfigEntry, + isEdit, + readOnly, + taskTypeOptions, + onTaskTypeOptionsSelect, + ] + ); + + const inferenceUri = useMemo(() => `_inference/${selectedTaskType}/`, [selectedTaskType]); + + return ( + + + + } + initialIsOpen={true} + > + + + {optionalProviderFormFields.length > 0 ? ( + <> + +

+ +

+
+ +
+ +
+ + + + + ) : null} + + {taskTypeSettings} + + +

+ +

+
+ +
+ +
+ + + + {(field) => { + const { isInvalid, errorMessage } = getFieldValidityAndErrorMessage(field); + + return ( + + } + > + { + setFieldValue('config.inferenceId', e.target.value); + }} + prepend={inferenceUri} + append={ + + {(copy) => ( + + + + )} + + } + /> + + ); + }} + +
+
+ ); +}; + +// eslint-disable-next-line import/no-default-export +export { AdditionalOptionsConnectorFields as default }; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/connector.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/connector.test.tsx new file mode 100644 index 0000000000000..44632e8b08331 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/connector.test.tsx @@ -0,0 +1,353 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; + +import ConnectorFields from './connector'; +import { ConnectorFormTestProvider } from '../lib/test_utils'; +import { render, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { createStartServicesMock } from '@kbn/triggers-actions-ui-plugin/public/common/lib/kibana/kibana_react.mock'; +import { DisplayType, FieldType } from '../lib/dynamic_config/types'; +import { useProviders } from './providers/get_providers'; +import { getTaskTypes } from './get_task_types'; +import { HttpSetup } from '@kbn/core-http-browser'; + +jest.mock('./providers/get_providers'); +jest.mock('./get_task_types'); + +const mockUseKibanaReturnValue = createStartServicesMock(); +jest.mock('@kbn/triggers-actions-ui-plugin/public/common/lib/kibana', () => ({ + __esModule: true, + useKibana: jest.fn(() => ({ + services: mockUseKibanaReturnValue, + })), +})); + +jest.mock('@faker-js/faker', () => ({ + faker: { + string: { + alpha: jest.fn().mockReturnValue('123'), + }, + }, +})); + +const mockProviders = useProviders as jest.Mock; +const mockTaskTypes = getTaskTypes as jest.Mock; + +const providersSchemas = [ + { + provider: 'openai', + logo: '', // should be openai logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 3, + required: true, + sensitive: true, + tooltip: `The OpenAI API authentication key. For more details about generating OpenAI API keys, refer to the https://platform.openai.com/account/api-keys.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 2, + required: true, + sensitive: false, + tooltip: 'The name of the model.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + organization_id: { + display: DisplayType.TEXTBOX, + label: 'Organization ID', + order: 4, + required: false, + sensitive: false, + tooltip: '', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + url: { + display: DisplayType.TEXTBOX, + label: 'URL', + order: 1, + required: true, + sensitive: false, + tooltip: '', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: 'https://api.openai.com/v1/chat/completions', + depends_on: [], + }, + }, + }, + { + provider: 'googleaistudio', + logo: '', // should be googleaistudio logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 2, + required: true, + sensitive: false, + tooltip: `ID of the LLM you're using`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, +]; +const taskTypesSchemas: Record = { + googleaistudio: [ + { + task_type: 'completion', + configuration: {}, + }, + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + openai: [ + { + task_type: 'completion', + configuration: { + user: { + display: DisplayType.TEXTBOX, + label: 'User', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the user issuing the request.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], +}; + +const openAiConnector = { + actionTypeId: '.inference', + name: 'AI Connector', + id: '123', + config: { + provider: 'openai', + taskType: 'completion', + providerConfig: { + url: 'https://openaiurl.com', + model_id: 'gpt-4o', + organization_id: 'test-org', + }, + taskTypeConfig: { + user: 'elastic', + }, + }, + secrets: { + secretsConfig: { + api_key: 'thats-a-nice-looking-key', + }, + }, + isDeprecated: false, +}; + +const googleaistudioConnector = { + ...openAiConnector, + config: { + ...openAiConnector.config, + provider: 'googleaistudio', + providerConfig: { + ...openAiConnector.config.providerConfig, + model_id: 'somemodel', + }, + taskTypeConfig: {}, + }, + secrets: { + secretsConfig: { + api_key: 'thats-google-key', + }, + }, +}; + +describe('ConnectorFields renders', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockProviders.mockReturnValue({ + isLoading: false, + data: providersSchemas, + }); + mockTaskTypes.mockImplementation( + (http: HttpSetup, provider: string) => taskTypesSchemas[provider] + ); + }); + test('openai provider fields are rendered', async () => { + const { getAllByTestId } = render( + + {}} /> + + ); + expect(getAllByTestId('provider-select')[0]).toBeInTheDocument(); + expect(getAllByTestId('provider-select')[0]).toHaveValue('OpenAI'); + + expect(getAllByTestId('url-input')[0]).toBeInTheDocument(); + expect(getAllByTestId('url-input')[0]).toHaveValue(openAiConnector.config?.providerConfig?.url); + expect(getAllByTestId('taskTypeSelectDisabled')[0]).toBeInTheDocument(); + expect(getAllByTestId('taskTypeSelectDisabled')[0]).toHaveTextContent('completion'); + }); + + test('googleaistudio provider fields are rendered', async () => { + const { getAllByTestId } = render( + + {}} /> + + ); + expect(getAllByTestId('api_key-password')[0]).toBeInTheDocument(); + expect(getAllByTestId('api_key-password')[0]).toHaveValue(''); + expect(getAllByTestId('provider-select')[0]).toBeInTheDocument(); + expect(getAllByTestId('provider-select')[0]).toHaveValue('Google AI Studio'); + expect(getAllByTestId('model_id-input')[0]).toBeInTheDocument(); + expect(getAllByTestId('model_id-input')[0]).toHaveValue( + googleaistudioConnector.config?.providerConfig.model_id + ); + }); + + describe('Validation', () => { + const onSubmit = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + jest.spyOn(global.Math, 'random').mockReturnValue(0.123456789); + }); + + it('connector validation succeeds when connector config is valid', async () => { + const { getByTestId } = render( + + {}} /> + + ); + + await userEvent.click(getByTestId('form-test-provide-submit')); + + await waitFor(async () => { + expect(onSubmit).toHaveBeenCalled(); + }); + + expect(onSubmit).toBeCalledWith({ + data: { + config: { + inferenceId: 'openai-completion-4fzzzxjylrx', + ...openAiConnector.config, + }, + actionTypeId: openAiConnector.actionTypeId, + name: openAiConnector.name, + id: openAiConnector.id, + isDeprecated: openAiConnector.isDeprecated, + }, + isValid: true, + }); + }); + + it('validates correctly if the provider config url is empty', async () => { + const connector = { + ...openAiConnector, + config: { + ...openAiConnector.config, + providerConfig: { + url: '', + modelId: 'gpt-4o', + }, + }, + }; + + const res = render( + + {}} /> + + ); + + await userEvent.click(res.getByTestId('form-test-provide-submit')); + await waitFor(async () => { + expect(onSubmit).toHaveBeenCalled(); + }); + + expect(onSubmit).toHaveBeenCalledWith({ data: {}, isValid: false }); + }); + + const tests: Array<[string, string]> = [ + ['url-input', 'not-valid'], + ['api_key-password', ''], + ]; + it.each(tests)('validates correctly %p', async (field, value) => { + const connector = { + ...openAiConnector, + config: { + ...openAiConnector.config, + headers: [], + }, + }; + + const res = render( + + {}} /> + + ); + + await userEvent.type(res.getByTestId(field), `{selectall}{backspace}${value}`, { + delay: 10, + }); + + await userEvent.click(res.getByTestId('form-test-provide-submit')); + await waitFor(async () => { + expect(onSubmit).toHaveBeenCalled(); + }); + + expect(onSubmit).toHaveBeenCalledWith({ data: {}, isValid: false }); + }); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/connector.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/connector.tsx new file mode 100644 index 0000000000000..35314dc06167d --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/connector.tsx @@ -0,0 +1,445 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { useState, useEffect, useCallback } from 'react'; +import { + EuiFormRow, + EuiSpacer, + EuiInputPopover, + EuiFieldText, + EuiFieldTextProps, + EuiSelectableOption, + EuiFormControlLayout, + keys, + EuiHorizontalRule, +} from '@elastic/eui'; +import { + getFieldValidityAndErrorMessage, + UseField, + useFormContext, + useFormData, +} from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; +import { FormattedMessage } from '@kbn/i18n-react'; +import { + ConnectorFormSchema, + type ActionConnectorFieldsProps, +} from '@kbn/triggers-actions-ui-plugin/public'; +import { useKibana } from '@kbn/triggers-actions-ui-plugin/public'; + +import { fieldValidators } from '@kbn/es-ui-shared-plugin/static/forms/helpers'; +import { ServiceProviderKeys } from '../../../common/inference/constants'; +import { ConnectorConfigurationFormItems } from '../lib/dynamic_config/connector_configuration_form_items'; +import { getTaskTypes } from './get_task_types'; +import * as i18n from './translations'; +import { DEFAULT_TASK_TYPE } from './constants'; +import { ConfigEntryView } from '../lib/dynamic_config/types'; +import { SelectableProvider } from './providers/selectable'; +import { Config, Secrets } from './types'; +import { generateInferenceEndpointId, getTaskTypeOptions, TaskTypeOption } from './helpers'; +import { useProviders } from './providers/get_providers'; +import { SERVICE_PROVIDERS } from './providers/render_service_provider/service_provider'; +import { AdditionalOptionsConnectorFields } from './additional_options_fields'; +import { + getProviderConfigHiddenField, + getProviderSecretsHiddenField, + getTaskTypeConfigHiddenField, +} from './hidden_fields'; + +const InferenceAPIConnectorFields: React.FunctionComponent = ({ + readOnly, + isEdit, +}) => { + const { + http, + notifications: { toasts }, + } = useKibana().services; + + const { updateFieldValues, setFieldValue, validateFields, isSubmitting } = useFormContext(); + const [{ config, secrets }] = useFormData>({ + watch: [ + 'secrets.providerSecrets', + 'config.taskType', + 'config.taskTypeConfig', + 'config.inferenceId', + 'config.provider', + 'config.providerConfig', + ], + }); + + const { data: providers, isLoading } = useProviders(http, toasts); + + const [isProviderPopoverOpen, setProviderPopoverOpen] = useState(false); + + const [providerSchema, setProviderSchema] = useState([]); + const [optionalProviderFormFields, setOptionalProviderFormFields] = useState( + [] + ); + const [requiredProviderFormFields, setRequiredProviderFormFields] = useState( + [] + ); + + const [taskTypeSchema, setTaskTypeSchema] = useState([]); + const [taskTypeOptions, setTaskTypeOptions] = useState([]); + const [selectedTaskType, setSelectedTaskType] = useState(DEFAULT_TASK_TYPE); + const [taskTypeFormFields, setTaskTypeFormFields] = useState([]); + + const handleProviderClosePopover = useCallback(() => { + setProviderPopoverOpen(false); + }, []); + + const handleProviderPopover = useCallback(() => { + setProviderPopoverOpen((isOpen) => !isOpen); + }, []); + + const handleProviderKeyboardOpen: EuiFieldTextProps['onKeyDown'] = useCallback((event: any) => { + if (event.key === keys.ENTER) { + setProviderPopoverOpen(true); + } + }, []); + + useEffect(() => { + if (!isEdit && config && !config.inferenceId) { + generateInferenceEndpointId(config, setFieldValue); + } + }, [isEdit, setFieldValue, config]); + + useEffect(() => { + if (isSubmitting) { + validateFields(['config.providerConfig']); + validateFields(['secrets.providerSecrets']); + validateFields(['config.taskTypeConfig']); + } + }, [isSubmitting, config, validateFields]); + + const onTaskTypeOptionsSelect = useCallback( + async (taskType: string, provider?: string) => { + // Get task type settings + const currentTaskTypes = await getTaskTypes(http, provider ?? config?.provider); + const newTaskType = currentTaskTypes?.find((p) => p.task_type === taskType); + + setSelectedTaskType(taskType); + generateInferenceEndpointId(config, setFieldValue); + + // transform the schema + const newTaskTypeSchema = Object.keys(newTaskType?.configuration ?? {}).map((k) => ({ + key: k, + isValid: true, + ...newTaskType?.configuration[k], + })) as ConfigEntryView[]; + setTaskTypeSchema(newTaskTypeSchema); + + const configDefaults = Object.keys(newTaskType?.configuration ?? {}).reduce( + (res: Record, k) => { + if (newTaskType?.configuration[k] && !!newTaskType?.configuration[k].default_value) { + res[k] = newTaskType.configuration[k].default_value; + } else { + res[k] = null; + } + return res; + }, + {} + ); + + updateFieldValues({ + config: { + taskType, + taskTypeConfig: configDefaults, + }, + }); + }, + [config, http, setFieldValue, updateFieldValues] + ); + + const onProviderChange = useCallback( + async (provider?: string) => { + const newProvider = providers?.find((p) => p.provider === provider); + + // Update task types list available for the selected provider + const providerTaskTypes = newProvider?.taskTypes ?? []; + setTaskTypeOptions(getTaskTypeOptions(providerTaskTypes)); + if (providerTaskTypes.length > 0) { + await onTaskTypeOptionsSelect(providerTaskTypes[0], provider); + } + + // Update connector providerSchema + const newProviderSchema = Object.keys(newProvider?.configuration ?? {}).map((k) => ({ + key: k, + isValid: true, + ...newProvider?.configuration[k], + })) as ConfigEntryView[]; + + setProviderSchema(newProviderSchema); + + const defaultProviderConfig: Record = {}; + const defaultProviderSecrets: Record = {}; + + Object.keys(newProvider?.configuration ?? {}).forEach((k) => { + if (!newProvider?.configuration[k].sensitive) { + if (newProvider?.configuration[k] && !!newProvider?.configuration[k].default_value) { + defaultProviderConfig[k] = newProvider.configuration[k].default_value; + } else { + defaultProviderConfig[k] = null; + } + } else { + defaultProviderSecrets[k] = null; + } + }); + + updateFieldValues({ + config: { + provider: newProvider?.provider, + providerConfig: defaultProviderConfig, + }, + secrets: { + providerSecrets: defaultProviderSecrets, + }, + }); + }, + [onTaskTypeOptionsSelect, providers, updateFieldValues] + ); + + useEffect(() => { + const getTaskTypeSchema = async () => { + const currentTaskTypes = await getTaskTypes(http, config?.provider ?? ''); + const newTaskType = currentTaskTypes?.find((p) => p.task_type === config?.taskType); + + // transform the schema + const newTaskTypeSchema = Object.keys(newTaskType?.configuration ?? {}).map((k) => ({ + key: k, + isValid: true, + ...newTaskType?.configuration[k], + })) as ConfigEntryView[]; + + setTaskTypeSchema(newTaskTypeSchema); + }; + + if (config?.provider && isEdit) { + const newProvider = providers?.find((p) => p.provider === config.provider); + // Update connector providerSchema + const newProviderSchema = Object.keys(newProvider?.configuration ?? {}).map((k) => ({ + key: k, + isValid: true, + ...newProvider?.configuration[k], + })) as ConfigEntryView[]; + + setProviderSchema(newProviderSchema); + + getTaskTypeSchema(); + } + }, [config?.provider, config?.taskType, http, isEdit, providers]); + + useEffect(() => { + // Set values from the provider secrets and config to the schema + const existingConfiguration = providerSchema + ? providerSchema.map((item: ConfigEntryView) => { + const itemValue = item; + itemValue.isValid = true; + if (item.sensitive && secrets?.providerSecrets) { + itemValue.value = secrets?.providerSecrets[item.key] as any; + } else if (config?.providerConfig) { + itemValue.value = config?.providerConfig[item.key] as any; + } + return itemValue; + }) + : []; + + existingConfiguration.sort((a, b) => (a.order ?? 0) - (b.order ?? 0)); + setOptionalProviderFormFields(existingConfiguration.filter((p) => !p.required && !p.sensitive)); + setRequiredProviderFormFields(existingConfiguration.filter((p) => p.required || p.sensitive)); + }, [config?.providerConfig, providerSchema, secrets]); + + useEffect(() => { + // Set values from the task type config to the schema + const existingTaskTypeConfiguration = taskTypeSchema + ? taskTypeSchema.map((item: ConfigEntryView) => { + const itemValue = item; + itemValue.isValid = true; + if (config?.taskTypeConfig) { + itemValue.value = config?.taskTypeConfig[item.key] as any; + } + return itemValue; + }) + : []; + existingTaskTypeConfiguration.sort((a, b) => (a.order ?? 0) - (b.order ?? 0)); + setTaskTypeFormFields(existingTaskTypeConfiguration); + }, [config, taskTypeSchema]); + + const getProviderOptions = useCallback(() => { + return providers?.map((p) => ({ + label: p.provider, + key: p.provider, + })) as EuiSelectableOption[]; + }, [providers]); + + const onSetProviderConfigEntry = useCallback( + async (key: string, value: unknown) => { + const entry: ConfigEntryView | undefined = providerSchema.find( + (p: ConfigEntryView) => p.key === key + ); + if (entry) { + if (entry.sensitive) { + if (!secrets.providerSecrets) { + secrets.providerSecrets = {}; + } + const newSecrets = { ...secrets.providerSecrets }; + newSecrets[key] = value; + setFieldValue('secrets.providerSecrets', newSecrets); + await validateFields(['secrets.providerSecrets']); + } else { + if (!config.providerConfig) { + config.providerConfig = {}; + } + const newConfig = { ...config.providerConfig }; + newConfig[key] = value; + setFieldValue('config.providerConfig', newConfig); + await validateFields(['config.providerConfig']); + } + } + }, + [config, providerSchema, secrets, setFieldValue, validateFields] + ); + + const onClearProvider = useCallback(() => { + onProviderChange(); + setFieldValue('config.taskType', ''); + setFieldValue('config.provider', ''); + }, [onProviderChange, setFieldValue]); + + const providerSuperSelect = useCallback( + (isInvalid: boolean) => ( + + + + ), + [ + isEdit, + readOnly, + onClearProvider, + config?.provider, + handleProviderPopover, + handleProviderKeyboardOpen, + isProviderPopoverOpen, + ] + ); + + return ( + <> + + {(field) => { + const { isInvalid, errorMessage } = getFieldValidityAndErrorMessage(field); + const selectInput = providerSuperSelect(isInvalid); + return ( + + } + isInvalid={isInvalid} + error={errorMessage} + > + + + + + ); + }} + + {config?.provider ? ( + <> + + + + + + + {getProviderSecretsHiddenField( + providerSchema, + setRequiredProviderFormFields, + isSubmitting + )} + {getProviderConfigHiddenField( + providerSchema, + setRequiredProviderFormFields, + isSubmitting + )} + {getTaskTypeConfigHiddenField(taskTypeSchema, setTaskTypeFormFields, isSubmitting)} + + ) : null} + + ); +}; + +// eslint-disable-next-line import/no-default-export +export { InferenceAPIConnectorFields as default }; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx new file mode 100644 index 0000000000000..8427caaf49ffc --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { SUB_ACTION } from '../../../common/inference/constants'; + +export const DEFAULT_CHAT_COMPLETE_BODY = { + input: 'What is Elastic?', +}; + +export const DEFAULT_RERANK_BODY = { + input: ['luke', 'like', 'leia', 'chewy', 'r2d2', 'star', 'wars'], + query: 'star wars main character', +}; + +export const DEFAULT_SPARSE_EMBEDDING_BODY = { + input: 'The sky above the port was the color of television tuned to a dead channel.', +}; + +export const DEFAULT_TEXT_EMBEDDING_BODY = { + input: 'The sky above the port was the color of television tuned to a dead channel.', + inputType: 'ingest', +}; + +export const DEFAULTS_BY_TASK_TYPE: Record = { + [SUB_ACTION.COMPLETION]: DEFAULT_CHAT_COMPLETE_BODY, + [SUB_ACTION.RERANK]: DEFAULT_RERANK_BODY, + [SUB_ACTION.SPARSE_EMBEDDING]: DEFAULT_SPARSE_EMBEDDING_BODY, + [SUB_ACTION.TEXT_EMBEDDING]: DEFAULT_TEXT_EMBEDDING_BODY, +}; + +export const DEFAULT_TASK_TYPE = 'completion'; + +export const DEFAULT_PROVIDER = 'elasticsearch'; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/get_task_types.test.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/get_task_types.test.ts new file mode 100644 index 0000000000000..201df82f6412c --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/get_task_types.test.ts @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { httpServiceMock } from '@kbn/core/public/mocks'; +import { DisplayType, FieldType } from '../lib/dynamic_config/types'; +import { getTaskTypes } from './get_task_types'; + +const http = httpServiceMock.createStartContract(); + +beforeEach(() => jest.resetAllMocks()); + +describe.skip('getTaskTypes', () => { + test('should call get inference task types api', async () => { + const apiResponse = { + amazonbedrock: [ + { + task_type: 'completion', + configuration: { + max_new_tokens: { + display: DisplayType.NUMERIC, + label: 'Max new tokens', + order: 1, + required: false, + sensitive: false, + tooltip: 'Sets the maximum number for the output tokens to be generated.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + }; + http.get.mockResolvedValueOnce(apiResponse); + + const result = await getTaskTypes(http, 'amazonbedrock'); + expect(result).toEqual(apiResponse); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/get_task_types.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/get_task_types.ts new file mode 100644 index 0000000000000..a4fbbd6a6288b --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/get_task_types.ts @@ -0,0 +1,606 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { HttpSetup } from '@kbn/core-http-browser'; +import { DisplayType, FieldType } from '../lib/dynamic_config/types'; +import { FieldsConfiguration } from './types'; + +export interface InferenceTaskType { + task_type: string; + configuration: FieldsConfiguration; +} + +// this http param is for the future migrating to real API +export const getTaskTypes = (http: HttpSetup, provider: string): Promise => { + const providersTaskTypes: Record = { + openai: [ + { + task_type: 'completion', + configuration: { + user: { + display: DisplayType.TEXTBOX, + label: 'User', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the user issuing the request.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'text_embedding', + configuration: { + user: { + display: DisplayType.TEXTBOX, + label: 'User', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the user issuing the request.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], + mistral: [ + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + hugging_face: [ + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + googlevertexai: [ + { + task_type: 'text_embedding', + configuration: { + auto_truncate: { + display: DisplayType.TOGGLE, + label: 'Auto truncate', + order: 1, + required: false, + sensitive: false, + tooltip: + 'Specifies if the API truncates inputs longer than the maximum token length automatically.', + type: FieldType.BOOLEAN, + validations: [], + value: false, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'rerank', + configuration: { + top_n: { + display: DisplayType.TOGGLE, + label: 'Top N', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the number of the top n documents, which should be returned.', + type: FieldType.BOOLEAN, + validations: [], + value: false, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], + googleaistudio: [ + { + task_type: 'completion', + configuration: {}, + }, + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + elasticsearch: [ + { + task_type: 'rerank', + configuration: { + return_documents: { + display: DisplayType.TOGGLE, + label: 'Return documents', + options: [], + order: 1, + required: false, + sensitive: false, + tooltip: 'Returns the document instead of only the index.', + type: FieldType.BOOLEAN, + validations: [], + value: true, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'sparse_embedding', + configuration: {}, + }, + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + cohere: [ + { + task_type: 'completion', + configuration: {}, + }, + { + task_type: 'text_embedding', + configuration: { + input_type: { + display: DisplayType.DROPDOWN, + label: 'Input type', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the type of input passed to the model.', + type: FieldType.STRING, + validations: [], + options: [ + { + label: 'classification', + value: 'classification', + }, + { + label: 'clusterning', + value: 'clusterning', + }, + { + label: 'ingest', + value: 'ingest', + }, + { + label: 'search', + value: 'search', + }, + ], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + truncate: { + display: DisplayType.DROPDOWN, + options: [ + { + label: 'NONE', + value: 'NONE', + }, + { + label: 'START', + value: 'START', + }, + { + label: 'END', + value: 'END', + }, + ], + label: 'Truncate', + order: 2, + required: false, + sensitive: false, + tooltip: 'Specifies how the API handles inputs longer than the maximum token length.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'rerank', + configuration: { + return_documents: { + display: DisplayType.TOGGLE, + label: 'Return documents', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specify whether to return doc text within the results.', + type: FieldType.BOOLEAN, + validations: [], + value: false, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + top_n: { + display: DisplayType.NUMERIC, + label: 'Top N', + order: 1, + required: false, + sensitive: false, + tooltip: + 'The number of most relevant documents to return, defaults to the number of the documents.', + type: FieldType.INTEGER, + validations: [], + value: false, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], + azureopenai: [ + { + task_type: 'completion', + configuration: { + user: { + display: DisplayType.TEXTBOX, + label: 'User', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the user issuing the request.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'text_embedding', + configuration: { + user: { + display: DisplayType.TEXTBOX, + label: 'User', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the user issuing the request.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], + azureaistudio: [ + { + task_type: 'completion', + configuration: { + user: { + display: DisplayType.TEXTBOX, + label: 'User', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the user issuing the request.', + type: FieldType.STRING, + validations: [], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'text_embedding', + configuration: { + do_sample: { + display: DisplayType.NUMERIC, + label: 'Do sample', + order: 1, + required: false, + sensitive: false, + tooltip: 'Instructs the inference process to perform sampling or not.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + max_new_tokens: { + display: DisplayType.NUMERIC, + label: 'Max new tokens', + order: 1, + required: false, + sensitive: false, + tooltip: 'Provides a hint for the maximum number of output tokens to be generated.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + temperature: { + display: DisplayType.NUMERIC, + label: 'Temperature', + order: 1, + required: false, + sensitive: false, + tooltip: 'A number in the range of 0.0 to 2.0 that specifies the sampling temperature.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + top_p: { + display: DisplayType.NUMERIC, + label: 'Top P', + order: 1, + required: false, + sensitive: false, + tooltip: + 'A number in the range of 0.0 to 2.0 that is an alternative value to temperature. Should not be used if temperature is specified.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], + amazonbedrock: [ + { + task_type: 'completion', + configuration: { + max_new_tokens: { + display: DisplayType.NUMERIC, + label: 'Max new tokens', + order: 1, + required: false, + sensitive: false, + tooltip: 'Sets the maximum number for the output tokens to be generated.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + temperature: { + display: DisplayType.NUMERIC, + label: 'Temperature', + order: 1, + required: false, + sensitive: false, + tooltip: + 'A number between 0.0 and 1.0 that controls the apparent creativity of the results.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + top_p: { + display: DisplayType.NUMERIC, + label: 'Top P', + order: 1, + required: false, + sensitive: false, + tooltip: + 'Alternative to temperature. A number in the range of 0.0 to 1.0, to eliminate low-probability tokens.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + top_k: { + display: DisplayType.NUMERIC, + label: 'Top K', + order: 1, + required: false, + sensitive: false, + tooltip: + 'Only available for anthropic, cohere, and mistral providers. Alternative to temperature.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + anthropic: [ + { + task_type: 'completion', + configuration: { + max_tokens: { + display: DisplayType.NUMERIC, + label: 'Max tokens', + order: 1, + required: true, + sensitive: false, + tooltip: 'The maximum number of tokens to generate before stopping.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + temperature: { + display: DisplayType.TEXTBOX, + label: 'Temperature', + order: 2, + required: false, + sensitive: false, + tooltip: 'The amount of randomness injected into the response.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + top_p: { + display: DisplayType.NUMERIC, + label: 'Top P', + order: 4, + required: false, + sensitive: false, + tooltip: 'Specifies to use Anthropic’s nucleus sampling.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + top_k: { + display: DisplayType.NUMERIC, + label: 'Top K', + order: 3, + required: false, + sensitive: false, + tooltip: 'Specifies to only sample from the top K options for each subsequent token.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ], + 'alibabacloud-ai-search': [ + { + task_type: 'text_embedding', + configuration: { + input_type: { + display: DisplayType.DROPDOWN, + label: 'Input type', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the type of input passed to the model.', + type: FieldType.STRING, + validations: [], + options: [ + { + label: 'ingest', + value: 'ingest', + }, + { + label: 'search', + value: 'search', + }, + ], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'sparse_embedding', + configuration: { + input_type: { + display: DisplayType.DROPDOWN, + label: 'Input type', + order: 1, + required: false, + sensitive: false, + tooltip: 'Specifies the type of input passed to the model.', + type: FieldType.STRING, + validations: [], + options: [ + { + label: 'ingest', + value: 'ingest', + }, + { + label: 'search', + value: 'search', + }, + ], + value: '', + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + return_token: { + display: DisplayType.TOGGLE, + label: 'Return token', + options: [], + order: 1, + required: false, + sensitive: false, + tooltip: + 'If `true`, the token name will be returned in the response. Defaults to `false` which means only the token ID will be returned in the response.', + type: FieldType.BOOLEAN, + validations: [], + value: true, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + task_type: 'completion', + configuration: {}, + }, + { + task_type: 'rerank', + configuration: {}, + }, + ], + watsonxai: [ + { + task_type: 'text_embedding', + configuration: {}, + }, + ], + }; + return Promise.resolve(providersTaskTypes[provider]); +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/helpers.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/helpers.ts new file mode 100644 index 0000000000000..0e1e4cdaa41ad --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/helpers.ts @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { isEmpty } from 'lodash/fp'; +import { ValidationFunc } from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; +import { ConfigEntryView } from '../lib/dynamic_config/types'; +import { Config } from './types'; +import * as i18n from './translations'; + +export interface TaskTypeOption { + id: string; + value: string; + label: string; +} + +export const getTaskTypeOptions = (taskTypes: string[]): TaskTypeOption[] => + taskTypes.map((taskType) => ({ + id: taskType, + label: taskType, + value: taskType, + })); + +export const generateInferenceEndpointId = ( + config: Config, + setFieldValue: (fieldName: string, value: unknown) => void +) => { + const taskTypeSuffix = config.taskType ? `${config.taskType}-` : ''; + const inferenceEndpointId = `${config.provider}-${taskTypeSuffix}${Math.random() + .toString(36) + .slice(2)}`; + config.inferenceId = inferenceEndpointId; + setFieldValue('config.inferenceId', inferenceEndpointId); +}; + +export const getNonEmptyValidator = ( + schema: ConfigEntryView[], + validationEventHandler: (fieldsWithErrors: ConfigEntryView[]) => void, + isSubmitting: boolean = false, + isSecrets: boolean = false +) => { + return (...args: Parameters): ReturnType => { + const [{ value, path }] = args; + const newSchema: ConfigEntryView[] = []; + + const configData = (value ?? {}) as Record; + let hasErrors = false; + if (schema) { + schema + .filter((f: ConfigEntryView) => f.required) + .forEach((field: ConfigEntryView) => { + // validate if submitting or on field edit - value is not default to null + if (configData[field.key] !== null || isSubmitting) { + // validate secrets fields separately from regular + if (isSecrets ? field.sensitive : !field.sensitive) { + if ( + !configData[field.key] || + (typeof configData[field.key] === 'string' && isEmpty(configData[field.key])) + ) { + field.validationErrors = [i18n.getRequiredMessage(field.label)]; + field.isValid = false; + hasErrors = true; + } else { + field.validationErrors = []; + field.isValid = true; + } + } + } + newSchema.push(field); + }); + + validationEventHandler(newSchema.sort((a, b) => (a.order ?? 0) - (b.order ?? 0))); + if (hasErrors) { + return { + code: 'ERR_FIELD_MISSING', + path, + message: i18n.getRequiredMessage('Action'), + }; + } + } + }; +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/hidden_fields.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/hidden_fields.tsx new file mode 100644 index 0000000000000..9b28d35aaaf3a --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/hidden_fields.tsx @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; +import { HiddenField } from '@kbn/es-ui-shared-plugin/static/forms/components'; +import { UseField } from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; +import { getNonEmptyValidator } from './helpers'; +import { ConfigEntryView } from '../lib/dynamic_config/types'; + +export const getProviderSecretsHiddenField = ( + providerSchema: ConfigEntryView[], + setRequiredProviderFormFields: React.Dispatch>, + isSubmitting: boolean +) => ( + +); + +export const getProviderConfigHiddenField = ( + providerSchema: ConfigEntryView[], + setRequiredProviderFormFields: React.Dispatch>, + isSubmitting: boolean +) => ( + +); + +export const getTaskTypeConfigHiddenField = ( + taskTypeSchema: ConfigEntryView[], + setTaskTypeFormFields: React.Dispatch>, + isSubmitting: boolean +) => ( + { + const formFields = [ + ...requiredFormFields, + ...(taskTypeSchema ?? []).filter((f) => !f.required), + ]; + setTaskTypeFormFields(formFields.sort((a, b) => (a.order ?? 0) - (b.order ?? 0))); + }, + isSubmitting + ), + isBlocking: true, + }, + ], + }} + /> +); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/index.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/index.ts new file mode 100644 index 0000000000000..cdac663d16b89 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { getConnectorType as getInferenceConnectorType } from './inference'; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx new file mode 100644 index 0000000000000..0f37564fd560c --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { TypeRegistry } from '@kbn/triggers-actions-ui-plugin/public/application/type_registry'; +import { registerConnectorTypes } from '..'; +import type { ActionTypeModel } from '@kbn/triggers-actions-ui-plugin/public/types'; +import { experimentalFeaturesMock, registrationServicesMock } from '../../mocks'; +import { SUB_ACTION } from '../../../common/inference/constants'; +import { ExperimentalFeaturesService } from '../../common/experimental_features_service'; + +const ACTION_TYPE_ID = '.inference'; +let actionTypeModel: ActionTypeModel; + +beforeAll(() => { + ExperimentalFeaturesService.init({ experimentalFeatures: experimentalFeaturesMock }); + const connectorTypeRegistry = new TypeRegistry(); + registerConnectorTypes({ connectorTypeRegistry, services: registrationServicesMock }); + const getResult = connectorTypeRegistry.get(ACTION_TYPE_ID); + if (getResult !== null) { + actionTypeModel = getResult; + } +}); + +describe('actionTypeRegistry.get() works', () => { + test('connector type static data is as expected', () => { + expect(actionTypeModel.id).toEqual(ACTION_TYPE_ID); + expect(actionTypeModel.selectMessage).toBe( + 'Send requests to AI providers such as Amazon Bedrock, OpenAI and more.' + ); + expect(actionTypeModel.actionTypeTitle).toBe('AI Connector'); + }); +}); + +describe('OpenAI action params validation', () => { + test.each([ + { + subAction: SUB_ACTION.RERANK, + subActionParams: { input: ['message test'], query: 'foobar' }, + }, + { + subAction: SUB_ACTION.COMPLETION, + subActionParams: { input: 'message test' }, + }, + { + subAction: SUB_ACTION.TEXT_EMBEDDING, + subActionParams: { input: 'message test', inputType: 'foobar' }, + }, + { + subAction: SUB_ACTION.SPARSE_EMBEDDING, + subActionParams: { input: 'message test' }, + }, + ])( + 'validation succeeds when params are valid for subAction $subAction', + async ({ subAction, subActionParams }) => { + const actionParams = { + subAction, + subActionParams, + }; + expect(await actionTypeModel.validateParams(actionParams)).toEqual({ + errors: { input: [], subAction: [], inputType: [], query: [] }, + }); + } + ); + + test('params validation fails when params is a wrong object', async () => { + const actionParams = { + subAction: SUB_ACTION.COMPLETION, + subActionParams: { body: 'message {test}' }, + }; + + expect(await actionTypeModel.validateParams(actionParams)).toEqual({ + errors: { input: ['Input is required.'], inputType: [], query: [], subAction: [] }, + }); + }); + + test('params validation fails when subAction is missing', async () => { + const actionParams = { + subActionParams: { input: 'message test' }, + }; + + expect(await actionTypeModel.validateParams(actionParams)).toEqual({ + errors: { + input: [], + inputType: [], + query: [], + subAction: ['Action is required.'], + }, + }); + }); + + test('params validation fails when subAction is not in the list of the supported', async () => { + const actionParams = { + subAction: 'wrong', + subActionParams: { input: 'message test' }, + }; + + expect(await actionTypeModel.validateParams(actionParams)).toEqual({ + errors: { + input: [], + inputType: [], + query: [], + subAction: ['Invalid action name.'], + }, + }); + }); + + test('params validation fails when subActionParams is missing', async () => { + const actionParams = { + subAction: SUB_ACTION.RERANK, + subActionParams: {}, + }; + + expect(await actionTypeModel.validateParams(actionParams)).toEqual({ + errors: { + input: ['Input is required.', 'Input does not have a valid Array format.'], + inputType: [], + query: ['Query is required.'], + subAction: [], + }, + }); + }); + + test('params validation fails when text_embedding inputType is missing', async () => { + const actionParams = { + subAction: SUB_ACTION.TEXT_EMBEDDING, + subActionParams: { input: 'message test' }, + }; + + expect(await actionTypeModel.validateParams(actionParams)).toEqual({ + errors: { + input: [], + inputType: ['Input type is required.'], + query: [], + subAction: [], + }, + }); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx new file mode 100644 index 0000000000000..e16d03306c166 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { lazy } from 'react'; +import { i18n } from '@kbn/i18n'; +import type { GenericValidationResult } from '@kbn/triggers-actions-ui-plugin/public/types'; +import { RerankParams, TextEmbeddingParams } from '../../../common/inference/types'; +import { SUB_ACTION } from '../../../common/inference/constants'; +import { + INFERENCE_CONNECTOR_ID, + INFERENCE_CONNECTOR_TITLE, +} from '../../../common/inference/constants'; +import { InferenceActionParams, InferenceConnector } from './types'; + +interface ValidationErrors { + subAction: string[]; + input: string[]; + // rerank only + query: string[]; + // text_embedding only + inputType: string[]; +} +export function getConnectorType(): InferenceConnector { + return { + id: INFERENCE_CONNECTOR_ID, + iconClass: 'sparkles', + isExperimental: true, + selectMessage: i18n.translate('xpack.stackConnectors.components.inference.selectMessageText', { + defaultMessage: 'Send requests to AI providers such as Amazon Bedrock, OpenAI and more.', + }), + actionTypeTitle: INFERENCE_CONNECTOR_TITLE, + validateParams: async ( + actionParams: InferenceActionParams + ): Promise> => { + const { subAction, subActionParams } = actionParams; + const translations = await import('./translations'); + const errors: ValidationErrors = { + input: [], + subAction: [], + inputType: [], + query: [], + }; + + if ( + subAction === SUB_ACTION.RERANK || + subAction === SUB_ACTION.COMPLETION || + subAction === SUB_ACTION.TEXT_EMBEDDING || + subAction === SUB_ACTION.SPARSE_EMBEDDING + ) { + if (!subActionParams.input?.length) { + errors.input.push(translations.getRequiredMessage('Input')); + } + } + if (subAction === SUB_ACTION.RERANK) { + if (!Array.isArray(subActionParams.input)) { + errors.input.push(translations.INPUT_INVALID); + } + + if (!(subActionParams as RerankParams).query?.length) { + errors.query.push(translations.getRequiredMessage('Query')); + } + } + if (subAction === SUB_ACTION.TEXT_EMBEDDING) { + if (!(subActionParams as TextEmbeddingParams).inputType?.length) { + errors.inputType.push(translations.getRequiredMessage('Input type')); + } + } + if (errors.input.length) return { errors }; + + // The internal "subAction" param should always be valid, ensure it is only if "subActionParams" are valid + if (!subAction) { + errors.subAction.push(translations.getRequiredMessage('Action')); + } else if ( + ![ + SUB_ACTION.COMPLETION, + SUB_ACTION.SPARSE_EMBEDDING, + SUB_ACTION.RERANK, + SUB_ACTION.TEXT_EMBEDDING, + ].includes(subAction) + ) { + errors.subAction.push(translations.INVALID_ACTION); + } + return { errors }; + }, + actionConnectorFields: lazy(() => import('./connector')), + actionParamsFields: lazy(() => import('./params')), + }; +} diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx new file mode 100644 index 0000000000000..49773edc2246a --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; +import { fireEvent, render } from '@testing-library/react'; +import ParamsFields from './params'; +import { SUB_ACTION } from '../../../common/inference/constants'; + +describe('Inference Params Fields renders', () => { + test('all params fields are rendered', () => { + const { getByTestId } = render( + {}} + index={0} + /> + ); + expect(getByTestId('inferenceInput')).toBeInTheDocument(); + expect(getByTestId('inferenceInput')).toHaveProperty('value', 'What is Elastic?'); + }); + + test.each(['openai', 'googleaistudio'])( + 'useEffect handles the case when subAction and subActionParams are undefined and provider is %p', + (provider) => { + const actionParams = { + subAction: undefined, + subActionParams: undefined, + }; + const editAction = jest.fn(); + const errors = {}; + const actionConnector = { + secrets: { + providerSecrets: { apiKey: 'apiKey' }, + }, + id: 'test', + actionTypeId: '.inference', + isPreconfigured: false, + isSystemAction: false as const, + isDeprecated: false, + name: 'My OpenAI Connector', + config: { + provider, + providerConfig: { + url: 'https://api.openai.com/v1/embeddings', + }, + taskType: 'completion', + }, + }; + render( + + ); + expect(editAction).toHaveBeenCalledTimes(2); + expect(editAction).toHaveBeenCalledWith('subAction', SUB_ACTION.COMPLETION, 0); + if (provider === 'openai') { + expect(editAction).toHaveBeenCalledWith( + 'subActionParams', + { input: 'What is Elastic?' }, + 0 + ); + } + if (provider === 'googleaistudio') { + expect(editAction).toHaveBeenCalledWith( + 'subActionParams', + { input: 'What is Elastic?' }, + 0 + ); + } + } + ); + + it('handles the case when subAction only is undefined', () => { + const actionParams = { + subAction: undefined, + subActionParams: { + input: '{"key": "value"}', + }, + }; + const editAction = jest.fn(); + const errors = {}; + render( + + ); + expect(editAction).toHaveBeenCalledTimes(1); + expect(editAction).toHaveBeenCalledWith('subAction', SUB_ACTION.COMPLETION, 0); + }); + + it('calls editAction function with the correct arguments ', () => { + const editAction = jest.fn(); + const errors = {}; + const { getByTestId } = render( + + ); + const jsonEditor = getByTestId('inputJsonEditor'); + fireEvent.change(jsonEditor, { target: { value: `[\"apple\",\"banana\",\"tomato\"]` } }); + expect(editAction).toHaveBeenCalledWith( + 'subActionParams', + { input: '["apple","banana","tomato"]', query: 'test' }, + 0 + ); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx new file mode 100644 index 0000000000000..0013f943e3639 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { useCallback, useEffect } from 'react'; +import { + JsonEditorWithMessageVariables, + type ActionParamsProps, +} from '@kbn/triggers-actions-ui-plugin/public'; +import { EuiTextArea, EuiFormRow, EuiSpacer, EuiSelect } from '@elastic/eui'; +import { RuleFormParamsErrors } from '@kbn/alerts-ui-shared'; +import { + ChatCompleteParams, + RerankParams, + SparseEmbeddingParams, + TextEmbeddingParams, +} from '../../../common/inference/types'; +import { DEFAULTS_BY_TASK_TYPE } from './constants'; +import * as i18n from './translations'; +import { SUB_ACTION } from '../../../common/inference/constants'; +import { InferenceActionConnector, InferenceActionParams } from './types'; + +const InferenceServiceParamsFields: React.FunctionComponent< + ActionParamsProps +> = ({ actionParams, editAction, index, errors, actionConnector }) => { + const { subAction, subActionParams } = actionParams; + + const { taskType } = (actionConnector as unknown as InferenceActionConnector).config; + + useEffect(() => { + if (!subAction) { + editAction('subAction', taskType, index); + } + }, [editAction, index, subAction, taskType]); + + useEffect(() => { + if (!subActionParams) { + editAction( + 'subActionParams', + { + ...(DEFAULTS_BY_TASK_TYPE[taskType] ?? {}), + }, + index + ); + } + }, [editAction, index, subActionParams, taskType]); + + const editSubActionParams = useCallback( + (params: Partial) => { + editAction('subActionParams', { ...subActionParams, ...params }, index); + }, + [editAction, index, subActionParams] + ); + + if (subAction === SUB_ACTION.COMPLETION) { + return ( + + ); + } + + if (subAction === SUB_ACTION.RERANK) { + return ( + + ); + } + + if (subAction === SUB_ACTION.SPARSE_EMBEDDING) { + return ( + + ); + } + + if (subAction === SUB_ACTION.TEXT_EMBEDDING) { + return ( + + ); + } + + return <>; +}; + +const InferenceInput: React.FunctionComponent<{ + input?: string; + inputError?: string; + editSubActionParams: (params: Partial) => void; +}> = ({ input, inputError, editSubActionParams }) => { + return ( + + { + editSubActionParams({ input: e.target.value }); + }} + isInvalid={false} + fullWidth={true} + /> + + ); +}; + +const CompletionParamsFields: React.FunctionComponent<{ + subActionParams: ChatCompleteParams; + errors: RuleFormParamsErrors; + editSubActionParams: (params: Partial) => void; +}> = ({ subActionParams, editSubActionParams, errors }) => { + const { input } = subActionParams; + + return ( + + ); +}; + +const SparseEmbeddingParamsFields: React.FunctionComponent<{ + subActionParams: SparseEmbeddingParams; + errors: RuleFormParamsErrors; + editSubActionParams: (params: Partial) => void; +}> = ({ subActionParams, editSubActionParams, errors }) => { + const { input } = subActionParams; + + return ( + + ); +}; + +const TextEmbeddingParamsFields: React.FunctionComponent<{ + subActionParams: TextEmbeddingParams; + errors: RuleFormParamsErrors; + editSubActionParams: (params: Partial) => void; +}> = ({ subActionParams, editSubActionParams, errors }) => { + const { input, inputType } = subActionParams; + + const options = [ + { value: 'ingest', text: 'ingest' }, + { value: 'search', text: 'search' }, + ]; + + return ( + <> + + { + editSubActionParams({ inputType: e.target.value }); + }} + /> + + + + + ); +}; + +const RerankParamsFields: React.FunctionComponent<{ + subActionParams: RerankParams; + errors: RuleFormParamsErrors; + editSubActionParams: (params: Partial) => void; +}> = ({ subActionParams, editSubActionParams, errors }) => { + const { input, query } = subActionParams; + + return ( + <> + { + editSubActionParams({ input: json.trim() }); + }} + onBlur={() => { + if (!input) { + editSubActionParams({ input: [] }); + } + }} + dataTestSubj="inference-inputJsonEditor" + /> + + + { + editSubActionParams({ query: e.target.value }); + }} + isInvalid={false} + fullWidth={true} + /> + + + ); +}; + +// eslint-disable-next-line import/no-default-export +export { InferenceServiceParamsFields as default }; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/alibaba_cloud.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/alibaba_cloud.svg new file mode 100644 index 0000000000000..1ae552d509c3a --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/alibaba_cloud.svg @@ -0,0 +1,3 @@ + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/amazon_bedrock.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/amazon_bedrock.svg new file mode 100644 index 0000000000000..f8815d4f75ec5 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/amazon_bedrock.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/anthropic.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/anthropic.svg new file mode 100644 index 0000000000000..c361cda86a7df --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/anthropic.svg @@ -0,0 +1,3 @@ + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/azure_ai_studio.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/azure_ai_studio.svg new file mode 100644 index 0000000000000..405e182a10394 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/azure_ai_studio.svg @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/azure_open_ai.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/azure_open_ai.svg new file mode 100644 index 0000000000000..122c0c65af13c --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/azure_open_ai.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/cohere.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/cohere.svg new file mode 100644 index 0000000000000..69953809fec35 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/cohere.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/elastic.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/elastic.svg new file mode 100644 index 0000000000000..e763c2e2f2ab6 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/elastic.svg @@ -0,0 +1,16 @@ + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/google_ai_studio.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/google_ai_studio.svg new file mode 100644 index 0000000000000..b6e34ae15c9e4 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/google_ai_studio.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/hugging_face.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/hugging_face.svg new file mode 100644 index 0000000000000..87ac70c5a18f4 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/hugging_face.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/ibm_watsonx.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/ibm_watsonx.svg new file mode 100644 index 0000000000000..5883eff3884d6 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/ibm_watsonx.svg @@ -0,0 +1,3 @@ + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/mistral.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/mistral.svg new file mode 100644 index 0000000000000..f62258a327594 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/mistral.svg @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/open_ai.svg b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/open_ai.svg new file mode 100644 index 0000000000000..9ddc8f8fd63b8 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/assets/images/open_ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.test.tsx new file mode 100644 index 0000000000000..06548b36a635a --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.test.tsx @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import React from 'react'; +import * as ReactQuery from '@tanstack/react-query'; +import { renderHook } from '@testing-library/react-hooks/dom'; +import { waitFor } from '@testing-library/react'; +import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; +import { httpServiceMock, notificationServiceMock } from '@kbn/core/public/mocks'; +import { useProviders } from './get_providers'; + +const http = httpServiceMock.createStartContract(); +const toasts = notificationServiceMock.createStartContract(); +const useQuerySpy = jest.spyOn(ReactQuery, 'useQuery'); + +beforeEach(() => jest.resetAllMocks()); + +const { getProviders } = jest.requireMock('./get_providers'); + +const queryClient = new QueryClient(); + +const wrapper = ({ children }: { children: React.ReactNode }) => ( + {children} +); + +describe('useProviders', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should call useQuery', async () => { + renderHook(() => useProviders(http, toasts.toasts), { + wrapper, + }); + + await waitFor(() => { + return expect(useQuerySpy).toBeCalled(); + }); + }); + + it('should return isError = true if api fails', async () => { + getProviders.mockResolvedValue('This is an error.'); + + renderHook(() => useProviders(http, toasts.toasts), { + wrapper, + }); + + await waitFor(() => expect(useQuerySpy).toHaveBeenCalled()); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.ts new file mode 100644 index 0000000000000..109266c1273fc --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/get_providers.ts @@ -0,0 +1,1054 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { HttpSetup } from '@kbn/core-http-browser'; +import { i18n } from '@kbn/i18n'; +import { useQuery } from '@tanstack/react-query'; +import type { ToastsStart } from '@kbn/core-notifications-browser'; +import { DisplayType, FieldType } from '../../lib/dynamic_config/types'; +import { FieldsConfiguration } from '../types'; + +export interface InferenceProvider { + provider: string; + taskTypes: string[]; + logo?: string; + configuration: FieldsConfiguration; +} + +export const getProviders = (http: HttpSetup): Promise => { + const providers = [ + { + provider: 'openai', + logo: '', // should be openai logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 3, + required: true, + sensitive: true, + tooltip: `The OpenAI API authentication key. For more details about generating OpenAI API keys, refer to the https://platform.openai.com/account/api-keys.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 2, + required: true, + sensitive: false, + tooltip: 'The name of the model to use for the inference task.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + organization_id: { + display: DisplayType.TEXTBOX, + label: 'Organization ID', + order: 4, + required: false, + sensitive: false, + tooltip: 'The unique identifier of your organization.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + url: { + display: DisplayType.TEXTBOX, + label: 'URL', + order: 1, + required: true, + sensitive: false, + tooltip: + 'The OpenAI API endpoint URL. For more information on the URL, refer to the https://platform.openai.com/docs/api-reference.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: 'https://api.openai.com/v1/chat/completions', + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: + 'Default number of requests allowed per minute. For text_embedding is 3000. For completion is 500.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'googleaistudio', + logo: '', // should be googleaistudio logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 2, + required: true, + sensitive: false, + tooltip: `ID of the LLM you're using`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'amazonbedrock', + logo: '', // should be amazonbedrock logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + access_key: { + display: DisplayType.TEXTBOX, + label: 'Access Key', + order: 1, + required: true, + sensitive: true, + tooltip: `A valid AWS access key that has permissions to use Amazon Bedrock.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + secret_key: { + display: DisplayType.TEXTBOX, + label: 'Secret Key', + order: 2, + required: true, + sensitive: true, + tooltip: `A valid AWS secret key that is paired with the access_key.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + provider: { + display: DisplayType.DROPDOWN, + label: 'Provider', + order: 3, + required: true, + options: [ + { + label: 'amazontitan', + value: 'amazontitan', + }, + { + label: 'anthropic', + value: 'anthropic', + }, + { + label: 'ai21labs', + value: 'ai21labs', + }, + { + label: 'cohere', + value: 'cohere', + }, + { + label: 'meta', + value: 'meta', + }, + { + label: 'mistral', + value: 'mistral', + }, + ], + sensitive: false, + tooltip: 'The model provider for your deployment.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model: { + display: DisplayType.TEXTBOX, + label: 'Model', + order: 4, + required: true, + sensitive: false, + tooltip: `The base model ID or an ARN to a custom model based on a foundational model.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + region: { + display: DisplayType.TEXTBOX, + label: 'Region', + order: 5, + required: true, + sensitive: false, + tooltip: `The region that your model or ARN is deployed in.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 6, + required: false, + sensitive: false, + tooltip: + 'By default, the amazonbedrock service sets the number of requests allowed per minute to 240.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'googlevertexai', + logo: '', // should be googlevertexai logo here, the hardcoded uses assets/images + taskTypes: ['text_embedding', 'rerank'], + configuration: { + service_account_json: { + display: DisplayType.TEXTBOX, + label: 'Credentials JSON', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 2, + required: true, + sensitive: false, + tooltip: `ID of the LLM you're using`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + location: { + display: DisplayType.TEXTBOX, + label: 'GCP Region', + order: 2, + required: true, + sensitive: false, + tooltip: `Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + project_id: { + display: DisplayType.TEXTBOX, + label: 'GCP Project', + order: 3, + required: true, + sensitive: false, + tooltip: + 'The GCP Project ID which has Vertex AI API(s) enabled. For more information on the URL, refer to the {geminiVertexAIDocs}.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'mistral', + logo: '', // should be misral logo here, the hardcoded uses assets/images + taskTypes: ['text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model: { + display: DisplayType.TEXTBOX, + label: 'Model', + order: 2, + required: true, + sensitive: false, + tooltip: `Refer to the Mistral models documentation for the list of available text embedding models`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 4, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: 240, + depends_on: [], + }, + max_input_tokens: { + display: DisplayType.NUMERIC, + label: 'Maximum input tokens', + order: 3, + required: false, + sensitive: false, + tooltip: 'Allows you to specify the maximum number of tokens per input.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'hugging_face', + logo: '', // should be hugging_face logo here, the hardcoded uses assets/images + taskTypes: ['text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 2, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + url: { + display: DisplayType.TEXTBOX, + label: 'URL', + order: 1, + required: true, + sensitive: false, + tooltip: 'The URL endpoint to use for the requests.', + type: FieldType.STRING, + validations: [], + value: 'https://api.openai.com/v1/embeddings', + ui_restrictions: [], + default_value: 'https://api.openai.com/v1/embeddings', + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 3, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'elasticsearch', + logo: '', // elasticsearch logo here + taskTypes: ['sparse_embedding', 'text_embedding', 'rerank'], + configuration: { + model_id: { + display: DisplayType.DROPDOWN, + label: 'Model ID', + order: 1, + required: true, + sensitive: false, + tooltip: `The name of the model to use for the inference task.`, + type: FieldType.STRING, + validations: [], + options: [ + { + label: '.elser_model_1', + value: '.elser_model_1', + }, + { + label: '.elser_model_2', + value: '.elser_model_2', + }, + { + label: '.elser_model_2_linux-x86_64', + value: '.elser_model_2_linux-x86_64', + }, + { + label: '.multilingual-e5-small', + value: '.multilingual-e5-small', + }, + { + label: '.multilingual-e5-small_linux-x86_64', + value: '.multilingual-e5-small_linux-x86_64', + }, + ], + value: null, + ui_restrictions: [], + default_value: '.multilingual-e5-small', + depends_on: [], + }, + num_allocations: { + display: DisplayType.NUMERIC, + label: 'Number allocations', + order: 2, + required: true, + sensitive: false, + tooltip: + 'The total number of allocations this model is assigned across machine learning nodes.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: 1, + depends_on: [], + }, + num_threads: { + display: DisplayType.NUMERIC, + label: 'Number threads', + order: 3, + required: true, + sensitive: false, + tooltip: 'Sets the number of threads used by each model allocation during inference.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: 2, + depends_on: [], + }, + }, + }, + { + provider: 'cohere', + logo: '', // should be cohere logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding', 'rerank'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'azureopenai', + logo: '', // should be azureopenai logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: false, + sensitive: true, + tooltip: `You must provide either an API key or an Entra ID.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + entra_id: { + display: DisplayType.TEXTBOX, + label: 'Entra ID', + order: 2, + required: false, + sensitive: true, + tooltip: `You must provide either an API key or an Entra ID.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + resource_name: { + display: DisplayType.TEXTBOX, + label: 'Resource Name', + order: 3, + required: true, + sensitive: false, + tooltip: `The name of your Azure OpenAI resource`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + api_version: { + display: DisplayType.TEXTBOX, + label: 'API version', + order: 4, + required: true, + sensitive: false, + tooltip: 'The Azure API version ID to use.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + deployment_id: { + display: DisplayType.TEXTBOX, + label: 'Deployment ID', + order: 5, + required: true, + sensitive: false, + tooltip: 'The deployment name of your deployed models.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: + 'The azureopenai service sets a default number of requests allowed per minute depending on the task type.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'azureaistudio', + logo: '', // should be azureaistudio logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'text_embedding'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + target: { + display: DisplayType.TEXTBOX, + label: 'Target', + order: 2, + required: true, + sensitive: false, + tooltip: `The target URL of your Azure AI Studio model deployment.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + endpoint_type: { + display: DisplayType.DROPDOWN, + label: 'Endpoint type', + order: 3, + required: true, + sensitive: false, + tooltip: 'Specifies the type of endpoint that is used in your model deployment.', + type: FieldType.STRING, + options: [ + { + label: 'token', + value: 'token', + }, + { + label: 'realtime', + value: 'realtime', + }, + ], + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + provider: { + display: DisplayType.DROPDOWN, + label: 'Provider', + order: 3, + required: true, + options: [ + { + label: 'cohere', + value: 'cohere', + }, + { + label: 'meta', + value: 'meta', + }, + { + label: 'microsoft_phi', + value: 'microsoft_phi', + }, + { + label: 'mistral', + value: 'mistral', + }, + { + label: 'openai', + value: 'openai', + }, + { + label: 'databricks', + value: 'databricks', + }, + ], + sensitive: false, + tooltip: 'The model provider for your deployment.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'anthropic', + logo: '', // should be anthropic logo here, the hardcoded uses assets/images + taskTypes: ['completion'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: `API Key for the provider you're connecting to`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 2, + required: true, + sensitive: false, + tooltip: `The name of the model to use for the inference task.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 5, + required: false, + sensitive: false, + tooltip: + 'By default, the anthropic service sets the number of requests allowed per minute to 50.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'watsonxai', + logo: '', // should be anthropic logo here, the hardcoded uses assets/images + taskTypes: ['text_embedding'], + configuration: { + api_version: { + display: DisplayType.TEXTBOX, + label: 'API version', + order: 1, + required: true, + sensitive: false, + tooltip: 'The IBM Watsonx API version ID to use.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + project_id: { + display: DisplayType.TEXTBOX, + label: 'Project ID', + order: 2, + required: true, + sensitive: false, + tooltip: '', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + model_id: { + display: DisplayType.TEXTBOX, + label: 'Model ID', + order: 3, + required: true, + sensitive: false, + tooltip: `The name of the model to use for the inference task.`, + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + url: { + display: DisplayType.TEXTBOX, + label: 'URL', + order: 4, + required: true, + sensitive: false, + tooltip: '', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + max_input_tokens: { + display: DisplayType.NUMERIC, + label: 'Maximum input tokens', + order: 5, + required: false, + sensitive: false, + tooltip: 'Allows you to specify the maximum number of tokens per input.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + { + provider: 'alibabacloud-ai-search', + logo: '', // should be anthropic logo here, the hardcoded uses assets/images + taskTypes: ['completion', 'sparse_embedding', 'text_embedding', 'rerank'], + configuration: { + api_key: { + display: DisplayType.TEXTBOX, + label: 'API Key', + order: 1, + required: true, + sensitive: true, + tooltip: 'A valid API key for the AlibabaCloud AI Search API.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + service_id: { + display: DisplayType.DROPDOWN, + label: 'Project ID', + order: 2, + required: true, + sensitive: false, + tooltip: 'The name of the model service to use for the {infer} task.', + type: FieldType.STRING, + options: [ + { + label: 'ops-text-embedding-001', + value: 'ops-text-embedding-001', + }, + { + label: 'ops-text-embedding-zh-001', + value: 'ops-text-embedding-zh-001', + }, + { + label: 'ops-text-embedding-en-001', + value: 'ops-text-embedding-en-001', + }, + { + label: 'ops-text-embedding-002', + value: 'ops-text-embedding-002', + }, + { + label: 'ops-text-sparse-embedding-001', + value: 'ops-text-sparse-embedding-001', + }, + { + label: 'ops-bge-reranker-larger', + value: 'ops-bge-reranker-larger', + }, + ], + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + host: { + display: DisplayType.TEXTBOX, + label: 'Host', + order: 3, + required: true, + sensitive: false, + tooltip: + 'The name of the host address used for the {infer} task. You can find the host address at https://opensearch.console.aliyun.com/cn-shanghai/rag/api-key[ the API keys section] of the documentation.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + http_schema: { + display: DisplayType.DROPDOWN, + label: 'HTTP Schema', + order: 4, + required: true, + sensitive: false, + tooltip: '', + type: FieldType.STRING, + options: [ + { + label: 'https', + value: 'https', + }, + { + label: 'http', + value: 'http', + }, + ], + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + workspace: { + display: DisplayType.TEXTBOX, + label: 'Workspace', + order: 5, + required: true, + sensitive: false, + tooltip: 'The name of the workspace used for the {infer} task.', + type: FieldType.STRING, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + 'rate_limit.requests_per_minute': { + display: DisplayType.NUMERIC, + label: 'Rate limit', + order: 6, + required: false, + sensitive: false, + tooltip: 'Minimize the number of rate limit errors.', + type: FieldType.INTEGER, + validations: [], + value: null, + ui_restrictions: [], + default_value: null, + depends_on: [], + }, + }, + }, + ] as InferenceProvider[]; + return Promise.resolve( + providers.sort((a, b) => (a.provider > b.provider ? 1 : b.provider > a.provider ? -1 : 0)) + ); +}; + +export const useProviders = (http: HttpSetup, toasts: ToastsStart) => { + const onErrorFn = (error: Error) => { + if (error) { + toasts.addDanger( + i18n.translate( + 'xpack.stackConnectors.components.inference.unableToFindProvidersQueryMessage', + { + defaultMessage: 'Unable to find providers', + } + ) + ); + } + }; + + const query = useQuery(['user-profile'], { + queryFn: () => getProviders(http), + staleTime: Infinity, + refetchOnWindowFocus: false, + onError: onErrorFn, + }); + return query; +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/render_service_provider/service_provider.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/render_service_provider/service_provider.test.tsx new file mode 100644 index 0000000000000..84a32286b7532 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/render_service_provider/service_provider.test.tsx @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { render, screen } from '@testing-library/react'; +import React from 'react'; +import { ServiceProviderIcon, ServiceProviderName } from './service_provider'; +import { ServiceProviderKeys } from '../../../../../common/inference/constants'; + +jest.mock('../assets/images/elastic.svg', () => 'elasticIcon.svg'); +jest.mock('../assets/images/hugging_face.svg', () => 'huggingFaceIcon.svg'); +jest.mock('../assets/images/cohere.svg', () => 'cohereIcon.svg'); +jest.mock('../assets/images/open_ai.svg', () => 'openAIIcon.svg'); + +describe('ServiceProviderIcon component', () => { + it('renders Hugging Face icon and name when providerKey is hugging_face', () => { + render(); + const icon = screen.getByTestId('icon-service-provider-hugging_face'); + expect(icon).toBeInTheDocument(); + }); + + it('renders Open AI icon and name when providerKey is openai', () => { + render(); + const icon = screen.getByTestId('icon-service-provider-openai'); + expect(icon).toBeInTheDocument(); + }); +}); + +describe('ServiceProviderName component', () => { + it('renders Hugging Face icon and name when providerKey is hugging_face', () => { + render(); + expect(screen.getByText('Hugging Face')).toBeInTheDocument(); + }); + + it('renders Open AI icon and name when providerKey is openai', () => { + render(); + expect(screen.getByText('OpenAI')).toBeInTheDocument(); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/render_service_provider/service_provider.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/render_service_provider/service_provider.tsx new file mode 100644 index 0000000000000..5d2c99ffd92ce --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/render_service_provider/service_provider.tsx @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { EuiHighlight, EuiIcon } from '@elastic/eui'; +import React from 'react'; +import { ServiceProviderKeys } from '../../../../../common/inference/constants'; +import elasticIcon from '../assets/images/elastic.svg'; +import huggingFaceIcon from '../assets/images/hugging_face.svg'; +import cohereIcon from '../assets/images/cohere.svg'; +import openAIIcon from '../assets/images/open_ai.svg'; +import azureAIStudioIcon from '../assets/images/azure_ai_studio.svg'; +import azureOpenAIIcon from '../assets/images/azure_open_ai.svg'; +import googleAIStudioIcon from '../assets/images/google_ai_studio.svg'; +import mistralIcon from '../assets/images/mistral.svg'; +import amazonBedrockIcon from '../assets/images/amazon_bedrock.svg'; +import anthropicIcon from '../assets/images/anthropic.svg'; +import alibabaCloudIcon from '../assets/images/alibaba_cloud.svg'; +import ibmWatsonxIcon from '../assets/images/ibm_watsonx.svg'; + +interface ServiceProviderProps { + providerKey: ServiceProviderKeys; + searchValue?: string; +} + +type ProviderSolution = 'Observability' | 'Security' | 'Search'; + +interface ServiceProviderRecord { + icon: string; + name: string; + solutions: ProviderSolution[]; +} + +export const SERVICE_PROVIDERS: Record = { + [ServiceProviderKeys.amazonbedrock]: { + icon: amazonBedrockIcon, + name: 'Amazon Bedrock', + solutions: ['Observability', 'Security', 'Search'], + }, + [ServiceProviderKeys.azureaistudio]: { + icon: azureAIStudioIcon, + name: 'Azure AI Studio', + solutions: ['Search'], + }, + [ServiceProviderKeys.azureopenai]: { + icon: azureOpenAIIcon, + name: 'Azure OpenAI', + solutions: ['Observability', 'Security', 'Search'], + }, + [ServiceProviderKeys.anthropic]: { + icon: anthropicIcon, + name: 'Anthropic', + solutions: ['Search'], + }, + [ServiceProviderKeys.cohere]: { + icon: cohereIcon, + name: 'Cohere', + solutions: ['Search'], + }, + [ServiceProviderKeys.elasticsearch]: { + icon: elasticIcon, + name: 'Elasticsearch', + solutions: ['Search'], + }, + [ServiceProviderKeys.googleaistudio]: { + icon: googleAIStudioIcon, + name: 'Google AI Studio', + solutions: ['Search'], + }, + [ServiceProviderKeys.googlevertexai]: { + icon: googleAIStudioIcon, + name: 'Google Vertex AI', + solutions: ['Observability', 'Security', 'Search'], + }, + [ServiceProviderKeys.hugging_face]: { + icon: huggingFaceIcon, + name: 'Hugging Face', + solutions: ['Search'], + }, + [ServiceProviderKeys.mistral]: { + icon: mistralIcon, + name: 'Mistral', + solutions: ['Search'], + }, + [ServiceProviderKeys.openai]: { + icon: openAIIcon, + name: 'OpenAI', + solutions: ['Observability', 'Security', 'Search'], + }, + [ServiceProviderKeys['alibabacloud-ai-search']]: { + icon: alibabaCloudIcon, + name: 'AlibabaCloud AI Search', + solutions: ['Search'], + }, + [ServiceProviderKeys.watsonxai]: { + icon: ibmWatsonxIcon, + name: 'IBM Watsonx', + solutions: ['Search'], + }, +}; + +export const ServiceProviderIcon: React.FC = ({ providerKey }) => { + const provider = SERVICE_PROVIDERS[providerKey]; + + return provider ? ( + + ) : ( + {providerKey} + ); +}; + +export const ServiceProviderName: React.FC = ({ + providerKey, + searchValue, +}) => { + const provider = SERVICE_PROVIDERS[providerKey]; + + return provider ? ( + {provider.name} + ) : ( + {providerKey} + ); +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/selectable/index.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/selectable/index.test.tsx new file mode 100644 index 0000000000000..f83d4bcd9ea4c --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/selectable/index.test.tsx @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { EuiSelectableProps } from '@elastic/eui'; +import React from 'react'; +import type { ShallowWrapper } from 'enzyme'; +import { shallow } from 'enzyme'; + +import { SelectableProvider } from '.'; + +describe('SelectableProvider', () => { + const props = { + isLoading: false, + onClosePopover: jest.fn(), + onProviderChange: jest.fn(), + getSelectableOptions: jest.fn().mockReturnValue([]), + }; + + describe('should render', () => { + let wrapper: ShallowWrapper; + + describe('provider', () => { + beforeAll(() => { + wrapper = shallow(); + }); + + afterAll(() => { + jest.clearAllMocks(); + }); + + test('render placeholder', () => { + const searchProps: EuiSelectableProps['searchProps'] = wrapper + .find('[data-test-subj="selectable-provider-input"]') + .prop('searchProps'); + expect(searchProps?.placeholder).toEqual('Search'); + }); + }); + + describe('template', () => { + beforeAll(() => { + wrapper = shallow(); + }); + + afterAll(() => { + jest.clearAllMocks(); + }); + + test('render placeholder', () => { + const searchProps: EuiSelectableProps['searchProps'] = wrapper + .find('[data-test-subj="selectable-provider-input"]') + .prop('searchProps'); + expect(searchProps?.placeholder).toEqual('Search'); + }); + }); + }); +}); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/selectable/index.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/selectable/index.tsx new file mode 100644 index 0000000000000..d4527e9c7b9a4 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/providers/selectable/index.tsx @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { EuiSelectableOption, EuiSelectableProps } from '@elastic/eui'; +import { EuiSelectable, EuiFlexGroup, EuiFlexItem, EuiBadge } from '@elastic/eui'; +import React, { memo, useCallback, useMemo, useState } from 'react'; +import { i18n } from '@kbn/i18n'; +import { ServiceProviderKeys } from '../../../../../common/inference/constants'; +import { + SERVICE_PROVIDERS, + ServiceProviderIcon, + ServiceProviderName, +} from '../render_service_provider/service_provider'; + +/** + * Modifies options by creating new property `providerTitle`(with value of `title`), and by setting `title` to undefined. + * Thus prevents appearing default browser tooltip on option hover (attribute `title` that gets rendered on li element) + * + * @param {EuiSelectableOption[]} options + * @returns {EuiSelectableOption[]} modified options + */ + +export interface SelectableProviderProps { + isLoading: boolean; + getSelectableOptions: (searchProviderValue?: string) => EuiSelectableOption[]; + onClosePopover: () => void; + onProviderChange: (provider?: string) => void; +} + +const SelectableProviderComponent: React.FC = ({ + isLoading, + getSelectableOptions, + onClosePopover, + onProviderChange, +}) => { + const [searchProviderValue, setSearchProviderValue] = useState(''); + const onSearchProvider = useCallback( + (val: string) => { + setSearchProviderValue(val); + }, + [setSearchProviderValue] + ); + + const renderProviderOption = useCallback>( + (option, searchValue) => { + const provider = SERVICE_PROVIDERS[option.label as ServiceProviderKeys]; + return ( + + + + + + + + + + + + + + {provider && + provider.solutions.map((solution) => ( + + {solution} + + ))} + + + + ); + }, + [] + ); + + const handleProviderChange = useCallback>( + (options) => { + const selectedProvider = options.filter((option) => option.checked === 'on'); + if (selectedProvider != null && selectedProvider.length > 0) { + onProviderChange(selectedProvider[0].label); + } + onClosePopover(); + }, + [onClosePopover, onProviderChange] + ); + + const EuiSelectableContent = useCallback>( + (list, search) => ( + <> + {search} + {list} + + ), + [] + ); + + const searchProps: EuiSelectableProps['searchProps'] = useMemo( + () => ({ + 'data-test-subj': 'provider-super-select-search-box', + placeholder: i18n.translate( + 'xpack.stackConnectors.components.inference.selectable.providerSearch', + { + defaultMessage: 'Search', + } + ), + onSearch: onSearchProvider, + incremental: false, + compressed: true, + fullWidth: true, + }), + [onSearchProvider] + ); + + return ( + + {EuiSelectableContent} + + ); +}; + +export const SelectableProvider = memo(SelectableProviderComponent); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/translations.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/translations.ts new file mode 100644 index 0000000000000..d73164ed2db9e --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/translations.ts @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; + +export const getRequiredMessage = (field: string) => { + return i18n.translate('xpack.stackConnectors.components.inference.requiredGenericTextField', { + defaultMessage: '{field} is required.', + values: { field }, + }); +}; + +export const INPUT_INVALID = i18n.translate( + 'xpack.stackConnectors.inference.params.error.invalidInputText', + { + defaultMessage: 'Input does not have a valid Array format.', + } +); + +export const INVALID_ACTION = i18n.translate( + 'xpack.stackConnectors.components.inference.invalidActionText', + { + defaultMessage: 'Invalid action name.', + } +); + +export const BODY = i18n.translate('xpack.stackConnectors.components.inference.bodyFieldLabel', { + defaultMessage: 'Body', +}); + +export const INPUT = i18n.translate( + 'xpack.stackConnectors.components.inference.completionInputLabel', + { + defaultMessage: 'Input', + } +); + +export const INPUT_TYPE = i18n.translate( + 'xpack.stackConnectors.components.inference.completionInputTypeLabel', + { + defaultMessage: 'Input type', + } +); + +export const QUERY = i18n.translate('xpack.stackConnectors.components.inference.rerankQueryLabel', { + defaultMessage: 'Query', +}); + +export const BODY_DESCRIPTION = i18n.translate( + 'xpack.stackConnectors.components.inference.bodyCodeEditorAriaLabel', + { + defaultMessage: 'Code editor', + } +); + +export const TASK_TYPE = i18n.translate( + 'xpack.stackConnectors.components.inference.taskTypeFieldLabel', + { + defaultMessage: 'Task type', + } +); + +export const PROVIDER = i18n.translate( + 'xpack.stackConnectors.components.inference.providerFieldLabel', + { + defaultMessage: 'Provider', + } +); + +export const PROVIDER_REQUIRED = i18n.translate( + 'xpack.stackConnectors.components.inference.error.requiredProviderText', + { + defaultMessage: 'Provider is required.', + } +); + +export const DOCUMENTATION = i18n.translate( + 'xpack.stackConnectors.components.inference.documentation', + { + defaultMessage: 'Inference API documentation', + } +); + +export const SELECT_PROVIDER = i18n.translate( + 'xpack.stackConnectors.components.inference.selectProvider', + { + defaultMessage: 'Select a service', + } +); + +export const COPY_TOOLTIP = i18n.translate( + 'xpack.stackConnectors.components.inference.copy.tooltip', + { + defaultMessage: 'Copy to clipboard', + } +); + +export const COPIED_TOOLTIP = i18n.translate( + 'xpack.stackConnectors.components.inference.copied.tooltip', + { + defaultMessage: 'Copied!', + } +); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts new file mode 100644 index 0000000000000..150292894b643 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { UserConfiguredActionConnector } from '@kbn/triggers-actions-ui-plugin/public/types'; +import { ActionTypeModel as ConnectorTypeModel } from '@kbn/triggers-actions-ui-plugin/public'; +import { SUB_ACTION } from '../../../common/inference/constants'; +import { + ChatCompleteParams, + RerankParams, + SparseEmbeddingParams, + TextEmbeddingParams, +} from '../../../common/inference/types'; +import { ConfigProperties } from '../lib/dynamic_config/types'; + +export type InferenceActionParams = + | { subAction: SUB_ACTION.COMPLETION; subActionParams: ChatCompleteParams } + | { subAction: SUB_ACTION.RERANK; subActionParams: RerankParams } + | { subAction: SUB_ACTION.SPARSE_EMBEDDING; subActionParams: SparseEmbeddingParams } + | { subAction: SUB_ACTION.TEXT_EMBEDDING; subActionParams: TextEmbeddingParams }; + +export type FieldsConfiguration = Record; + +export interface Config { + taskType: string; + taskTypeConfig?: Record; + inferenceId: string; + provider: string; + providerConfig?: Record; +} + +export interface Secrets { + providerSecrets?: Record; +} + +export type InferenceConnector = ConnectorTypeModel; + +export type InferenceActionConnector = UserConfiguredActionConnector; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_field.tsx b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_field.tsx new file mode 100644 index 0000000000000..79ae552a9528a --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_field.tsx @@ -0,0 +1,360 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { useEffect, useState } from 'react'; + +import { + EuiAccordion, + EuiFieldText, + EuiFieldPassword, + EuiSwitch, + EuiTextArea, + EuiFlexGroup, + EuiFlexItem, + EuiIcon, + EuiFieldNumber, + EuiCheckableCard, + useGeneratedHtmlId, + EuiSpacer, + EuiSuperSelect, + EuiText, +} from '@elastic/eui'; + +import { isEmpty } from 'lodash/fp'; +import { + ensureBooleanType, + ensureCorrectTyping, + ensureStringType, +} from './connector_configuration_utils'; +import { ConfigEntryView, DisplayType } from './types'; + +interface ConnectorConfigurationFieldProps { + configEntry: ConfigEntryView; + isLoading: boolean; + setConfigValue: (value: number | string | boolean | null) => void; +} + +interface ConfigInputFieldProps { + configEntry: ConfigEntryView; + isLoading: boolean; + validateAndSetConfigValue: (value: string | boolean) => void; +} +export const ConfigInputField: React.FC = ({ + configEntry, + isLoading, + validateAndSetConfigValue, +}) => { + // eslint-disable-next-line @typescript-eslint/naming-convention + const { isValid, placeholder, value, default_value, key } = configEntry; + const [innerValue, setInnerValue] = useState( + !value || value.toString().length === 0 ? default_value : value + ); + + useEffect(() => { + setInnerValue(!value || value.toString().length === 0 ? default_value : value); + }, [default_value, value]); + return ( + { + setInnerValue(event.target.value); + validateAndSetConfigValue(event.target.value); + }} + placeholder={placeholder} + /> + ); +}; + +export const ConfigSwitchField: React.FC = ({ + configEntry, + isLoading, + validateAndSetConfigValue, +}) => { + // eslint-disable-next-line @typescript-eslint/naming-convention + const { label, value, default_value, key } = configEntry; + const [innerValue, setInnerValue] = useState(value ?? default_value); + useEffect(() => { + setInnerValue(value ?? default_value); + }, [default_value, value]); + return ( + {label}

} + onChange={(event) => { + setInnerValue(event.target.checked); + validateAndSetConfigValue(event.target.checked); + }} + /> + ); +}; + +export const ConfigInputTextArea: React.FC = ({ + isLoading, + configEntry, + validateAndSetConfigValue, +}) => { + // eslint-disable-next-line @typescript-eslint/naming-convention + const { isValid, placeholder, value, default_value, key } = configEntry; + const [innerValue, setInnerValue] = useState(value ?? default_value); + useEffect(() => { + setInnerValue(value ?? default_value); + }, [default_value, value]); + return ( + { + setInnerValue(event.target.value); + validateAndSetConfigValue(event.target.value); + }} + placeholder={placeholder} + /> + ); +}; + +export const ConfigNumberField: React.FC = ({ + configEntry, + isLoading, + validateAndSetConfigValue, +}) => { + // eslint-disable-next-line @typescript-eslint/naming-convention + const { isValid, placeholder, value, default_value, key } = configEntry; + const [innerValue, setInnerValue] = useState(value ?? default_value); + useEffect(() => { + setInnerValue(!value || value.toString().length === 0 ? default_value : value); + }, [default_value, value]); + return ( + { + const newValue = isEmpty(event.target.value) ? '0' : event.target.value; + setInnerValue(newValue); + validateAndSetConfigValue(newValue); + }} + placeholder={placeholder} + /> + ); +}; + +export const ConfigCheckableField: React.FC = ({ + configEntry, + validateAndSetConfigValue, +}) => { + const radioCardId = useGeneratedHtmlId({ prefix: 'radioCard' }); + // eslint-disable-next-line @typescript-eslint/naming-convention + const { value, options, default_value } = configEntry; + const [innerValue, setInnerValue] = useState(value ?? default_value); + useEffect(() => { + setInnerValue(value ?? default_value); + }, [default_value, value]); + return ( + <> + {options?.map((o) => ( + <> + { + setInnerValue(o.value); + validateAndSetConfigValue(o.value); + }} + /> + + + ))} + + ); +}; + +export const ConfigSensitiveTextArea: React.FC = ({ + isLoading, + configEntry, + validateAndSetConfigValue, +}) => { + const { key, label } = configEntry; + return ( + {label}

}> + +
+ ); +}; + +export const ConfigInputPassword: React.FC = ({ + isLoading, + configEntry, + validateAndSetConfigValue, +}) => { + const { value, key } = configEntry; + const [innerValue, setInnerValue] = useState(value ?? null); + useEffect(() => { + setInnerValue(value ?? null); + }, [value]); + return ( + <> + { + setInnerValue(event.target.value); + validateAndSetConfigValue(event.target.value); + }} + /> + + ); +}; + +export const ConfigSelectField: React.FC = ({ + configEntry, + isLoading, + validateAndSetConfigValue, +}) => { + // eslint-disable-next-line @typescript-eslint/naming-convention + const { isValid, options, value, default_value } = configEntry; + const [innerValue, setInnerValue] = useState(value ?? default_value); + const optionsRes = options?.map((o) => ({ + value: o.value, + inputDisplay: ( + + {o.icon ? ( + + + + ) : null} + + {o.label} + + + ), + })); + return ( + { + setInnerValue(newValue); + validateAndSetConfigValue(newValue); + }} + /> + ); +}; + +export const ConnectorConfigurationField: React.FC = ({ + configEntry, + isLoading, + setConfigValue, +}) => { + const validateAndSetConfigValue = (value: number | string | boolean) => { + setConfigValue(ensureCorrectTyping(configEntry.type, value)); + }; + + const { key, display, sensitive } = configEntry; + + switch (display) { + case DisplayType.DROPDOWN: + return ( + + ); + + case DisplayType.CHECKABLE: + return ( + + ); + + case DisplayType.NUMERIC: + return ( + + ); + + case DisplayType.TEXTAREA: + const textarea = ( + + ); + + return sensitive ? ( + <> + + + ) : ( + textarea + ); + + case DisplayType.TOGGLE: + return ( + + ); + + default: + return sensitive ? ( + + ) : ( + + ); + } +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_form_items.tsx b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_form_items.tsx new file mode 100644 index 0000000000000..3190ac80275f1 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_form_items.tsx @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; + +import { + EuiCallOut, + EuiFlexGroup, + EuiFlexItem, + EuiFormRow, + EuiPanel, + EuiSpacer, + EuiText, +} from '@elastic/eui'; + +import { i18n } from '@kbn/i18n'; +import { ConfigEntryView, DisplayType } from './types'; +import { ConnectorConfigurationField } from './connector_configuration_field'; + +interface ConnectorConfigurationFormItemsProps { + isLoading: boolean; + items: ConfigEntryView[]; + setConfigEntry: (key: string, value: string | number | boolean | null) => void; + direction?: 'column' | 'row' | 'rowReverse' | 'columnReverse' | undefined; + itemsGrow?: boolean; +} + +export const ConnectorConfigurationFormItems: React.FC = ({ + isLoading, + items, + setConfigEntry, + direction, + itemsGrow, +}) => { + return ( + + {items.map((configEntry) => { + const { + depends_on: dependencies, + key, + display, + isValid, + label, + sensitive, + tooltip, + validationErrors, + required, + } = configEntry; + + const helpText = tooltip; + // toggle and sensitive textarea labels go next to the element, not in the row + const rowLabel = + display === DisplayType.TOGGLE || (display === DisplayType.TEXTAREA && sensitive) ? ( + <> + ) : tooltip ? ( + + +

{label}

+
+
+ ) : ( +

{label}

+ ); + + const optionalLabel = !required ? ( + + {i18n.translate('xpack.stackConnectors.components.inference.config.optionalValue', { + defaultMessage: 'Optional', + })} + + ) : undefined; + + if (dependencies?.length > 0) { + return ( + + + + { + setConfigEntry(configEntry.key, value); + }} + /> + + + + ); + } + return ( + + + { + setConfigEntry(configEntry.key, value); + }} + /> + + {configEntry.sensitive ? ( + <> + + + + ) : null} + + ); + })} +
+ ); +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_utils.ts b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_utils.ts new file mode 100644 index 0000000000000..182327a180a63 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/connector_configuration_utils.ts @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ConfigProperties, FieldType } from './types'; + +export type ConnectorConfigEntry = ConfigProperties & { key: string }; + +export const validIntInput = (value: string | number | boolean | null): boolean => { + // reject non integers (including x.0 floats), but don't validate if empty + return (value !== null || value !== '') && + (isNaN(Number(value)) || + !Number.isSafeInteger(Number(value)) || + ensureStringType(value).indexOf('.') >= 0) + ? false + : true; +}; + +export const ensureCorrectTyping = ( + type: FieldType, + value: string | number | boolean | null +): string | number | boolean | null => { + switch (type) { + case FieldType.INTEGER: + return validIntInput(value) ? ensureIntType(value) : value; + case FieldType.BOOLEAN: + return ensureBooleanType(value); + default: + return ensureStringType(value); + } +}; + +export const ensureStringType = (value: string | number | boolean | null): string => { + return value !== null ? String(value) : ''; +}; + +export const ensureIntType = (value: string | number | boolean | null): number | null => { + // int is null-safe to prevent empty values from becoming zeroes + if (value === null || value === '') { + return null; + } + + return parseInt(String(value), 10); +}; + +export const ensureBooleanType = (value: string | number | boolean | null): boolean => { + return Boolean(value); +}; + +export const hasUiRestrictions = (configEntry: Partial) => { + return (configEntry.ui_restrictions ?? []).length > 0; +}; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/types.ts b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/types.ts new file mode 100644 index 0000000000000..40e17a1989075 --- /dev/null +++ b/x-pack/plugins/stack_connectors/public/connector_types/lib/dynamic_config/types.ts @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export enum DisplayType { + TEXTBOX = 'textbox', + TEXTAREA = 'textarea', + NUMERIC = 'numeric', + TOGGLE = 'toggle', + DROPDOWN = 'dropdown', + CHECKABLE = 'checkable', +} + +export interface SelectOption { + label: string; + value: string; + icon?: string; +} + +export interface Dependency { + field: string; + value: string | number | boolean | null; +} + +export enum FieldType { + STRING = 'str', + INTEGER = 'int', + LIST = 'list', + BOOLEAN = 'bool', +} + +export interface ConfigCategoryProperties { + label: string; + order: number; + type: 'category'; +} + +export interface Validation { + constraint: string | number; + type: string; +} + +export interface ConfigProperties { + category?: string; + default_value: string | number | boolean | null; + depends_on: Dependency[]; + display: DisplayType; + label: string; + options?: SelectOption[]; + order?: number | null; + placeholder?: string; + required: boolean; + sensitive: boolean; + tooltip: string | null; + type: FieldType; + ui_restrictions: string[]; + validations: Validation[]; + value: string | number | boolean | null; +} + +interface ConfigEntry extends ConfigProperties { + key: string; +} + +export interface ConfigEntryView extends ConfigEntry { + isValid: boolean; + validationErrors: string[]; +} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/index.ts b/x-pack/plugins/stack_connectors/server/connector_types/index.ts index 2c905471761ed..09d8a44c2a287 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/index.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/index.ts @@ -20,6 +20,7 @@ import { getConnectorType as getIndexConnectorType } from './es_index'; import { getConnectorType as getOpenAIConnectorType } from './openai'; import { getConnectorType as getBedrockConnectorType } from './bedrock'; import { getConnectorType as getGeminiConnectorType } from './gemini'; +import { getConnectorType as getInferenceConnectorType } from './inference'; import { getConnectorType as getPagerDutyConnectorType } from './pagerduty'; import { getConnectorType as getSwimlaneConnectorType } from './swimlane'; import { getConnectorType as getServerLogConnectorType } from './server_log'; @@ -118,4 +119,7 @@ export function registerConnectorTypes({ if (experimentalFeatures.crowdstrikeConnectorOn) { actions.registerSubActionConnectorType(getCrowdstrikeConnectorType()); } + if (experimentalFeatures.inferenceConnectorOn) { + actions.registerSubActionConnectorType(getInferenceConnectorType()); + } } diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/index.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/index.test.ts new file mode 100644 index 0000000000000..b764a318df5dd --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/index.test.ts @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; +import { ActionsConfigurationUtilities } from '@kbn/actions-plugin/server/actions_config'; +import { configValidator, getConnectorType } from '.'; +import { Config, Secrets } from '../../../common/inference/types'; +import { SubActionConnectorType } from '@kbn/actions-plugin/server/sub_action_framework/types'; +import { DEFAULT_PROVIDER, DEFAULT_TASK_TYPE } from '../../../common/inference/constants'; +import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; +import { InferencePutResponse } from '@elastic/elasticsearch/lib/api/types'; + +let connectorType: SubActionConnectorType; +let configurationUtilities: jest.Mocked; + +const mockEsClient = elasticsearchClientMock.createClusterClient().asScoped().asInternalUser; + +const mockResponse: Promise = Promise.resolve({ + inference_id: 'test', + service: 'openai', + service_settings: {}, + task_settings: {}, + task_type: 'completion', +}); + +describe('AI Connector', () => { + beforeEach(() => { + configurationUtilities = actionsConfigMock.create(); + connectorType = getConnectorType(); + }); + test('exposes the connector as `AI Connector` with id `.inference`', () => { + mockEsClient.inference.put.mockResolvedValue(mockResponse); + expect(connectorType.id).toEqual('.inference'); + expect(connectorType.name).toEqual('AI Connector'); + }); + describe('config validation', () => { + test('config validation passes when only required fields are provided', () => { + const config: Config = { + providerConfig: { + url: 'https://api.openai.com/v1/chat/completions', + }, + provider: DEFAULT_PROVIDER, + taskType: DEFAULT_TASK_TYPE, + inferenceId: 'test', + taskTypeConfig: {}, + }; + + expect(configValidator(config, { configurationUtilities })).toEqual(config); + }); + + test('config validation failed when the task type is empty', () => { + const config: Config = { + providerConfig: {}, + provider: 'openai', + taskType: '', + inferenceId: 'test', + taskTypeConfig: {}, + }; + expect(() => { + configValidator(config, { configurationUtilities }); + }).toThrowErrorMatchingInlineSnapshot( + `"Error configuring Inference API action: Error: Task type is not supported by Inference Endpoint."` + ); + }); + + test('config validation failed when the provider is empty', () => { + const config: Config = { + providerConfig: {}, + provider: '', + taskType: DEFAULT_TASK_TYPE, + inferenceId: 'test', + taskTypeConfig: {}, + }; + expect(() => { + configValidator(config, { configurationUtilities }); + }).toThrowErrorMatchingInlineSnapshot( + `"Error configuring Inference API action: Error: API Provider is not supported by Inference Endpoint."` + ); + }); + }); +}); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts new file mode 100644 index 0000000000000..18af48bc18a51 --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts @@ -0,0 +1,182 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; +import { + SubActionConnectorType, + ValidatorType, +} from '@kbn/actions-plugin/server/sub_action_framework/types'; +import { + GenerativeAIForSearchPlaygroundConnectorFeatureId, + GenerativeAIForSecurityConnectorFeatureId, +} from '@kbn/actions-plugin/common'; +import { ValidatorServices } from '@kbn/actions-plugin/server/types'; +import { GenerativeAIForObservabilityConnectorFeatureId } from '@kbn/actions-plugin/common/connector_feature_config'; +import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types'; +import { ElasticsearchClient, Logger } from '@kbn/core/server'; +import { + INFERENCE_CONNECTOR_TITLE, + INFERENCE_CONNECTOR_ID, + ServiceProviderKeys, + SUB_ACTION, +} from '../../../common/inference/constants'; +import { ConfigSchema, SecretsSchema } from '../../../common/inference/schema'; +import { Config, Secrets } from '../../../common/inference/types'; +import { InferenceConnector } from './inference'; +import { unflattenObject } from '../lib/unflatten_object'; + +const deleteInferenceEndpoint = async ( + inferenceId: string, + taskType: InferenceTaskType, + logger: Logger, + esClient: ElasticsearchClient +) => { + try { + await esClient.inference.delete({ + task_type: taskType, + inference_id: inferenceId, + }); + logger.debug( + `Inference endpoint for task type "${taskType}" and inference id ${inferenceId} was successfuly deleted` + ); + } catch (e) { + logger.warn( + `Failed to delete inference endpoint for task type "${taskType}" and inference id ${inferenceId}. Error: ${e.message}` + ); + throw e; + } +}; + +export const getConnectorType = (): SubActionConnectorType => ({ + id: INFERENCE_CONNECTOR_ID, + name: INFERENCE_CONNECTOR_TITLE, + getService: (params) => new InferenceConnector(params), + schema: { + config: ConfigSchema, + secrets: SecretsSchema, + }, + validators: [{ type: ValidatorType.CONFIG, validator: configValidator }], + supportedFeatureIds: [ + GenerativeAIForSecurityConnectorFeatureId, + GenerativeAIForSearchPlaygroundConnectorFeatureId, + GenerativeAIForObservabilityConnectorFeatureId, + ], + minimumLicenseRequired: 'enterprise' as const, + preSaveHook: async ({ config, secrets, logger, services, isUpdate }) => { + const esClient = services.scopedClusterClient.asInternalUser; + try { + const taskSettings = config?.taskTypeConfig + ? { + ...unflattenObject(config?.taskTypeConfig), + } + : {}; + const serviceSettings = { + ...unflattenObject(config?.providerConfig ?? {}), + ...unflattenObject(secrets?.providerSecrets ?? {}), + }; + + let inferenceExists = false; + try { + await esClient?.inference.get({ + inference_id: config?.inferenceId, + task_type: config?.taskType as InferenceTaskType, + }); + inferenceExists = true; + } catch (e) { + /* throws error if inference endpoint by id does not exist */ + } + if (!isUpdate && inferenceExists) { + throw new Error( + `Inference with id ${config?.inferenceId} and task type ${config?.taskType} already exists.` + ); + } + + if (isUpdate && inferenceExists && config && config.provider) { + // TODO: replace, when update API for inference endpoint exists + await deleteInferenceEndpoint( + config.inferenceId, + config.taskType as InferenceTaskType, + logger, + esClient + ); + } + + await esClient?.inference.put({ + inference_id: config?.inferenceId ?? '', + task_type: config?.taskType as InferenceTaskType, + inference_config: { + service: config!.provider, + service_settings: serviceSettings, + task_settings: taskSettings, + }, + }); + logger.debug( + `Inference endpoint for task type "${config?.taskType}" and inference id ${ + config?.inferenceId + } was successfuly ${isUpdate ? 'updated' : 'created'}` + ); + } catch (e) { + logger.warn( + `Failed to ${isUpdate ? 'update' : 'create'} inference endpoint for task type "${ + config?.taskType + }" and inference id ${config?.inferenceId}. Error: ${e.message}` + ); + throw e; + } + }, + postSaveHook: async ({ config, logger, services, wasSuccessful, isUpdate }) => { + if (!wasSuccessful && !isUpdate) { + const esClient = services.scopedClusterClient.asInternalUser; + await deleteInferenceEndpoint( + config.inferenceId, + config.taskType as InferenceTaskType, + logger, + esClient + ); + } + }, + postDeleteHook: async ({ config, logger, services }) => { + const esClient = services.scopedClusterClient.asInternalUser; + await deleteInferenceEndpoint( + config.inferenceId, + config.taskType as InferenceTaskType, + logger, + esClient + ); + }, +}); + +export const configValidator = (configObject: Config, validatorServices: ValidatorServices) => { + try { + const { provider, taskType } = configObject; + if (!Object.keys(ServiceProviderKeys).includes(provider)) { + throw new Error( + `API Provider is not supported${ + provider && provider.length ? `: ${provider}` : `` + } by Inference Endpoint.` + ); + } + + if (!Object.keys(SUB_ACTION).includes(taskType.toUpperCase())) { + throw new Error( + `Task type is not supported${ + taskType && taskType.length ? `: ${taskType}` : `` + } by Inference Endpoint.` + ); + } + return configObject; + } catch (err) { + throw new Error( + i18n.translate('xpack.stackConnectors.inference.configurationErrorApiProvider', { + defaultMessage: 'Error configuring Inference API action: {err}', + values: { + err: err.toString(), + }, + }) + ); + } +}; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts new file mode 100644 index 0000000000000..a79bd0360598b --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts @@ -0,0 +1,310 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { InferenceConnector } from './inference'; +import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; +import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; +import { actionsMock } from '@kbn/actions-plugin/server/mocks'; +import { PassThrough, Transform } from 'stream'; +import {} from '@kbn/actions-plugin/server/types'; +import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; +import { InferenceInferenceResponse } from '@elastic/elasticsearch/lib/api/types'; + +const OPENAI_CONNECTOR_ID = '123'; +const DEFAULT_OPENAI_MODEL = 'gpt-4o'; + +describe('InferenceConnector', () => { + let mockError: jest.Mock; + const logger = loggingSystemMock.createLogger(); + const mockResponse: InferenceInferenceResponse = { + completion: [ + { + result: + 'Elastic is a company known for developing the Elasticsearch search and analytics engine, which allows for real-time data search, analysis, and visualization. Elasticsearch is part of the larger Elastic Stack (also known as the ELK Stack), which includes:\n\n1. **Elasticsearch**: A distributed, RESTful search and analytics engine capable of addressing a growing number of use cases. As the heart of the Elastic Stack, it centrally stores your data so you can discover the expected and uncover the unexpected.\n \n2. **Logstash**: A server-side data processing pipeline that ingests data from multiple sources simultaneously, transforms it, and sends it to your preferred "stash," such as Elasticsearch.\n \n3. **Kibana**: A data visualization dashboard for Elasticsearch. It allows you to search, view, and interact with data stored in Elasticsearch indices. You can perform advanced data analysis and visualize data in various charts, tables, and maps.\n\n4. **Beats**: Lightweight data shippers for different types of data. They send data from hundreds or thousands of machines and systems to Elasticsearch or Logstash.\n\nThe Elastic Stack is commonly used for various applications, such as log and event data analysis, full-text search, security analytics, business analytics, and more. It is employed across many industries to derive insights from large volumes of structured and unstructured data.\n\nElastic offers both open-source and paid versions of its software, providing a variety of features ranging from basic data ingestion and visualization to advanced machine learning and security capabilities.', + }, + ], + }; + + describe('performApiCompletion', () => { + const mockEsClient = elasticsearchClientMock.createClusterClient().asScoped().asInternalUser; + + beforeEach(() => { + mockEsClient.inference.inference.mockResolvedValue(mockResponse); + mockError = jest.fn().mockImplementation(() => { + throw new Error('API Error'); + }); + }); + + const services = actionsMock.createServices(); + services.scopedClusterClient = mockEsClient; + const connector = new InferenceConnector({ + configurationUtilities: actionsConfigMock.create(), + connector: { id: '1', type: OPENAI_CONNECTOR_ID }, + config: { + provider: 'openai', + providerConfig: { + url: 'https://api.openai.com/v1/chat/completions', + model_id: DEFAULT_OPENAI_MODEL, + }, + taskType: 'completion', + inferenceId: 'test', + taskTypeConfig: {}, + }, + secrets: { providerSecrets: { api_key: '123' } }, + logger, + services, + }); + + it('uses the completion task_type is supplied', async () => { + const response = await connector.performApiCompletion({ + input: 'What is Elastic?', + }); + expect(mockEsClient.inference.inference).toBeCalledTimes(1); + expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + { + inference_id: 'test', + input: 'What is Elastic?', + task_type: 'completion', + }, + { asStream: false } + ); + expect(response).toEqual(mockResponse.completion); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + mockEsClient.inference.inference = mockError; + + await expect(connector.performApiCompletion({ input: 'What is Elastic?' })).rejects.toThrow( + 'API Error' + ); + }); + }); + + describe('performApiRerank', () => { + const mockEsClient = elasticsearchClientMock.createClusterClient().asScoped().asInternalUser; + const mockResponseRerank = { + rerank: [ + { + index: 2, + score: 0.011597361, + text: 'leia', + }, + { + index: 0, + score: 0.006338922, + text: 'luke', + }, + ], + }; + + beforeEach(() => { + mockEsClient.inference.inference.mockResolvedValue(mockResponseRerank); + mockError = jest.fn().mockImplementation(() => { + throw new Error('API Error'); + }); + }); + const services = actionsMock.createServices(); + services.scopedClusterClient = mockEsClient; + it('the API call is successful with correct parameters', async () => { + const connectorRerank = new InferenceConnector({ + configurationUtilities: actionsConfigMock.create(), + connector: { id: '1', type: '123' }, + config: { + provider: 'googlevertexai', + providerConfig: { + model_id: DEFAULT_OPENAI_MODEL, + }, + taskType: 'rerank', + inferenceId: 'test-rerank', + taskTypeConfig: {}, + }, + secrets: { providerSecrets: { api_key: '123' } }, + logger, + services, + }); + const response = await connectorRerank.performApiRerank({ + input: ['apple', 'banana'], + query: 'test', + }); + expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + { + inference_id: 'test-rerank', + input: ['apple', 'banana'], + query: 'test', + task_type: 'rerank', + }, + { asStream: false } + ); + expect(response).toEqual(mockResponseRerank.rerank); + }); + }); + + describe('performApiTextEmbedding', () => { + const mockEsClient = elasticsearchClientMock.createClusterClient().asScoped().asInternalUser; + + beforeEach(() => { + mockEsClient.inference.inference.mockResolvedValue(mockResponse); + mockError = jest.fn().mockImplementation(() => { + throw new Error('API Error'); + }); + }); + + const services = actionsMock.createServices(); + services.scopedClusterClient = mockEsClient; + const connectorTextEmbedding = new InferenceConnector({ + configurationUtilities: actionsConfigMock.create(), + connector: { id: '1', type: OPENAI_CONNECTOR_ID }, + config: { + providerConfig: { + url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15', + }, + provider: 'elasticsearch', + taskType: '', + inferenceId: '', + taskTypeConfig: {}, + }, + secrets: { providerSecrets: {} }, + logger: loggingSystemMock.createLogger(), + services, + }); + + it('test the AzureAI API call is successful with correct parameters', async () => { + const response = await connectorTextEmbedding.performApiTextEmbedding({ + input: 'Hello world', + inputType: 'ingest', + }); + expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + { + inference_id: '', + input: 'Hello world', + task_settings: { + input_type: 'ingest', + }, + task_type: 'text_embedding', + }, + { asStream: false } + ); + expect(response).toEqual(mockResponse.text_embedding); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + mockEsClient.inference.inference = mockError; + + await expect( + connectorTextEmbedding.performApiTextEmbedding({ + input: 'Hello world', + inputType: 'ingest', + }) + ).rejects.toThrow('API Error'); + }); + }); + + describe('performApiCompletionStream', () => { + const mockEsClient = elasticsearchClientMock.createClusterClient().asScoped().asInternalUser; + + const mockStream = ( + dataToStream: string[] = [ + 'data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}', + ] + ) => { + const streamMock = createStreamMock(); + dataToStream.forEach((chunk) => { + streamMock.write(chunk); + }); + streamMock.complete(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockEsClient.inference.inference.mockResolvedValue(streamMock.transform as any); + }; + + beforeEach(() => { + // @ts-ignore + mockStream(); + }); + + const services = actionsMock.createServices(); + services.scopedClusterClient = mockEsClient; + const connector = new InferenceConnector({ + configurationUtilities: actionsConfigMock.create(), + connector: { id: '1', type: OPENAI_CONNECTOR_ID }, + config: { + providerConfig: { + url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15', + }, + provider: 'elasticsearch', + taskType: 'completion', + inferenceId: '', + taskTypeConfig: {}, + }, + secrets: { providerSecrets: {} }, + logger: loggingSystemMock.createLogger(), + services, + }); + + it('the API call is successful with correct request parameters', async () => { + await connector.performApiCompletionStream({ input: 'Hello world' }); + expect(mockEsClient.inference.inference).toBeCalledTimes(1); + expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + { + inference_id: '', + input: 'Hello world', + task_type: 'completion', + }, + { asStream: true } + ); + }); + + it('signal is properly passed to streamApi', async () => { + const signal = jest.fn() as unknown as AbortSignal; + await connector.performApiCompletionStream({ input: 'Hello world', signal }); + + expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + { + inference_id: '', + input: 'Hello world', + task_type: 'completion', + }, + { asStream: true, signal } + ); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + mockEsClient.inference.inference = mockError; + + await expect( + connector.performApiCompletionStream({ input: 'What is Elastic?' }) + ).rejects.toThrow('API Error'); + }); + + it('responds with a readable stream', async () => { + const response = await connector.performApiCompletionStream({ + input: 'What is Elastic?', + }); + expect(response instanceof PassThrough).toEqual(true); + }); + }); +}); + +function createStreamMock() { + const transform: Transform = new Transform({}); + + return { + write: (data: string) => { + transform.push(data); + }, + fail: () => { + transform.emit('error', new Error('Stream failed')); + transform.end(); + }, + transform, + complete: () => { + transform.end(); + }, + }; +} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts new file mode 100644 index 0000000000000..d9aa4bf044e1d --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts @@ -0,0 +1,232 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; + +import { PassThrough, Stream } from 'stream'; +import { IncomingMessage } from 'http'; + +import { AxiosError } from 'axios'; +import { + InferenceInferenceRequest, + InferenceInferenceResponse, + InferenceTaskType, +} from '@elastic/elasticsearch/lib/api/types'; +import { + ChatCompleteParamsSchema, + RerankParamsSchema, + SparseEmbeddingParamsSchema, + TextEmbeddingParamsSchema, +} from '../../../common/inference/schema'; +import { + Config, + Secrets, + ChatCompleteParams, + ChatCompleteResponse, + StreamingResponse, + RerankParams, + RerankResponse, + SparseEmbeddingParams, + SparseEmbeddingResponse, + TextEmbeddingParams, + TextEmbeddingResponse, +} from '../../../common/inference/types'; +import { SUB_ACTION } from '../../../common/inference/constants'; + +export class InferenceConnector extends SubActionConnector { + // Not using Axios + protected getResponseErrorMessage(error: AxiosError): string { + throw new Error('Method not implemented.'); + } + + private inferenceId; + private taskType; + + constructor(params: ServiceParams) { + super(params); + + this.provider = this.config.provider; + this.taskType = this.config.taskType; + this.inferenceId = this.config.inferenceId; + this.logger = this.logger; + this.connectorID = this.connector.id; + this.connectorTokenClient = params.services.connectorTokenClient; + + this.registerSubActions(); + } + + private registerSubActions() { + this.registerSubAction({ + name: SUB_ACTION.COMPLETION, + method: 'performApiCompletion', + schema: ChatCompleteParamsSchema, + }); + + this.registerSubAction({ + name: SUB_ACTION.RERANK, + method: 'performApiRerank', + schema: RerankParamsSchema, + }); + + this.registerSubAction({ + name: SUB_ACTION.SPARSE_EMBEDDING, + method: 'performApiSparseEmbedding', + schema: SparseEmbeddingParamsSchema, + }); + + this.registerSubAction({ + name: SUB_ACTION.TEXT_EMBEDDING, + method: 'performApiTextEmbedding', + schema: TextEmbeddingParamsSchema, + }); + + this.registerSubAction({ + name: SUB_ACTION.COMPLETION_STREAM, + method: 'performApiCompletionStream', + schema: ChatCompleteParamsSchema, + }); + } + + /** + * responsible for making a esClient inference method to perform chat completetion task endpoint and returning the service response data + * @param input the text on which you want to perform the inference task. + * @signal abort signal + */ + public async performApiCompletion({ + input, + signal, + }: ChatCompleteParams & { signal?: AbortSignal }): Promise { + const response = await this.performInferenceApi( + { inference_id: this.inferenceId, input, task_type: 'completion' }, + false, + signal + ); + return response.completion!; + } + + /** + * responsible for making a esClient inference method to rerank task endpoint and returning the response data + * @param input the text on which you want to perform the inference task. input can be a single string or an array. + * @query the search query text + * @signal abort signal + */ + public async performApiRerank({ + input, + query, + signal, + }: RerankParams & { signal?: AbortSignal }): Promise { + const response = await this.performInferenceApi( + { + query, + inference_id: this.inferenceId, + input, + task_type: 'rerank', + }, + false, + signal + ); + return response.rerank!; + } + + /** + * responsible for making a esClient inference method sparse embedding task endpoint and returning the response data + * @param input the text on which you want to perform the inference task. + * @signal abort signal + */ + public async performApiSparseEmbedding({ + input, + signal, + }: SparseEmbeddingParams & { signal?: AbortSignal }): Promise { + const response = await this.performInferenceApi( + { inference_id: this.inferenceId, input, task_type: 'sparse_embedding' }, + false, + signal + ); + return response.sparse_embedding!; + } + + /** + * responsible for making a esClient inference method text embedding task endpoint and returning the response data + * @param input the text on which you want to perform the inference task. + * @signal abort signal + */ + public async performApiTextEmbedding({ + input, + inputType, + signal, + }: TextEmbeddingParams & { signal?: AbortSignal }): Promise { + const response = await this.performInferenceApi( + { + inference_id: this.inferenceId, + input, + task_type: 'text_embedding', + task_settings: { + input_type: inputType, + }, + }, + false, + signal + ); + return response.text_embedding!; + } + + /** + * private generic method to avoid duplication esClient inference inference execute. + * @param params InferenceInferenceRequest params. + * @param asStream defines the type of the responce, regular or stream + * @signal abort signal + */ + private async performInferenceApi( + params: InferenceInferenceRequest, + asStream: boolean = false, + signal?: AbortSignal + ): Promise { + try { + const response = await this.esClient?.inference.inference(params, { asStream, signal }); + this.logger.info( + `Perform Inference endpoint for task type "${this.taskType}" and inference id ${this.inferenceId}` + ); + // TODO: const usageMetadata = response?.data?.usageMetadata; + return response; + } catch (err) { + this.logger.error(`error perform inference endpoint API: ${err}`); + throw err; + } + } + + private async streamAPI({ + input, + signal, + }: ChatCompleteParams & { signal?: AbortSignal }): Promise { + const response = await this.performInferenceApi( + { inference_id: this.inferenceId, input, task_type: this.taskType as InferenceTaskType }, + true, + signal + ); + + return (response as unknown as Stream).pipe(new PassThrough()); + } + + /** + * takes input. It calls the streamApi method to make a + * request to the Inference API with the message. It then returns a Transform stream + * that pipes the response from the API through the transformToString function, + * which parses the proprietary response into a string of the response text alone + * @param input A message to be sent to the API + * @signal abort signal + */ + public async performApiCompletionStream({ + input, + signal, + }: ChatCompleteParams & { signal?: AbortSignal }): Promise { + const res = (await this.streamAPI({ + input, + signal, + })) as unknown as IncomingMessage; + return res; + } +} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/create_gen_ai_dashboard.ts b/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/create_gen_ai_dashboard.ts index ed15c31dc29c7..ee5a3471a2d34 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/create_gen_ai_dashboard.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/create_gen_ai_dashboard.ts @@ -24,7 +24,7 @@ export const initDashboard = async ({ logger: Logger; savedObjectsClient: SavedObjectsClientContract; dashboardId: string; - genAIProvider: 'OpenAI' | 'Bedrock' | 'Gemini'; + genAIProvider: 'OpenAI' | 'Bedrock' | 'Gemini' | 'Inference'; }): Promise<{ success: boolean; error?: OutputError; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/gen_ai_dashboard.ts b/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/gen_ai_dashboard.ts index 144704b8af677..5805dd7728ccf 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/gen_ai_dashboard.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/lib/gen_ai/gen_ai_dashboard.ts @@ -11,11 +11,15 @@ import { SavedObject } from '@kbn/core-saved-objects-common/src/server_types'; import { OPENAI_TITLE, OPENAI_CONNECTOR_ID } from '../../../../common/openai/constants'; import { BEDROCK_TITLE, BEDROCK_CONNECTOR_ID } from '../../../../common/bedrock/constants'; import { GEMINI_TITLE, GEMINI_CONNECTOR_ID } from '../../../../common/gemini/constants'; +import { + INFERENCE_CONNECTOR_TITLE, + INFERENCE_CONNECTOR_ID, +} from '../../../../common/inference/constants'; export const getDashboardTitle = (title: string) => `${title} Token Usage`; export const getDashboard = ( - genAIProvider: 'OpenAI' | 'Bedrock' | 'Gemini', + genAIProvider: 'OpenAI' | 'Bedrock' | 'Gemini' | 'Inference', dashboardId: string ): SavedObject => { let attributes = { @@ -42,6 +46,12 @@ export const getDashboard = ( dashboardTitle: getDashboardTitle(GEMINI_TITLE), actionTypeId: GEMINI_CONNECTOR_ID, }; + } else if (genAIProvider === 'Inference') { + attributes = { + provider: INFERENCE_CONNECTOR_TITLE, + dashboardTitle: getDashboardTitle(INFERENCE_CONNECTOR_TITLE), + actionTypeId: INFERENCE_CONNECTOR_ID, + }; } const ids: Record = { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/lib/unflatten_object.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/lib/unflatten_object.test.ts new file mode 100644 index 0000000000000..ccd11f7e92947 --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/lib/unflatten_object.test.ts @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { unflattenObject } from './unflatten_object'; + +describe('unflattenObject', () => { + test('should unflatten an object', () => { + const obj = { + a: true, + 'b.baz[0].a': false, + 'b.baz[0].b': 'foo', + 'b.baz[1]': 'bar', + 'b.baz[2]': true, + 'b.foo': 'bar', + 'b.baz[3][0]': 1, + 'b.baz[3][1]': 2, + 'c.b.foo': 'cheese', + }; + + expect(unflattenObject(obj)).toEqual({ + a: true, + b: { + foo: 'bar', + baz: [ + { + a: false, + b: 'foo', + }, + 'bar', + true, + [1, 2], + ], + }, + c: { + b: { + foo: 'cheese', + }, + }, + }); + }); +}); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/lib/unflatten_object.ts b/x-pack/plugins/stack_connectors/server/connector_types/lib/unflatten_object.ts new file mode 100644 index 0000000000000..23f2553223a35 --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/lib/unflatten_object.ts @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { set } from '@kbn/safer-lodash-set'; + +interface GenericObject { + [key: string]: unknown; +} +export const unflattenObject = (object: object): T => + Object.entries(object).reduce((acc, [key, value]) => { + set(acc, key, value); + return acc; + }, {} as T); diff --git a/x-pack/plugins/stack_connectors/server/plugin.test.ts b/x-pack/plugins/stack_connectors/server/plugin.test.ts index 0c9551f787ca0..7d11d152f8316 100644 --- a/x-pack/plugins/stack_connectors/server/plugin.test.ts +++ b/x-pack/plugins/stack_connectors/server/plugin.test.ts @@ -131,7 +131,7 @@ describe('Stack Connectors Plugin', () => { name: 'Torq', }) ); - expect(actionsSetup.registerSubActionConnectorType).toHaveBeenCalledTimes(10); + expect(actionsSetup.registerSubActionConnectorType).toHaveBeenCalledTimes(11); expect(actionsSetup.registerSubActionConnectorType).toHaveBeenNthCalledWith( 1, expect.objectContaining({ diff --git a/x-pack/plugins/stack_connectors/tsconfig.json b/x-pack/plugins/stack_connectors/tsconfig.json index 8a37f4edaa0b0..66c84d17408bc 100644 --- a/x-pack/plugins/stack_connectors/tsconfig.json +++ b/x-pack/plugins/stack_connectors/tsconfig.json @@ -42,6 +42,8 @@ "@kbn/utility-types", "@kbn/task-manager-plugin", "@kbn/alerting-types", + "@kbn/alerts-ui-shared", + "@kbn/core-notifications-browser", ], "exclude": [ "target/**/*", diff --git a/x-pack/plugins/task_manager/server/integration_tests/__snapshots__/task_cost_check.test.ts.snap b/x-pack/plugins/task_manager/server/integration_tests/__snapshots__/task_cost_check.test.ts.snap index 754d9f0c66b4b..96ac62b9f03df 100644 --- a/x-pack/plugins/task_manager/server/integration_tests/__snapshots__/task_cost_check.test.ts.snap +++ b/x-pack/plugins/task_manager/server/integration_tests/__snapshots__/task_cost_check.test.ts.snap @@ -106,6 +106,10 @@ Array [ "cost": 1, "taskType": "actions:.crowdstrike", }, + Object { + "cost": 1, + "taskType": "actions:.inference", + }, Object { "cost": 1, "taskType": "actions:.cases", diff --git a/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/create_connector_flyout/header.tsx b/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/create_connector_flyout/header.tsx index 0b252ca6660c6..e3fb79520b93b 100644 --- a/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/create_connector_flyout/header.tsx +++ b/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/create_connector_flyout/header.tsx @@ -51,13 +51,17 @@ const FlyoutHeaderComponent: React.FC = ({

- + {actionTypeName && actionTypeName.toLowerCase().includes('connector') ? ( + actionTypeName + ) : ( + + )}

diff --git a/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.test.tsx b/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.test.tsx index 78039c753a276..22eb2673fd353 100644 --- a/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.test.tsx +++ b/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.test.tsx @@ -15,17 +15,23 @@ const renderWithSecretFields = ({ isEdit, isMissingSecrets, numberOfSecretFields, + isSecretFieldsHidden = false, }: { isEdit: boolean; isMissingSecrets: boolean; numberOfSecretFields: number; + isSecretFieldsHidden?: boolean; }): RenderResult => { return render( {Array.from({ length: numberOfSecretFields }).map((_, index) => { return ( - + ); })} @@ -67,10 +73,16 @@ describe('EncryptedFieldsCallout', () => { ], ]; - const noSecretsTests: Array<[{ isEdit: boolean; isMissingSecrets: boolean }, string]> = [ + const noSecretsTests: Array< + [{ isEdit: boolean; isMissingSecrets: boolean; isSecretFieldsHidden?: boolean }, string] + > = [ [{ isEdit: false, isMissingSecrets: false }, 'create-connector-secrets-callout'], [{ isEdit: true, isMissingSecrets: false }, 'edit-connector-secrets-callout'], [{ isEdit: false, isMissingSecrets: true }, 'missing-secrets-callout'], + [ + { isEdit: true, isMissingSecrets: true, isSecretFieldsHidden: true }, + 'edit-connector-secrets-callout', + ], ]; it.each(isCreateTests)( diff --git a/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.tsx b/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.tsx index 35cbb473c6a3f..161fed865a4e6 100644 --- a/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.tsx +++ b/x-pack/plugins/triggers_actions_ui/public/application/sections/action_connector_form/encrypted_fields_callout.tsx @@ -96,7 +96,7 @@ const EncryptedFieldsCalloutComponent: React.FC = ( ); } - if (!isEdit) { + if (!isEdit && secretFieldsLabel.length) { return ( = ( ); } - if (isEdit) { + if (isEdit && secretFieldsLabel.length) { return ( + ) { + if (this.returnError) { + return InferenceSimulator.sendErrorResponse(response); + } + + return InferenceSimulator.sendResponse(response); + } + + private static sendResponse(response: http.ServerResponse) { + response.statusCode = 202; + response.setHeader('Content-Type', 'application/json'); + response.end(JSON.stringify(inferenceSuccessResponse, null, 4)); + } + + private static sendErrorResponse(response: http.ServerResponse) { + response.statusCode = 422; + response.setHeader('Content-Type', 'application/json;charset=UTF-8'); + response.end(JSON.stringify(inferenceFailedResponse, null, 4)); + } +} + +export const inferenceSuccessResponse = { + refid: '80be4a0d-5f0e-4d6c-b00e-8cb918f7df1f', +}; +export const inferenceFailedResponse = { + error: { + statusMessage: 'Bad job', + }, +}; diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/inference.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/inference.ts new file mode 100644 index 0000000000000..f3f6361e84db4 --- /dev/null +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/inference.ts @@ -0,0 +1,518 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import expect from '@kbn/expect'; +import { IValidatedEvent } from '@kbn/event-log-plugin/server'; + +import { + InferenceSimulator, + inferenceSuccessResponse, +} from '@kbn/actions-simulators-plugin/server/inference_simulation'; +import { TaskErrorSource } from '@kbn/task-manager-plugin/common'; +import { FtrProviderContext } from '../../../../../common/ftr_provider_context'; +import { getUrlPrefix, ObjectRemover } from '../../../../../common/lib'; +import { getEventLog } from '../../../../../common/lib'; + +const connectorTypeId = '.inference'; +const name = 'AI connector action'; +const secrets = { + apiKey: 'genAiApiKey', +}; + +const defaultConfig = { provider: 'openai' }; + +// eslint-disable-next-line import/no-default-export +export default function InferenceConnectorTest({ getService }: FtrProviderContext) { + const supertest = getService('supertest'); + const objectRemover = new ObjectRemover(supertest); + const supertestWithoutAuth = getService('supertestWithoutAuth'); + const configService = getService('config'); + const retry = getService('retry'); + const createConnector = async (apiUrl: string, spaceId?: string) => { + const { body } = await supertest + .post(`${getUrlPrefix(spaceId ?? 'default')}/api/actions/connector`) + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: { ...defaultConfig, apiUrl }, + secrets, + }) + .expect(200); + + objectRemover.add(spaceId ?? 'default', body.id, 'connector', 'actions'); + + return body.id; + }; + + describe('OpenAI', () => { + after(async () => { + await objectRemover.removeAll(); + }); + describe('action creation', () => { + const simulator = new InferenceSimulator({ + returnError: false, + proxy: { + config: configService.get('kbnTestServer.serverArgs'), + }, + }); + const config = { ...defaultConfig, apiUrl: '' }; + + before(async () => { + config.apiUrl = await simulator.start(); + }); + + after(() => { + simulator.close(); + }); + + it('should return 200 when creating the connector without a default model', async () => { + const { body: createdAction } = await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config, + secrets, + }) + .expect(200); + + expect(createdAction).to.eql({ + id: createdAction.id, + is_preconfigured: false, + is_system_action: false, + is_deprecated: false, + name, + connector_type_id: connectorTypeId, + is_missing_secrets: false, + config: { + ...config, + defaultModel: 'gpt-4o', + }, + }); + }); + + it('should return 200 when creating the connector with a default model', async () => { + const { body: createdAction } = await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: { + ...config, + defaultModel: 'gpt-3.5-turbo', + }, + secrets, + }) + .expect(200); + + expect(createdAction).to.eql({ + id: createdAction.id, + is_preconfigured: false, + is_system_action: false, + is_deprecated: false, + name, + connector_type_id: connectorTypeId, + is_missing_secrets: false, + config: { + ...config, + defaultModel: 'gpt-3.5-turbo', + }, + }); + }); + + it('should return 400 Bad Request when creating the connector without the apiProvider', async () => { + await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name: 'A GenAi action', + connector_type_id: '.gen-ai', + config: { + apiUrl: config.apiUrl, + }, + secrets: { + apiKey: '123', + }, + }) + .expect(400) + .then((resp: any) => { + expect(resp.body).to.eql({ + statusCode: 400, + error: 'Bad Request', + message: + 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected at least one defined value but got [undefined]\n- [1.apiProvider]: expected at least one defined value but got [undefined]', + }); + }); + }); + + it('should return 400 Bad Request when creating the connector without the apiUrl', async () => { + await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: defaultConfig, + secrets, + }) + .expect(400) + .then((resp: any) => { + expect(resp.body).to.eql({ + statusCode: 400, + error: 'Bad Request', + message: + 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected value to equal [Azure OpenAI]\n- [1.apiUrl]: expected value of type [string] but got [undefined]', + }); + }); + }); + + it('should return 400 Bad Request when creating the connector with a apiUrl that is not allowed', async () => { + await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: { + ...defaultConfig, + apiUrl: 'http://genAi.mynonexistent.com', + }, + secrets, + }) + .expect(400) + .then((resp: any) => { + expect(resp.body).to.eql({ + statusCode: 400, + error: 'Bad Request', + message: + 'error validating action type config: Error configuring OpenAI action: Error: error validating url: target url "http://genAi.mynonexistent.com" is not added to the Kibana config xpack.actions.allowedHosts', + }); + }); + }); + + it('should return 400 Bad Request when creating the connector without secrets', async () => { + await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config, + }) + .expect(400) + .then((resp: any) => { + expect(resp.body).to.eql({ + statusCode: 400, + error: 'Bad Request', + message: + 'error validating action type secrets: [apiKey]: expected value of type [string] but got [undefined]', + }); + }); + }); + }); + + describe('executor', () => { + describe('validation', () => { + const simulator = new InferenceSimulator({ + proxy: { + config: configService.get('kbnTestServer.serverArgs'), + }, + }); + let genAiActionId: string; + + before(async () => { + const apiUrl = await simulator.start(); + genAiActionId = await createConnector(apiUrl); + }); + + after(() => { + simulator.close(); + }); + + it('should fail when the params is empty', async () => { + const { body } = await supertest + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: {}, + }); + expect(200); + + expect(body).to.eql({ + status: 'error', + connector_id: genAiActionId, + message: + 'error validating action params: [subAction]: expected value of type [string] but got [undefined]', + retry: false, + errorSource: TaskErrorSource.FRAMEWORK, + }); + }); + + it('should fail when the subAction is invalid', async () => { + const { body } = await supertest + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: { subAction: 'invalidAction' }, + }) + .expect(200); + + expect(body).to.eql({ + connector_id: genAiActionId, + status: 'error', + retry: true, + message: 'an error occurred while running the action', + errorSource: TaskErrorSource.FRAMEWORK, + service_message: `Sub action "invalidAction" is not registered. Connector id: ${genAiActionId}. Connector name: OpenAI. Connector type: .gen-ai`, + }); + }); + }); + + describe('execution', () => { + describe('successful response simulator', () => { + const simulator = new InferenceSimulator({ + proxy: { + config: configService.get('kbnTestServer.serverArgs'), + }, + }); + let apiUrl: string; + let genAiActionId: string; + + before(async () => { + apiUrl = await simulator.start(); + genAiActionId = await createConnector(apiUrl); + }); + + after(() => { + simulator.close(); + }); + + it('should send a stringified JSON object', async () => { + const { body } = await supertest + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: { + subAction: 'test', + subActionParams: { + body: '{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"Hello world"}]}', + }, + }, + }) + .expect(200); + + expect(simulator.requestData).to.eql({ + model: 'gpt-3.5-turbo', + messages: [{ role: 'user', content: 'Hello world' }], + }); + expect(body).to.eql({ + status: 'ok', + connector_id: genAiActionId, + data: inferenceSuccessResponse, + }); + + const events: IValidatedEvent[] = await retry.try(async () => { + return await getEventLog({ + getService, + spaceId: 'default', + type: 'action', + id: genAiActionId, + provider: 'actions', + actions: new Map([ + ['execute-start', { equal: 1 }], + ['execute', { equal: 1 }], + ]), + }); + }); + + const executeEvent = events[1]; + expect(executeEvent?.kibana?.action?.execution?.usage?.request_body_bytes).to.be(78); + }); + describe('Token tracking dashboard', () => { + const dashboardId = 'specific-dashboard-id-default'; + + it('should not create a dashboard when user does not have kibana event log permissions', async () => { + const { body } = await supertestWithoutAuth + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .auth('global_read', 'global_read-password') + .set('kbn-xsrf', 'foo') + .send({ + params: { + subAction: 'getDashboard', + subActionParams: { + dashboardId, + }, + }, + }) + .expect(200); + + // check dashboard has not been created + await supertest + .get(`/api/saved_objects/dashboard/${dashboardId}`) + .set('kbn-xsrf', 'foo') + .expect(404); + + expect(body).to.eql({ + status: 'ok', + connector_id: genAiActionId, + data: { available: false }, + }); + }); + + it('should create a dashboard when user has correct permissions', async () => { + const { body } = await supertest + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: { + subAction: 'getDashboard', + subActionParams: { + dashboardId, + }, + }, + }) + .expect(200); + + // check dashboard has been created + await retry.try(async () => + supertest + .get(`/api/saved_objects/dashboard/${dashboardId}`) + .set('kbn-xsrf', 'foo') + .expect(200) + ); + + objectRemover.add('default', dashboardId, 'dashboard', 'saved_objects'); + + expect(body).to.eql({ + status: 'ok', + connector_id: genAiActionId, + data: { available: true }, + }); + }); + }); + }); + describe('non-default space simulator', () => { + const simulator = new InferenceSimulator({ + proxy: { + config: configService.get('kbnTestServer.serverArgs'), + }, + }); + let apiUrl: string; + let genAiActionId: string; + + before(async () => { + apiUrl = await simulator.start(); + genAiActionId = await createConnector(apiUrl, 'space1'); + }); + after(() => { + simulator.close(); + }); + + const dashboardId = 'specific-dashboard-id-space1'; + + it('should create a dashboard in non-default space', async () => { + const { body } = await supertest + .post(`${getUrlPrefix('space1')}/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: { + subAction: 'getDashboard', + subActionParams: { + dashboardId, + }, + }, + }) + .expect(200); + + // check dashboard has been created + await retry.try( + async () => + await supertest + .get(`${getUrlPrefix('space1')}/api/saved_objects/dashboard/${dashboardId}`) + .set('kbn-xsrf', 'foo') + .expect(200) + ); + objectRemover.add('space1', dashboardId, 'dashboard', 'saved_objects'); + + expect(body).to.eql({ + status: 'ok', + connector_id: genAiActionId, + data: { available: true }, + }); + }); + }); + + describe('error response simulator', () => { + const simulator = new InferenceSimulator({ + returnError: true, + proxy: { + config: configService.get('kbnTestServer.serverArgs'), + }, + }); + + let genAiActionId: string; + + before(async () => { + const apiUrl = await simulator.start(); + genAiActionId = await createConnector(apiUrl); + }); + + after(() => { + simulator.close(); + }); + + it('should return a failure when error happens', async () => { + const { body } = await supertest + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: {}, + }) + .expect(200); + + expect(body).to.eql({ + status: 'error', + connector_id: genAiActionId, + message: + 'error validating action params: [subAction]: expected value of type [string] but got [undefined]', + retry: false, + errorSource: TaskErrorSource.FRAMEWORK, + }); + }); + + it('should return a error when error happens', async () => { + const { body } = await supertest + .post(`/api/actions/connector/${genAiActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .send({ + params: { + subAction: 'test', + subActionParams: { + body: '{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"Hello world"}]}', + }, + }, + }) + .expect(200); + + expect(body).to.eql({ + status: 'error', + connector_id: genAiActionId, + message: 'an error occurred while running the action', + retry: true, + errorSource: TaskErrorSource.FRAMEWORK, + service_message: + 'Status code: 422. Message: API Error: Unprocessable Entity - The model `bad model` does not exist', + }); + }); + }); + }); + }); + }); +} diff --git a/x-pack/test/alerting_api_integration/spaces_only/tests/actions/check_registered_connector_types.ts b/x-pack/test/alerting_api_integration/spaces_only/tests/actions/check_registered_connector_types.ts index 4b7dd28d63b5c..10449613f6ef6 100644 --- a/x-pack/test/alerting_api_integration/spaces_only/tests/actions/check_registered_connector_types.ts +++ b/x-pack/test/alerting_api_integration/spaces_only/tests/actions/check_registered_connector_types.ts @@ -53,6 +53,7 @@ export default function createRegisteredConnectorTypeTests({ getService }: FtrPr '.gen-ai', '.bedrock', '.gemini', + '.inference', '.sentinelone', '.cases', '.crowdstrike', diff --git a/x-pack/test/plugin_api_integration/test_suites/task_manager/check_registered_task_types.ts b/x-pack/test/plugin_api_integration/test_suites/task_manager/check_registered_task_types.ts index fcb782b069dbf..f5ccfcc3ca56f 100644 --- a/x-pack/test/plugin_api_integration/test_suites/task_manager/check_registered_task_types.ts +++ b/x-pack/test/plugin_api_integration/test_suites/task_manager/check_registered_task_types.ts @@ -61,6 +61,7 @@ export default function ({ getService }: FtrProviderContext) { 'actions:.gemini', 'actions:.gen-ai', 'actions:.index', + 'actions:.inference', 'actions:.jira', 'actions:.observability-ai-assistant', 'actions:.opsgenie',