Skip to content

Commit

Permalink
POC - Hide inference results for semantic_text
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed May 22, 2024
1 parent d6f838c commit 89c1ca0
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ static <Request extends AbstractBulkByScrollRequest<Request>> SearchRequest prep
}
sourceBuilder.version(needsSourceDocumentVersions);
sourceBuilder.seqNoAndPrimaryTerm(needsSourceDocumentSeqNoAndPrimaryTerm);
sourceBuilder.hideSourceFields(false);

/*
* Do not open scroll if max docs <= scroll size and not resuming on version conflicts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REMOVE_ES_SOURCE_OPTIONS = def(8_661_00_0);
public static final TransportVersion NODE_STATS_INGEST_BYTES = def(8_662_00_0);
public static final TransportVersion SEMANTIC_QUERY = def(8_663_00_0);
public static final TransportVersion HIDE_SOURCE_FIELDS = def(8_664_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.elasticsearch.indices.ExecutorSelector;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

Expand Down Expand Up @@ -154,7 +155,7 @@ protected GetResponse shardOperation(GetRequest request, ShardId shardId) throws
request.realtime(),
request.version(),
request.versionType(),
request.fetchSourceContext(),
FetchSourceContext.of(request.fetchSourceContext(), indexShard),
request.isForceSyntheticSource()
);
return new GetResponse(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ public AbstractBulkByScrollRequest(SearchRequest searchRequest, boolean setDefau
searchRequest.source(new SearchSourceBuilder());
searchRequest.source().size(DEFAULT_SCROLL_SIZE);
}
this.searchRequest.source().hideSourceFields(false);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
import org.elasticsearch.search.fetch.ShardFetchRequest;
import org.elasticsearch.search.fetch.subphase.FetchDocValuesContext;
import org.elasticsearch.search.fetch.subphase.FetchFieldsContext;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext.ScriptField;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.internal.AliasFilter;
Expand Down Expand Up @@ -1345,6 +1346,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.fetchSource() != null) {
context.fetchSourceContext(source.fetchSource());
}
if (source.hideSourceFields()) {
context.fetchSourceContext(FetchSourceContext.of(source.fetchSource(), context.indexShard()));
}

if (source.docValueFields() != null) {
FetchDocValuesContext docValuesContext = new FetchDocValuesContext(
context.getSearchExecutionContext(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ public static HighlightBuilder highlight() {

private Map<String, Object> runtimeMappings = emptyMap();

private boolean hideSourceFields = true;

/**
* Constructs a new search source builder.
*/
Expand Down Expand Up @@ -279,6 +281,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
rankBuilder = in.readOptionalNamedWriteable(RankBuilder.class);
}
if (in.getTransportVersion().onOrAfter(TransportVersions.HIDE_SOURCE_FIELDS)) {
hideSourceFields = in.readBoolean();
}
}

@Override
Expand Down Expand Up @@ -365,6 +370,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else if (rankBuilder != null) {
throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]");
}
if (out.getTransportVersion().onOrAfter(TransportVersions.HIDE_SOURCE_FIELDS)) {
out.writeBoolean(hideSourceFields);
}
}

/**
Expand Down Expand Up @@ -1113,6 +1121,15 @@ public SearchSourceBuilder runtimeMappings(Map<String, Object> runtimeMappings)
return this;
}

public boolean hideSourceFields() {
return hideSourceFields;
}

public SearchSourceBuilder hideSourceFields(boolean hideSourceFields) {
this.hideSourceFields = hideSourceFields;
return this;
}

/**
* Rewrites this search source builder into its primitive form. e.g. by
* rewriting the QueryBuilder. If the builder did not change the identity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.ArrayUtils;
import org.elasticsearch.core.Booleans;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.lookup.SourceFilter;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -52,6 +54,31 @@ public static FetchSourceContext of(boolean fetchSource, @Nullable String[] incl
return new FetchSourceContext(fetchSource, includes, excludes);
}

public static FetchSourceContext of(FetchSourceContext original, IndexShard indexShard) {
// TODO: Delegate name retrieval to InferencerFieldMapper
String[] inferenceFields = indexShard.mapperService()
.mappingLookup()
.inferenceFields()
.keySet()
.stream()
.map(s -> s + ".inference")
.toArray(String[]::new);
if (inferenceFields.length == 0) {
return original;
}
if (original == null) {
return FetchSourceContext.of(true, null, inferenceFields);
}
if (original.includes() == null || original.includes().length == 0) {
return FetchSourceContext.of(
original.fetchSource(),
original.includes(),
ArrayUtils.concat(original.excludes(), inferenceFields)
);
}
return original;
}

public static FetchSourceContext readFrom(StreamInput in) throws IOException {
final boolean fetchSource = in.readBoolean();
final String[] includes = in.readStringArray();
Expand Down

0 comments on commit 89c1ca0

Please sign in to comment.