Skip to content

Commit

Permalink
Use a dedicated interface to check for binding of inference and model…
Browse files Browse the repository at this point in the history
… registries
  • Loading branch information
carlosdelest committed Feb 2, 2024
1 parent 6539c2a commit fd70765
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 23 deletions.
25 changes: 11 additions & 14 deletions server/src/main/java/org/elasticsearch/node/NodeConstruction.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
import org.elasticsearch.plugins.ClusterPlugin;
import org.elasticsearch.plugins.DiscoveryPlugin;
import org.elasticsearch.plugins.HealthPlugin;
import org.elasticsearch.plugins.InferenceRegistryPlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.MetadataUpgrader;
Expand Down Expand Up @@ -1090,26 +1091,22 @@ record PluginServiceInstances(
}

// Register noop versions of inference services if Inference plugin is not available
if (isPluginComponentDefined(pluginComponents, InferenceServiceRegistry.class) == false) {
logger.warn("Inference service is not available");
modules.bindToInstance(InferenceServiceRegistry.class, new InferenceServiceRegistry.NoopInferenceServiceRegistry());
}
if (isPluginComponentDefined(pluginComponents, ModelRegistry.class) == false) {
logger.warn("Model registry is not available");
modules.bindToInstance(ModelRegistry.class, new ModelRegistry.NoopModelRegistry());
}
Optional<InferenceRegistryPlugin> inferenceRegistryPlugin = getSinglePlugin(InferenceRegistryPlugin.class);
modules.bindToInstance(
InferenceServiceRegistry.class,
inferenceRegistryPlugin.map(InferenceRegistryPlugin::getInferenceServiceRegistry)
.orElse(new InferenceServiceRegistry.NoopInferenceServiceRegistry())
);
modules.bindToInstance(
ModelRegistry.class,
inferenceRegistryPlugin.map(InferenceRegistryPlugin::getModelRegistry).orElse(new ModelRegistry.NoopModelRegistry())
);

injector = modules.createInjector();

postInjection(clusterModule, actionModule, clusterService, transportService, featureService);
}

private static boolean isPluginComponentDefined(Collection<?> pluginComponents, Class<?> clazz) {
return pluginComponents.stream()
.map(p -> p instanceof PluginComponentBinding ? ((PluginComponentBinding) p).impl() : p)
.anyMatch(p -> clazz.isAssignableFrom(clazz));
}

private ClusterService createClusterService(SettingsModule settingsModule, ThreadPool threadPool, TaskManager taskManager) {
ClusterService clusterService = new ClusterService(
settingsModule.getSettings(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.plugins;

import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.ModelRegistry;

/**
* Plugins that provide inference services should implement this interface.
* There should be a single one in the classpath, as we currently support a single instance for ModelRegistry / InfereceServiceRegistry.
*/
public interface InferenceRegistryPlugin {
InferenceServiceRegistry getInferenceServiceRegistry();

ModelRegistry getModelRegistry();
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.InferenceRegistryPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.rest.RestController;
Expand Down Expand Up @@ -70,7 +71,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin {
public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, InferenceRegistryPlugin {

public static final String NAME = "inference";
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";
Expand All @@ -79,6 +80,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();

private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();

private List<InferenceServiceExtension> inferenceServiceExtensions;

Expand Down Expand Up @@ -130,7 +132,7 @@ public Collection<?> createComponents(PluginServices services) {
);
httpFactory.set(httpRequestSenderFactory);

ModelRegistry modelRegistry = new ModelRegistryImpl(services.client());
ModelRegistry modelReg = new ModelRegistryImpl(services.client());

if (inferenceServiceExtensions == null) {
inferenceServiceExtensions = new ArrayList<>();
Expand All @@ -139,14 +141,13 @@ public Collection<?> createComponents(PluginServices services) {
inferenceServices.add(this::getInferenceServiceFactories);

var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client());
var registry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext);
registry.init(services.client());
inferenceServiceRegistry.set(registry);
var inferenceRegistry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext);
inferenceRegistry.init(services.client());
inferenceServiceRegistry.set(inferenceRegistry);
modelRegistry.set(modelReg);

return List.of(
new PluginComponentBinding<>(ModelRegistry.class, modelRegistry),
new PluginComponentBinding<>(InferenceServiceRegistry.class, registry)
);
// Don't return components as they will be registered using InferenceRegistryPlugin methods to retrieve them
return List.of();
}

@Override
Expand Down Expand Up @@ -241,4 +242,14 @@ public void close() {

IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose);
}

@Override
public InferenceServiceRegistry getInferenceServiceRegistry() {
return inferenceServiceRegistry.get();
}

@Override
public ModelRegistry getModelRegistry() {
return modelRegistry.get();
}
}

0 comments on commit fd70765

Please sign in to comment.