Skip to content

Commit

Permalink
[ML] Do not create the .inference index as a side effect of calling u…
Browse files Browse the repository at this point in the history
…sage (elastic#115023)

The Inference usage API calls GET _inference/_all and because the default
configs are persisted on read it causes the creation of the .inference index.
This action is undesirable and causes test failures by leaking the system index
out of the test clean up code.
# Conflicts:
#	muted-tests.yml
#	server/src/main/java/org/elasticsearch/TransportVersions.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java
  • Loading branch information
davidkyle committed Oct 22, 2024
1 parent 8a00fce commit 8238caf
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,40 @@ public GetInferenceModelAction() {

public static class Request extends AcknowledgedRequest<GetInferenceModelAction.Request> {

private static boolean PERSIST_DEFAULT_CONFIGS = true;

private final String inferenceEntityId;
private final TaskType taskType;
// Default endpoint configurations are persisted on first read.
// Set to false to avoid persisting on read.
// This setting only applies to GET * requests. It has
// no effect when getting a single model
private final boolean persistDefaultConfig;

public Request(String inferenceEntityId, TaskType taskType) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
}

public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.persistDefaultConfig = persistDefaultConfig;
}

public Request(StreamInput in) throws IOException {
super(in);
this.inferenceEntityId = in.readString();
this.taskType = TaskType.fromStream(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ)) {
this.persistDefaultConfig = in.readBoolean();
} else {
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
}

}

public String getInferenceEntityId() {
Expand All @@ -57,24 +78,33 @@ public TaskType getTaskType() {
return taskType;
}

public boolean isPersistDefaultConfig() {
return persistDefaultConfig;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ)) {
out.writeBoolean(this.persistDefaultConfig);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType;
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& taskType == request.taskType
&& persistDefaultConfig == request.persistDefaultConfig;
}

@Override
public int hashCode() {
return Objects.hash(inferenceEntityId, taskType);
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -250,7 +251,7 @@ public void testGetAllModels() throws InterruptedException {
}

AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(modelCount));
var getAllModels = modelHolder.get();
Expand Down Expand Up @@ -332,14 +333,14 @@ public void testGetAllModels_WithDefaults() throws Exception {
}

AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(totalModelCount));
var getAllModels = modelHolder.get();
assertReturnModelIsModifiable(modelHolder.get().get(0));

// same result but configs should have been persisted this time
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(totalModelCount));

Expand Down Expand Up @@ -386,7 +387,7 @@ public void testGetAllModels_OnlyDefaults() throws Exception {

AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(2));
var getAllModels = modelHolder.get();
Expand All @@ -404,6 +405,44 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
}
}

public void testGetAllModels_withDoNotPersist() throws Exception {
int defaultModelCount = 2;
var serviceName = "foo";
var service = mock(InferenceService.class);

var defaultConfigs = new ArrayList<Model>();
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (int i = 0; i < defaultModelCount; i++) {
var id = "default-" + i;
var taskType = randomFrom(TaskType.values());
defaultConfigs.add(createModel(id, taskType, serviceName));
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());

defaultIds.forEach(modelRegistry::addDefaultIds);

AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(false, listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(2));

expectThrows(IndexNotFoundException.class, () -> client().admin().indices().prepareGetIndex().addIndices(".inference").get());

// this time check the index is created
blockingCall(listener -> modelRegistry.getAllModels(true, listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(2));
assertInferenceIndexExists();
}

public void testGet_WithDefaults() throws InterruptedException {
var serviceName = "foo";
var service = mock(InferenceService.class);
Expand Down Expand Up @@ -512,6 +551,12 @@ public void testGetByTaskType_WithDefaults() throws Exception {
assertReturnModelIsModifiable(modelHolder.get().get(0));
}

private void assertInferenceIndexExists() {
var indexResponse = client().admin().indices().prepareGetIndex().addIndices(".inference").get();
assertNotNull(indexResponse.getSettings());
assertNotNull(indexResponse.getMappings());
}

@SuppressWarnings("unchecked")
private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) {
var settings = unparsedModel.settings();
Expand Down Expand Up @@ -550,7 +595,6 @@ private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType)
);
default -> throw new IllegalArgumentException("task type " + taskType + " is not supported");
};

}

protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ protected void doExecute(
boolean inferenceEntityIdIsWildCard = Strings.isAllOrWildcard(request.getInferenceEntityId());

if (request.getTaskType() == TaskType.ANY && inferenceEntityIdIsWildCard) {
getAllModels(listener);
getAllModels(request.isPersistDefaultConfig(), listener);
} else if (inferenceEntityIdIsWildCard) {
getModelsByTaskType(request.getTaskType(), listener);
} else {
Expand Down Expand Up @@ -114,8 +114,11 @@ private void getSingleModel(
}));
}

