Skip to content

Commit

Permalink
Use FieldInferenceMetadata structure in dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Mar 13, 2024
1 parent 547ea9a commit 588e67f
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.metadata.FieldInferenceMetadata;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.analysis.IndexAnalyzers;
Expand Down Expand Up @@ -64,7 +66,7 @@ public class QueryRewriteContext {
protected Predicate<String> allowedFields;
private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final Map<String, Set<String>> modelsForFields;
private final Map<String, FieldInferenceMetadata> fieldInferenceMetadataForIndices;

public QueryRewriteContext(
final XContentParserConfiguration parserConfiguration,
Expand All @@ -83,8 +85,8 @@ public QueryRewriteContext(
final ScriptCompiler scriptService,
final ModelRegistry modelRegistry,
final InferenceServiceRegistry inferenceServiceRegistry,
final Map<String, Set<String>> modelsForFields
) {
@Nullable final Map<String, FieldInferenceMetadata> fieldInferenceMetadataForIndices
) {

this.parserConfiguration = parserConfiguration;
this.client = client;
Expand All @@ -103,9 +105,7 @@ public QueryRewriteContext(
this.scriptService = scriptService;
this.modelRegistry = modelRegistry;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelsForFields = modelsForFields != null ?
modelsForFields.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> Set.copyOf(e.getValue()))) :
Collections.emptyMap();
this.fieldInferenceMetadataForIndices = Objects.requireNonNullElse(fieldInferenceMetadataForIndices, Map.of());
}

public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) {
Expand Down Expand Up @@ -136,7 +136,7 @@ public QueryRewriteContext(
final LongSupplier nowInMillis,
final ModelRegistry modelRegistry,
final InferenceServiceRegistry inferenceServiceRegistry,
final Map<String, Set<String>> modelsForFields
final Map<String, FieldInferenceMetadata> fieldInferenceMetadataMap
) {
this(
parserConfiguration,
Expand All @@ -155,7 +155,7 @@ public QueryRewriteContext(
null,
modelRegistry,
inferenceServiceRegistry,
modelsForFields
fieldInferenceMetadataMap
);
}

Expand Down Expand Up @@ -400,8 +400,7 @@ public InferenceServiceRegistry getInferenceServiceRegistry() {
return inferenceServiceRegistry;
}

public Set<String> getModelsForField(String fieldName) {
Set<String> models = modelsForFields.get(fieldName);
return models != null ? models : Collections.emptySet();
public Set<String> getInferenceIdsForField(String fieldName) {
return fieldInferenceMetadataForIndices.values().stream().map(v -> v.getInferenceIdForField(fieldName)).collect(Collectors.toSet());
}
}
22 changes: 13 additions & 9 deletions server/src/main/java/org/elasticsearch/indices/IndicesService.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.cluster.metadata.FieldInferenceMetadata;
import org.elasticsearch.cluster.metadata.IndexAbstraction;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down Expand Up @@ -1702,18 +1703,21 @@ public AliasFilter buildAliasFilter(ClusterState state, String index, Set<String
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, IndicesRequest indicesRequest, ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry) {
Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), indicesRequest);
Map<String, Set<String>> modelsForFields = new HashMap<>();
Map<String, FieldInferenceMetadata> fieldInferenceMetadataMap = Arrays.stream(indices)
.collect(Collectors.toMap(Index::getName, index -> indexServiceSafe(index).getMetadata().getFieldInferenceMetadata()));

Map<String, Set<String>> inferenceIdsForFields = new HashMap<>();
// Collect all index inference ids for each field
for (Index index : indices) {
Map<String, Set<String>> fieldsForModels = indexService(index).getMetadata().getFieldsForModels();
for (Map.Entry<String, Set<String>> entry : fieldsForModels.entrySet()) {
for (String fieldName : entry.getValue()) {
Set<String> models = modelsForFields.computeIfAbsent(fieldName, v -> new HashSet<>());
models.add(entry.getKey());
}
}
Map<String, String> inferenceIdForFieldsInIndex = indexService(index).getMetadata()
.getFieldInferenceMetadata()
.getInferenceIdForFields();
inferenceIdForFieldsInIndex.entrySet().forEach(entry -> {
inferenceIdsForFields.computeIfAbsent(entry.getKey(), k -> new HashSet<>()).add(entry.getValue());
});
}

return new QueryRewriteContext(parserConfig, client, nowInMillis, modelRegistry, inferenceServiceRegistry, modelsForFields);
return new QueryRewriteContext(parserConfig, client, nowInMillis, modelRegistry, inferenceServiceRegistry, fieldInferenceMetadataMap);
}

