Skip to content

Commit

Permalink
Add inference calculation for semantic_text
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Apr 30, 2024
1 parent b382b3b commit ba06e01
Show file tree
Hide file tree
Showing 15 changed files with 1,608 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ private void executeBulkRequestsByShard(
bulkRequest.getRefreshPolicy(),
requests.toArray(new BulkItemRequest[0])
);
var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName());
if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) {
bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields());
}
bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards());
bulkShardRequest.timeout(bulkRequest.timeout());
bulkShardRequest.routedBasedOnClusterVersion(clusterState.version());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import org.elasticsearch.action.support.replication.ReplicatedWriteRequest;
import org.elasticsearch.action.support.replication.ReplicationRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.transport.RawIndexingDataTransportRequest;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequest>
Expand All @@ -33,6 +35,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequ

private final BulkItemRequest[] items;

private transient Map<String, InferenceFieldMetadata> inferenceFieldMap = null;

public BulkShardRequest(StreamInput in) throws IOException {
super(in);
items = in.readArray(i -> i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new);
Expand All @@ -44,6 +48,30 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe
setRefreshPolicy(refreshPolicy);
}

/**
* Public for test
* Set the transient metadata indicating that this request requires running inference before proceeding.
*/
public void setInferenceFieldMap(Map<String, InferenceFieldMetadata> fieldInferenceMap) {
this.inferenceFieldMap = fieldInferenceMap;
}

/**
* Consumes the inference metadata to execute inference on the bulk items just once.
*/
public Map<String, InferenceFieldMetadata> consumeInferenceFieldMap() {
Map<String, InferenceFieldMetadata> ret = inferenceFieldMap;
inferenceFieldMap = null;
return ret;
}

/**
* Public for test
*/
public Map<String, InferenceFieldMetadata> getInferenceFieldMap() {
return inferenceFieldMap;
}

public long totalSizeInBytes() {
long totalSizeInBytes = 0;
for (int i = 0; i < items.length; i++) {
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ apply plugin: 'elasticsearch.internal-yaml-rest-test'

restResources {
restApi {
include '_common', 'indices', 'inference', 'index'
include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex', 'search'
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,6 @@ public TestServiceModel(
super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
}

@Override
public TestDenseInferenceServiceExtension.TestServiceSettings getServiceSettings() {
return (TestDenseInferenceServiceExtension.TestServiceSettings) super.getServiceSettings();
}

@Override
public TestTaskSettings getTaskSettings() {
return (TestTaskSettings) super.getTaskSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public static TestServiceSettings fromMap(Map<String, Object> map) {
SimilarityMeasure similarity = null;
String similarityStr = (String) map.remove("similarity");
if (similarityStr != null) {
similarity = SimilarityMeasure.valueOf(similarityStr);
similarity = SimilarityMeasure.fromString(similarityStr);
}

return new TestServiceSettings(model, dimensions, similarity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,25 @@ private SparseEmbeddingResults makeResults(List<String> input) {
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<SparseEmbeddingResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j));
tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, j + 1.0F));
}
embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false));
}
return new SparseEmbeddingResults(embeddings);
}

private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input) {
var chunks = new ArrayList<ChunkedTextExpansionResults.ChunkedResult>();
List<ChunkedInferenceServiceResults> results = new ArrayList<>();
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<TextExpansionResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j));
tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F));
}
chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens));
results.add(
new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)))
);
}
return List.of(new ChunkedSparseEmbeddingResults(chunks));
return results;
}

protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
Expand Down Expand Up @@ -45,6 +46,7 @@
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
Expand Down Expand Up @@ -76,6 +78,8 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.singletonList;

public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin {

/**
Expand All @@ -101,6 +105,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();

private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;

public InferencePlugin(Settings settings) {
Expand Down Expand Up @@ -166,6 +171,9 @@ public Collection<?> createComponents(PluginServices services) {
registry.init(services.client());
inferenceServiceRegistry.set(registry);

var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry);
shardBulkInferenceActionFilter.set(actionFilter);

return List.of(modelRegistry, registry);
}

Expand Down Expand Up @@ -272,4 +280,12 @@ public Map<String, Mapper.TypeParser> getMappers() {
}
return Map.of();
}

@Override
public Collection<ActionFilter> getActionFilters() {
if (SemanticTextFeature.isEnabled()) {
return singletonList(shardBulkInferenceActionFilter.get());
}
return List.of();
}
}
Loading

0 comments on commit ba06e01

Please sign in to comment.