private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {
modelRegistry.getAllModels(listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))));
private void getAllModels(boolean persistDefaultEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
modelRegistry.getAllModels(
persistDefaultEndpoints,
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, l)))
);
}

private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ protected void masterOperation(
ClusterState state,
ActionListener<XPackUsageFeatureResponse> listener
) {
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY);
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, listener.delegateFailureAndWrap((delegate, response) -> {
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
for (ModelConfigurations model : response.getEndpoints()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@

import static org.elasticsearch.core.Strings.format;

/**
* Class for persisting and reading inference endpoint configurations.
* Some inference services provide default configurations, the registry is
* made aware of these at start up via {@link #addDefaultIds(InferenceService.DefaultConfigId)}.
* Only the ids and service details are registered at this point
* as the full config definition may not be known at start up.
* The full config is lazily populated on read and persisted to the
* index. This has the effect of creating the backing index on reading
* the configs. {@link #getAllModels(boolean, ActionListener)} has an option
* to not write the default configs to index on read to avoid index creation.
*/
public class ModelRegistry {
public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {}

Expand Down Expand Up @@ -132,7 +143,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
if (maybeDefault.isPresent()) {
getDefaultConfig(maybeDefault.get(), listener);
getDefaultConfig(true, maybeDefault.get(), listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
}
Expand Down Expand Up @@ -163,7 +174,7 @@ public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> lis
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
if (maybeDefault.isPresent()) {
getDefaultConfig(maybeDefault.get(), listener);
getDefaultConfig(true, maybeDefault.get(), listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
}
Expand Down Expand Up @@ -199,7 +210,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
addAllDefaultConfigsIfMissing(modelConfigs, defaultConfigsForTaskType, delegate);
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
});

QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
Expand All @@ -216,13 +227,20 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM

/**
* Get all models.
* If the defaults endpoint configurations have not been persisted then only
* persist them if {@code persistDefaultEndpoints == true}. Persisting the
* configs has the side effect of creating the index.
*
* Secret settings are not included
* @param persistDefaultEndpoints Persist the defaults endpoint configurations if
* not already persisted. When false this avoids the creation
* of the backing index.
* @param listener Models listener
*/
public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
public void getAllModels(boolean persistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
addAllDefaultConfigsIfMissing(foundConfigs, defaultConfigIds, delegate);
addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
});

// In theory the index should only contain model config documents
Expand All @@ -241,6 +259,7 @@ public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
}

private void addAllDefaultConfigsIfMissing(
boolean persistDefaultEndpoints,
List<UnparsedModel> foundConfigs,
List<InferenceService.DefaultConfigId> matchedDefaults,
ActionListener<List<UnparsedModel>> listener
Expand All @@ -263,18 +282,26 @@ private void addAllDefaultConfigsIfMissing(
);

for (var required : missing) {
getDefaultConfig(required, groupedListener);
getDefaultConfig(persistDefaultEndpoints, required, groupedListener);
}
}
}

private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, ActionListener<UnparsedModel> listener) {
private void getDefaultConfig(
boolean persistDefaultEndpoints,
InferenceService.DefaultConfigId defaultConfig,
ActionListener<UnparsedModel> listener
) {
defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
boolean foundModel = false;
for (var m : models) {
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
foundModel = true;
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
if (persistDefaultEndpoints) {
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
} else {
listener.onResponse(modelToUnparsedModel(m));
}
break;
}
}
Expand All @@ -287,7 +314,7 @@ private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, Ac
}));
}

public void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
var responseListener = ActionListener.<Boolean>wrap(success -> {
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
public class GetInferenceModelRequestTests extends AbstractWireSerializingTestCase<GetInferenceModelAction.Request> {

public static GetInferenceModelAction.Request randomTestInstance() {
return new GetInferenceModelAction.Request(randomAlphaOfLength(8), randomFrom(TaskType.values()));
return new GetInferenceModelAction.Request(randomAlphaOfLength(8), randomFrom(TaskType.values()), randomBoolean());
}

@Override
Expand All @@ -30,12 +30,17 @@ protected GetInferenceModelAction.Request createTestInstance() {

@Override
protected GetInferenceModelAction.Request mutateInstance(GetInferenceModelAction.Request instance) {
return switch (randomIntBetween(0, 1)) {
return switch (randomIntBetween(0, 2)) {
case 0 -> new GetInferenceModelAction.Request(instance.getInferenceEntityId() + "foo", instance.getTaskType());
case 1 -> {
var nextTaskType = TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length];
yield new GetInferenceModelAction.Request(instance.getInferenceEntityId(), nextTaskType);
}
case 2 -> new GetInferenceModelAction.Request(
instance.getInferenceEntityId(),
instance.getTaskType(),
instance.isPersistDefaultConfig() == false
);
default -> throw new UnsupportedOperationException();
};
}
Expand Down

0 comments on commit 8238caf

Please sign in to comment.