public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ private static BulkShardRequest runBulkOperation(
) {
Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build();
IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME)
.fieldsForModels(fieldsForModels)
.fieldInferenceMetadata(fieldsForModels)
.settings(settings)
.numberOfShards(1)
.numberOfReplicas(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void testIndexMetadataSerialization() throws IOException {
.stats(indexStats)
.indexWriteLoadForecast(indexWriteLoadForecast)
.shardSizeInBytesForecast(shardSizeInBytesForecast)
.fieldsForModels(fieldsForModels)
.fieldInferenceMetadata(fieldsForModels)
.build();
assertEquals(system, metadata.isSystem());

Expand Down Expand Up @@ -556,7 +556,7 @@ public void testFieldsForModels() {
assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of()));

Map<String, Set<String>> fieldsForModels = randomFieldsForModels(false);
IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build();
IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldsForModels).build();
assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void testEmpty() {
assertNotNull(names);
assertThat(names, hasSize(0));

Map<String, Set<String>> fieldsForModels = lookup.getFieldsForModels();
Map<String, Set<String>> fieldsForModels = lookup.getInferenceIdsForFields();
assertNotNull(fieldsForModels);
assertTrue(fieldsForModels.isEmpty());
}
Expand All @@ -48,7 +48,7 @@ public void testAddNewField() {
assertNull(lookup.get("bar"));
assertEquals(f.fieldType(), lookup.get("foo"));

Map<String, Set<String>> fieldsForModels = lookup.getFieldsForModels();
Map<String, Set<String>> fieldsForModels = lookup.getInferenceIdsForFields();
assertNotNull(fieldsForModels);
assertTrue(fieldsForModels.isEmpty());
}
Expand Down Expand Up @@ -440,7 +440,7 @@ public void testInferenceModelFieldType() {
assertEquals(f2.fieldType(), lookup.get("foo2"));
assertEquals(f3.fieldType(), lookup.get("foo3"));

Map<String, Set<String>> fieldsForModels = lookup.getFieldsForModels();
Map<String, Set<String>> fieldsForModels = lookup.getInferenceIdsForFields();
assertNotNull(fieldsForModels);
assertEquals(2, fieldsForModels.size());
assertEquals(Set.of("foo1", "foo2"), fieldsForModels.get("bar1"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ public void testEmptyMappingLookup() {
assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size());
assertFalse(mappingLookup.fieldMappers().iterator().hasNext());
assertEquals(0, mappingLookup.getMatchingFieldNames("*").size());
assertNotNull(mappingLookup.getFieldsForModels());
assertTrue(mappingLookup.getFieldsForModels().isEmpty());
assertNotNull(mappingLookup.getInferenceIdsForFields());
assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty());
}

public void testValidateDoesNotShadow() {
Expand Down Expand Up @@ -201,7 +201,7 @@ public void testFieldsForModels() {
assertEquals(1, size(mappingLookup.fieldMappers()));
assertEquals(fieldType, mappingLookup.getFieldType("test_field_name"));

Map<String, Set<String>> fieldsForModels = mappingLookup.getFieldsForModels();
Map<String, Set<String>> fieldsForModels = mappingLookup.getInferenceIdsForFields();
assertNotNull(fieldsForModels);
assertEquals(1, fieldsForModels.size());
assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format)
}

@Override
public String getInferenceModel() {
public String getInferenceId() {
return modelId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.indices.recovery.RecoverySettings;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.plugins.MockPluginsService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
Expand Down Expand Up @@ -101,7 +103,9 @@ SearchService newSearchService(
ResponseCollectorService responseCollectorService,
CircuitBreakerService circuitBreakerService,
ExecutorSelector executorSelector,
Tracer tracer
Tracer tracer,
ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry
) {
if (pluginsService.filterPlugins(MockSearchService.TestPlugin.class).findAny().isEmpty()) {
return super.newSearchService(
Expand All @@ -115,7 +119,9 @@ SearchService newSearchService(
responseCollectorService,
circuitBreakerService,
executorSelector,
tracer
tracer,
modelRegistry,
inferenceServiceRegistry
);
}
return new MockSearchService(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.elasticsearch.indices.ExecutorSelector;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.node.MockNode;
import org.elasticsearch.node.ResponseCollectorService;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -97,7 +99,9 @@ public MockSearchService(
responseCollectorService,
circuitBreakerService,
executorSelector,
tracer
tracer,
null,
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ QueryRewriteContext createQueryRewriteContext() {
() -> true,
scriptService,
null,
null,
null
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
return this;
}

Set<String> modelsForField = queryRewriteContext.getModelsForField(fieldName);
Set<String> modelsForField = queryRewriteContext.getInferenceIdsForField(fieldName);
if (modelsForField.isEmpty()) {
throw new IllegalArgumentException("Field [" + fieldName + "] is not a semantic_text field type");
}
Expand Down

0 comments on commit 588e67f

Please sign in to comment.