Skip to content

Commit

Permalink
Support mTLS in Elastic Inference Service plugin (elastic#116423) (el…
Browse files Browse the repository at this point in the history
…astic#119679)

* Support mTLS in Elastic Inference Service plugin (elastic#116423)

* Introduce new SSL settings under `xpack.inference.elastic.http.ssl`.

* Support mTLS connection between Elasticsearch and Elastic Inference Service.

* Update docs/changelog/119679.yaml

* Apply new changelog

* [CI] Auto commit changes from spotless

---------

Co-authored-by: elasticsearchmachine <[email protected]>
  • Loading branch information
vidok and elasticsearchmachine authored Jan 15, 2025
1 parent f7f8ab0 commit 03424aa
Show file tree
Hide file tree
Showing 22 changed files with 314 additions and 76 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/119679.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 119679
summary: Support mTLS for the Elastic Inference Service integration inside the inference API
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,8 @@ static Map<String, Settings> getSSLSettingsMap(Settings settings) {
sslSettingsMap.put(WatcherField.EMAIL_NOTIFICATION_SSL_PREFIX, settings.getByPrefix(WatcherField.EMAIL_NOTIFICATION_SSL_PREFIX));
sslSettingsMap.put(XPackSettings.TRANSPORT_SSL_PREFIX, settings.getByPrefix(XPackSettings.TRANSPORT_SSL_PREFIX));
sslSettingsMap.putAll(getTransportProfileSSLSettings(settings));
// Mount Elastic Inference Service (part of the Inference plugin) configuration
sslSettingsMap.put("xpack.inference.elastic.http.ssl", settings.getByPrefix("xpack.inference.elastic.http.ssl."));
// Only build remote cluster server SSL if the port is enabled
if (REMOTE_CLUSTER_SERVER_ENABLED.get(settings)) {
sslSettingsMap.put(XPackSettings.REMOTE_CLUSTER_SERVER_SSL_PREFIX, getRemoteClusterServerSslSettings(settings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ public Map<String, IndexStorePlugin.SnapshotCommitSupplier> getSnapshotCommitSup
}

@SuppressWarnings("unchecked")
private <T> List<T> filterPlugins(Class<T> type) {
protected <T> List<T> filterPlugins(Class<T> type) {
return plugins.stream().filter(x -> type.isAssignableFrom(x.getClass())).map(p -> ((T) p)).collect(Collectors.toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,8 @@ public void testGetConfigurationByContextName() throws Exception {
"xpack.security.authc.realms.ldap.realm1.ssl",
"xpack.security.authc.realms.saml.realm2.ssl",
"xpack.monitoring.exporters.mon1.ssl",
"xpack.monitoring.exporters.mon2.ssl" };
"xpack.monitoring.exporters.mon2.ssl",
"xpack.inference.elastic.http.ssl" };

assumeTrue("Not enough cipher suites are available to support this test", getCipherSuites.length >= contextNames.length);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
Expand Down Expand Up @@ -74,7 +74,7 @@ public void setup() throws Exception {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(Utils.TestInferencePlugin.class, LocalStateCompositeXPackPlugin.class);
return Arrays.asList(LocalStateInferencePlugin.class);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
Expand Down Expand Up @@ -77,7 +76,7 @@ public void createComponents() {

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(ReindexPlugin.class, InferencePlugin.class, LocalStateCompositeXPackPlugin.class);
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
}

public void testStoreModel() throws Exception {
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
requires software.amazon.awssdk.retries.api;
requires org.reactivestreams;
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
Expand All @@ -49,6 +50,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction;
Expand All @@ -58,6 +60,7 @@
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction;
import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction;
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
Expand Down Expand Up @@ -126,7 +129,6 @@
import java.util.Map;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.singletonList;
Expand Down Expand Up @@ -166,6 +168,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final Settings settings;
private final SetOnce<HttpRequestSender.Factory> httpFactory = new SetOnce<>();
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
private final SetOnce<HttpRequestSender.Factory> elasicInferenceServiceFactory = new SetOnce<>();
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
// not being initialized yet
Expand Down Expand Up @@ -252,31 +255,31 @@ public Collection<?> createComponents(PluginServices services) {
var inferenceServices = new ArrayList<>(inferenceServiceExtensions);
inferenceServices.add(this::getInferenceServiceFactories);

// Set elasticInferenceUrl based on feature flags to support transitioning to the new Elastic Inference Service URL without exposing
// internal names like "eis" or "gateway".
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);

String elasticInferenceUrl = null;
if (isElasticInferenceServiceEnabled()) {
// Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`).
var elasticInferenceServiceHttpClientManager = HttpClientManager.create(
settings,
services.threadPool(),
services.clusterService(),
throttlerManager,
getSslService()
);

if (ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl();
} else if (DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
log.warn(
"Deprecated flag {} detected for enabling {}. Please use {}.",
ELASTIC_INFERENCE_SERVICE_IDENTIFIER,
DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG,
ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG
var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory(
serviceComponents.get(),
elasticInferenceServiceHttpClientManager,
services.clusterService()
);
elasticInferenceUrl = inferenceServiceSettings.getEisGatewayUrl();
}
elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory);

if (elasticInferenceUrl != null) {
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
String elasticInferenceUrl = this.getElasticInferenceServiceUrl(inferenceServiceSettings);
elasticInferenceServiceComponents.set(new ElasticInferenceServiceComponents(elasticInferenceUrl));

inferenceServices.add(
() -> List.of(
context -> new ElasticInferenceService(
httpFactory.get(),
elasicInferenceServiceFactory.get(),
serviceComponents.get(),
elasticInferenceServiceComponents.get()
)
Expand Down Expand Up @@ -400,16 +403,21 @@ public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {

@Override
public List<Setting<?>> getSettings() {
return Stream.of(
HttpSettings.getSettingsDefinitions(),
HttpClientManager.getSettingsDefinitions(),
ThrottlerManager.getSettingsDefinitions(),
RetrySettings.getSettingsDefinitions(),
ElasticInferenceServiceSettings.getSettingsDefinitions(),
Truncator.getSettingsDefinitions(),
RequestExecutorServiceSettings.getSettingsDefinitions(),
List.of(SKIP_VALIDATE_AND_START)
).flatMap(Collection::stream).collect(Collectors.toList());
ArrayList<Setting<?>> settings = new ArrayList<>();
settings.addAll(HttpSettings.getSettingsDefinitions());
settings.addAll(HttpClientManager.getSettingsDefinitions());
settings.addAll(ThrottlerManager.getSettingsDefinitions());
settings.addAll(RetrySettings.getSettingsDefinitions());
settings.addAll(Truncator.getSettingsDefinitions());
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START);

// Register Elastic Inference Service settings definitions if the corresponding feature flag is enabled.
if (isElasticInferenceServiceEnabled()) {
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
}

return settings;
}

@Override
Expand Down Expand Up @@ -466,7 +474,10 @@ public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
@Override
public List<RetrieverSpec<?>> getRetrievers() {
return List.of(
new RetrieverSpec<>(new ParseField(TextSimilarityRankBuilder.NAME), TextSimilarityRankRetrieverBuilder::fromXContent),
new RetrieverSpec<>(
new ParseField(TextSimilarityRankBuilder.NAME),
(parser, context) -> TextSimilarityRankRetrieverBuilder.fromXContent(parser, context, getLicenseState())
),
new RetrieverSpec<>(new ParseField(RandomRankBuilder.NAME), RandomRankRetrieverBuilder::fromXContent)
);
}
Expand All @@ -475,4 +486,36 @@ public List<RetrieverSpec<?>> getRetrievers() {
public Map<String, Highlighter> getHighlighters() {
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
}

// Get Elastic Inference service URL based on feature flags to support transitioning
// to the new Elastic Inference Service URL.
private String getElasticInferenceServiceUrl(ElasticInferenceServiceSettings settings) {
String elasticInferenceUrl = null;

if (ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
elasticInferenceUrl = settings.getElasticInferenceServiceUrl();
} else if (DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
log.warn(
"Deprecated flag {} detected for enabling {}. Please use {}.",
ELASTIC_INFERENCE_SERVICE_IDENTIFIER,
DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG,
ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG
);
elasticInferenceUrl = settings.getEisGatewayUrl();
}

return elasticInferenceUrl;
}

protected Boolean isElasticInferenceServiceEnabled() {
return (ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() || DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled());
}

protected SSLService getSslService() {
return XPackPlugin.getSharedSslService();
}

protected XPackLicenseState getLicenseState() {
return XPackPlugin.getSharedLicenseState();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

package org.elasticsearch.xpack.inference.external.http;

import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager;
import org.apache.http.impl.nio.reactor.DefaultConnectingIOReactor;
import org.apache.http.impl.nio.reactor.IOReactorConfig;
import org.apache.http.nio.conn.NoopIOSessionStrategy;
import org.apache.http.nio.conn.SchemeIOSessionStrategy;
import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy;
import org.apache.http.nio.reactor.ConnectingIOReactor;
import org.apache.http.nio.reactor.IOReactorException;
import org.apache.http.pool.PoolStats;
Expand All @@ -21,18 +26,21 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.io.Closeable;
import java.io.IOException;
import java.util.List;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX;

public class HttpClientManager implements Closeable {
private static final Logger logger = LogManager.getLogger(HttpClientManager.class);
/**
* The maximum number of total connections the connection pool can lease to all routes.
* The configuration applies to each instance of HTTPClientManager (max_total_connections=10 and instances=5 leads to 50 connections).
* From googling around the connection pools maxTotal value should be close to the number of available threads.
*
* https://stackoverflow.com/questions/30989637/how-to-decide-optimal-settings-for-setmaxtotal-and-setdefaultmaxperroute
Expand All @@ -47,6 +55,7 @@ public class HttpClientManager implements Closeable {

/**
* The max number of connections a single route can lease.
* This configuration applies to each instance of HttpClientManager.
*/
public static final Setting<Integer> MAX_ROUTE_CONNECTIONS = Setting.intSetting(
"xpack.inference.http.max_route_connections",
Expand Down Expand Up @@ -98,6 +107,22 @@ public static HttpClientManager create(
return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager);
}

public static HttpClientManager create(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
ThrottlerManager throttlerManager,
SSLService sslService
) {
// Set the sslStrategy to ensure an encrypted connection, as Elastic Inference Service requires it.
SSLIOSessionStrategy sslioSessionStrategy = sslService.sslIOSessionStrategy(
sslService.getSSLConfiguration(ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX)
);

PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy);
return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager);
}

// Default for testing
HttpClientManager(
Settings settings,
Expand All @@ -121,6 +146,25 @@ public static HttpClientManager create(
this.addSettingsUpdateConsumers(clusterService);
}

private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy) {
ConnectingIOReactor ioReactor;
try {
var configBuilder = IOReactorConfig.custom().setSoKeepAlive(true);
ioReactor = new DefaultConnectingIOReactor(configBuilder.build());
} catch (IOReactorException e) {
var message = "Failed to initialize HTTP client manager with SSL.";
logger.error(message, e);
throw new ElasticsearchException(message, e);
}

Registry<SchemeIOSessionStrategy> registry = RegistryBuilder.<SchemeIOSessionStrategy>create()
.register("http", NoopIOSessionStrategy.INSTANCE)
.register("https", sslStrategy)
.build();

return new PoolingNHttpClientConnectionManager(ioReactor, registry);
}

private static PoolingNHttpClientConnectionManager createConnectionManager() {
ConnectingIOReactor ioReactor;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
Expand All @@ -21,7 +22,6 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -78,8 +78,11 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
RetrieverBuilder.declareBaseParserFields(TextSimilarityRankBuilder.NAME, PARSER);
}

public static TextSimilarityRankRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context)
throws IOException {
public static TextSimilarityRankRetrieverBuilder fromXContent(
XContentParser parser,
RetrieverParserContext context,
XPackLicenseState licenceState
) throws IOException {
if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED) == false) {
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + TextSimilarityRankBuilder.NAME + "]");
}
Expand All @@ -88,7 +91,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent(XContentParser par
"[text_similarity_reranker] retriever composition feature is not supported by all nodes in the cluster"
);
}
if (TextSimilarityRankBuilder.TEXT_SIMILARITY_RERANKER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
if (TextSimilarityRankBuilder.TEXT_SIMILARITY_RERANKER_FEATURE.check(licenceState) == false) {
throw LicenseUtils.newComplianceException(TextSimilarityRankBuilder.NAME);
}
return PARSER.apply(parser, context);
Expand Down
Loading

0 comments on commit 03424aa

Please sign in to comment.