Skip to content

Commit

Permalink
[ML] Do not create the .inference index as a side effect of calling u…
Browse files Browse the repository at this point in the history
…sage (elastic#115023)

The Inference usage API calls GET _inference/_all and because the default 
configs are persisted on read it causes the creation of the .inference index.
This action is undesirable and causes test failures by leaking the system index
out of the test clean up code.
  • Loading branch information
davidkyle authored and georgewallace committed Oct 25, 2024
1 parent cdc0195 commit 65161a0
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 38 deletions.
15 changes: 0 additions & 15 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,6 @@ tests:
- class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT
method: test {yaml=reference/rest-api/usage/line_38}
issue: https://github.com/elastic/elasticsearch/issues/113694
- class: org.elasticsearch.xpack.eql.EqlRestIT
method: testIndexWildcardPatterns
issue: https://github.com/elastic/elasticsearch/issues/114749
- class: org.elasticsearch.xpack.enrich.EnrichIT
method: testEnrichSpecialTypes
issue: https://github.com/elastic/elasticsearch/issues/114773
- class: org.elasticsearch.xpack.security.operator.OperatorPrivilegesIT
method: testEveryActionIsEitherOperatorOnlyOrNonOperator
issue: https://github.com/elastic/elasticsearch/issues/102992
Expand All @@ -312,23 +306,14 @@ tests:
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT
method: testTracingCrossCluster
issue: https://github.com/elastic/elasticsearch/issues/112731
- class: org.elasticsearch.xpack.enrich.EnrichIT
method: testImmutablePolicy
issue: https://github.com/elastic/elasticsearch/issues/114839
- class: org.elasticsearch.license.LicensingTests
issue: https://github.com/elastic/elasticsearch/issues/114865
- class: org.elasticsearch.xpack.enrich.EnrichIT
method: testDeleteIsCaseSensitive
issue: https://github.com/elastic/elasticsearch/issues/114840
- class: org.elasticsearch.packaging.test.EnrollmentProcessTests
method: test20DockerAutoFormCluster
issue: https://github.com/elastic/elasticsearch/issues/114885
- class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT
method: test {yaml=cluster.stats/30_ccs_stats/cross-cluster search stats search}
issue: https://github.com/elastic/elasticsearch/issues/114902
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
method: test {p0=enrich/40_synthetic_source/enrich documents over _bulk}
issue: https://github.com/elastic/elasticsearch/issues/114825
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
method: testInferDeploysDefaultElser
issue: https://github.com/elastic/elasticsearch/issues/114913
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REMOVE_MIN_COMPATIBLE_SHARD_NODE = def(8_773_00_0);
public static final TransportVersion REVERT_REMOVE_MIN_COMPATIBLE_SHARD_NODE = def(8_774_00_0);
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_PARENT_SIMPLIFIED = def(8_775_00_0);
public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ = def(8_776_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,8 +1121,6 @@ protected static void wipeAllIndices(boolean preserveSecurityIndices) throws IOE
if (preserveSecurityIndices) {
indexPatterns.add("-.security-*");
}
// always preserve inference index
indexPatterns.add("-.inference");
final Request deleteRequest = new Request("DELETE", Strings.collectionToCommaDelimitedString(indexPatterns));
deleteRequest.addParameter("expand_wildcards", "open,closed,hidden");
final Response response = adminClient().performRequest(deleteRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,40 @@ public GetInferenceModelAction() {

public static class Request extends AcknowledgedRequest<GetInferenceModelAction.Request> {

private static boolean PERSIST_DEFAULT_CONFIGS = true;

private final String inferenceEntityId;
private final TaskType taskType;
// Default endpoint configurations are persisted on first read.
// Set to false to avoid persisting on read.
// This setting only applies to GET * requests. It has
// no effect when getting a single model
private final boolean persistDefaultConfig;

public Request(String inferenceEntityId, TaskType taskType) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
}

public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.persistDefaultConfig = persistDefaultConfig;
}

public Request(StreamInput in) throws IOException {
super(in);
this.inferenceEntityId = in.readString();
this.taskType = TaskType.fromStream(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ)) {
this.persistDefaultConfig = in.readBoolean();
} else {
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
}

}

public String getInferenceEntityId() {
Expand All @@ -57,24 +78,33 @@ public TaskType getTaskType() {
return taskType;
}

public boolean isPersistDefaultConfig() {
return persistDefaultConfig;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ)) {
out.writeBoolean(this.persistDefaultConfig);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType;
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& taskType == request.taskType
&& persistDefaultConfig == request.persistDefaultConfig;
}

@Override
public int hashCode() {
return Objects.hash(inferenceEntityId, taskType);
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -251,7 +252,7 @@ public void testGetAllModels() throws InterruptedException {
}

AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(modelCount));
var getAllModels = modelHolder.get();
Expand Down Expand Up @@ -333,14 +334,14 @@ public void testGetAllModels_WithDefaults() throws Exception {
}

AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(totalModelCount));
var getAllModels = modelHolder.get();
assertReturnModelIsModifiable(modelHolder.get().get(0));

// same result but configs should have been persisted this time
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(totalModelCount));

Expand Down Expand Up @@ -387,7 +388,7 @@ public void testGetAllModels_OnlyDefaults() throws Exception {

AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(2));
var getAllModels = modelHolder.get();
Expand All @@ -405,6 +406,44 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
}
}

public void testGetAllModels_withDoNotPersist() throws Exception {
int defaultModelCount = 2;
var serviceName = "foo";
var service = mock(InferenceService.class);

var defaultConfigs = new ArrayList<Model>();
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (int i = 0; i < defaultModelCount; i++) {
var id = "default-" + i;
var taskType = randomFrom(TaskType.values());
defaultConfigs.add(createModel(id, taskType, serviceName));
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());

defaultIds.forEach(modelRegistry::addDefaultIds);

AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(false, listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(2));

expectThrows(IndexNotFoundException.class, () -> client().admin().indices().prepareGetIndex().addIndices(".inference").get());

// this time check the index is created
blockingCall(listener -> modelRegistry.getAllModels(true, listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
assertThat(modelHolder.get(), hasSize(2));
assertInferenceIndexExists();
}

public void testGet_WithDefaults() throws InterruptedException {
var serviceName = "foo";
var service = mock(InferenceService.class);
Expand Down Expand Up @@ -513,6 +552,12 @@ public void testGetByTaskType_WithDefaults() throws Exception {
assertReturnModelIsModifiable(modelHolder.get().get(0));
}

private void assertInferenceIndexExists() {
var indexResponse = client().admin().indices().prepareGetIndex().addIndices(".inference").get();
assertNotNull(indexResponse.getSettings());
assertNotNull(indexResponse.getMappings());
}

@SuppressWarnings("unchecked")
private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) {
var settings = unparsedModel.settings();
Expand Down Expand Up @@ -551,7 +596,6 @@ private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType)
);
default -> throw new IllegalArgumentException("task type " + taskType + " is not supported");
};

}

protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ protected void doExecute(
boolean inferenceEntityIdIsWildCard = Strings.isAllOrWildcard(request.getInferenceEntityId());

if (request.getTaskType() == TaskType.ANY && inferenceEntityIdIsWildCard) {
getAllModels(listener);
getAllModels(request.isPersistDefaultConfig(), listener);
} else if (inferenceEntityIdIsWildCard) {
getModelsByTaskType(request.getTaskType(), listener);
} else {
Expand Down Expand Up @@ -100,8 +100,9 @@ private void getSingleModel(
}));
}

private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {
private void getAllModels(boolean persistDefaultEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
modelRegistry.getAllModels(
persistDefaultEndpoints,
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ protected void masterOperation(
ClusterState state,
ActionListener<XPackUsageFeatureResponse> listener
) {
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY);
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, listener.delegateFailureAndWrap((delegate, response) -> {
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
for (ModelConfigurations model : response.getEndpoints()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@

import static org.elasticsearch.core.Strings.format;

/**
* Class for persisting and reading inference endpoint configurations.
* Some inference services provide default configurations, the registry is
* made aware of these at start up via {@link #addDefaultIds(InferenceService.DefaultConfigId)}.
* Only the ids and service details are registered at this point
* as the full config definition may not be known at start up.
* The full config is lazily populated on read and persisted to the
* index. This has the effect of creating the backing index on reading
* the configs. {@link #getAllModels(boolean, ActionListener)} has an option
* to not write the default configs to index on read to avoid index creation.
*/
public class ModelRegistry {
public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {}

Expand Down Expand Up @@ -132,7 +143,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
if (maybeDefault.isPresent()) {
getDefaultConfig(maybeDefault.get(), listener);
getDefaultConfig(true, maybeDefault.get(), listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
}
Expand Down Expand Up @@ -163,7 +174,7 @@ public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> lis
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
if (maybeDefault.isPresent()) {
getDefaultConfig(maybeDefault.get(), listener);
getDefaultConfig(true, maybeDefault.get(), listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
}
Expand Down Expand Up @@ -199,7 +210,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
addAllDefaultConfigsIfMissing(modelConfigs, defaultConfigsForTaskType, delegate);
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
});

QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
Expand All @@ -216,13 +227,20 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM

/**
* Get all models.
* If the defaults endpoint configurations have not been persisted then only
* persist them if {@code persistDefaultEndpoints == true}. Persisting the
* configs has the side effect of creating the index.
*
* Secret settings are not included
* @param persistDefaultEndpoints Persist the defaults endpoint configurations if
* not already persisted. When false this avoids the creation
* of the backing index.
* @param listener Models listener
*/
public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
public void getAllModels(boolean persistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
addAllDefaultConfigsIfMissing(foundConfigs, defaultConfigIds, delegate);
addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
});

// In theory the index should only contain model config documents
Expand All @@ -241,6 +259,7 @@ public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
}

private void addAllDefaultConfigsIfMissing(
boolean persistDefaultEndpoints,
List<UnparsedModel> foundConfigs,
List<InferenceService.DefaultConfigId> matchedDefaults,
ActionListener<List<UnparsedModel>> listener
Expand All @@ -263,18 +282,26 @@ private void addAllDefaultConfigsIfMissing(
);

for (var required : missing) {
getDefaultConfig(required, groupedListener);
getDefaultConfig(persistDefaultEndpoints, required, groupedListener);
}
}
}

private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, ActionListener<UnparsedModel> listener) {
private void getDefaultConfig(
boolean persistDefaultEndpoints,
InferenceService.DefaultConfigId defaultConfig,
ActionListener<UnparsedModel> listener
) {
defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
boolean foundModel = false;
for (var m : models) {
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
foundModel = true;
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
if (persistDefaultEndpoints) {
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
} else {
listener.onResponse(modelToUnparsedModel(m));
}
break;
}
}
Expand All @@ -287,7 +314,7 @@ private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, Ac
}));
}

public void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
var responseListener = ActionListener.<Boolean>wrap(success -> {
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
}, exception -> {
Expand Down
Loading

0 comments on commit 65161a0

Please sign in to comment.