Skip to content

Commit

Permalink
[ML][Data Frame] make response.count be total count of hits (#43241) (#…
Browse files Browse the repository at this point in the history
…43389)

* [ML][Data Frame] make response.count be total count of hits

* addressing line length check

* changing response count for filters

* adjusting serialization, variable name, and total count logic

* making count mandatory for creation
  • Loading branch information
benwtrent authored Jun 19, 2019
1 parent b333ced commit 77ce326
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ public void testGetFilters() throws Exception {
GetFiltersResponse getFiltersResponse = execute(getFiltersRequest,
machineLearningClient::getFilter,
machineLearningClient::getFilterAsync);
assertThat(getFiltersResponse.count(), equalTo(2L));
assertThat(getFiltersResponse.count(), equalTo(3L));
assertThat(getFiltersResponse.filters().size(), equalTo(2));
assertThat(getFiltersResponse.filters().stream().map(MlFilter::getId).collect(Collectors.toList()),
containsInAnyOrder("get-filter-test-2", "get-filter-test-3"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ protected void searchResources(AbstractGetResourcesRequest request, ActionListen
indicesOptions.expandWildcardsOpen(),
indicesOptions.expandWildcardsClosed(),
indicesOptions))
.source(sourceBuilder);
.source(sourceBuilder.trackTotalHits(true));

executeAsyncWithOrigin(client.threadPool().getThreadContext(),
executionOrigin(),
Expand All @@ -98,6 +98,7 @@ protected void searchResources(AbstractGetResourcesRequest request, ActionListen
public void onResponse(SearchResponse response) {
List<Resource> docs = new ArrayList<>();
Set<String> foundResourceIds = new HashSet<>();
long totalHitCount = response.getHits().getTotalHits().value;
for (SearchHit hit : response.getHits().getHits()) {
BytesReference docSource = hit.getSourceRef();
try (InputStream stream = docSource.streamInput();
Expand All @@ -115,7 +116,7 @@ public void onResponse(SearchResponse response) {
if (requiredMatches.hasUnmatchedIds()) {
listener.onFailure(notFoundException(requiredMatches.unmatchedIdsString()));
} else {
listener.onResponse(new QueryPage<>(docs, docs.size(), getResultsField()));
listener.onResponse(new QueryPage<>(docs, totalHitCount, getResultsField()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public static class Response extends AbstractGetResourcesResponse<DataFrameTrans
public static final String INVALID_TRANSFORMS_DEPRECATION_WARNING = "Found [{}] invalid transforms";
private static final ParseField INVALID_TRANSFORMS = new ParseField("invalid_transforms");

public Response(List<DataFrameTransformConfig> transformConfigs) {
super(new QueryPage<>(transformConfigs, transformConfigs.size(), DataFrameField.TRANSFORMS));
public Response(List<DataFrameTransformConfig> transformConfigs, long count) {
super(new QueryPage<>(transformConfigs, count, DataFrameField.TRANSFORMS));
}

public Response() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.dataframe.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.Action;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.TaskOperationFailure;
Expand All @@ -21,8 +22,10 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformStateAndStats;
import org.elasticsearch.xpack.core.dataframe.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -138,32 +141,52 @@ public boolean equals(Object obj) {
}

public static class Response extends BaseTasksResponse implements ToXContentObject {
private List<DataFrameTransformStateAndStats> transformsStateAndStats;
private final QueryPage<DataFrameTransformStateAndStats> transformsStateAndStats;

public Response(List<DataFrameTransformStateAndStats> transformsStateAndStats) {
super(Collections.emptyList(), Collections.emptyList());
this.transformsStateAndStats = transformsStateAndStats;
public Response(List<DataFrameTransformStateAndStats> transformStateAndStats, long count) {
this(new QueryPage<>(transformStateAndStats, count, DataFrameField.TRANSFORMS));
}

public Response(List<DataFrameTransformStateAndStats> transformsStateAndStats, List<TaskOperationFailure> taskFailures,
List<? extends ElasticsearchException> nodeFailures) {
public Response(List<DataFrameTransformStateAndStats> transformStateAndStats,
long count,
List<TaskOperationFailure> taskFailures,
List<? extends ElasticsearchException> nodeFailures) {
this(new QueryPage<>(transformStateAndStats, count, DataFrameField.TRANSFORMS), taskFailures, nodeFailures);
}

private Response(QueryPage<DataFrameTransformStateAndStats> transformsStateAndStats) {
this(transformsStateAndStats, Collections.emptyList(), Collections.emptyList());
}

private Response(QueryPage<DataFrameTransformStateAndStats> transformsStateAndStats,
List<TaskOperationFailure> taskFailures,
List<? extends ElasticsearchException> nodeFailures) {
super(taskFailures, nodeFailures);
this.transformsStateAndStats = transformsStateAndStats;
this.transformsStateAndStats = ExceptionsHelper.requireNonNull(transformsStateAndStats, "transformsStateAndStats");
}

public Response(StreamInput in) throws IOException {
super(in);
transformsStateAndStats = in.readList(DataFrameTransformStateAndStats::new);
if (in.getVersion().onOrAfter(Version.V_7_3_0)) {
transformsStateAndStats = new QueryPage<>(in, DataFrameTransformStateAndStats::new);
} else {
List<DataFrameTransformStateAndStats> stats = in.readList(DataFrameTransformStateAndStats::new);
transformsStateAndStats = new QueryPage<>(stats, stats.size(), DataFrameField.TRANSFORMS);
}
}

public List<DataFrameTransformStateAndStats> getTransformsStateAndStats() {
return transformsStateAndStats;
return transformsStateAndStats.results();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeList(transformsStateAndStats);
if (out.getVersion().onOrAfter(Version.V_7_3_0)) {
transformsStateAndStats.writeTo(out);
} else {
out.writeList(transformsStateAndStats.results());
}
}

@Override
Expand All @@ -175,8 +198,7 @@ public void readFrom(StreamInput in) {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentCommon(builder, params);
builder.field(DataFrameField.COUNT.getPreferredName(), transformsStateAndStats.size());
builder.field(DataFrameField.TRANSFORMS.getPreferredName(), transformsStateAndStats);
transformsStateAndStats.doXContentBody(builder, params);
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void testInvalidTransforms() throws IOException {
transforms.add(DataFrameTransformConfigTests.randomDataFrameTransformConfig());
transforms.add(DataFrameTransformConfigTests.randomInvalidDataFrameTransformConfig());

Response r = new Response(transforms);
Response r = new Response(transforms, transforms.size());
XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
r.toXContent(builder, XContent.EMPTY_PARAMS);
Map<String, Object> responseAsMap = createParser(builder).map();
Expand All @@ -52,7 +52,7 @@ public void testNoHeaderInResponse() throws IOException {
transforms.add(DataFrameTransformConfigTests.randomDataFrameTransformConfig());
}

Response r = new Response(transforms);
Response r = new Response(transforms, transforms.size());
XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
r.toXContent(builder, XContent.EMPTY_PARAMS);
Map<String, Object> responseAsMap = createParser(builder).map();
Expand All @@ -76,7 +76,7 @@ protected Response createTestInstance() {
configs.add(DataFrameTransformConfigTests.randomDataFrameTransformConfig());
}

return new Response(configs);
return new Response(configs, randomNonNegativeLong());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected Response createTestInstance() {
taskFailures.add(new TaskOperationFailure("node1", randomLongBetween(1, 10), new Exception("error")));
nodeFailures.add(new FailedNodeException("node1", "message", new Exception("error")));
}
return new Response(stats, taskFailures, nodeFailures);
return new Response(stats, randomLongBetween(stats.size(), 10_000_000L), taskFailures, nodeFailures);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public TransportGetDataFrameTransformsAction(TransportService transportService,
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
searchResources(request, ActionListener.wrap(
r -> listener.onResponse(new Response(r.results())),
r -> listener.onResponse(new Response(r.results(), r.count())),
listener::onFailure
));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected Response newResponse(Request request, List<Response> tasks, List<TaskO
.collect(Collectors.toList());
List<ElasticsearchException> allFailedNodeExceptions = new ArrayList<>(failedNodeExceptions);
allFailedNodeExceptions.addAll(tasks.stream().flatMap(r -> r.getNodeFailures().stream()).collect(Collectors.toList()));
return new Response(responses, taskOperationFailures, allFailedNodeExceptions);
return new Response(responses, responses.size(), taskOperationFailures, allFailedNodeExceptions);
}

@Override
Expand All @@ -83,36 +83,47 @@ protected void taskOperation(Request request, DataFrameTransformTask task, Actio
String nodeId = state.nodes().getLocalNode().getId();
if (task.isCancelled() == false) {
transformsCheckpointService.getCheckpointStats(task.getTransformId(), task.getCheckpoint(), task.getInProgressCheckpoint(),
ActionListener.wrap(checkpointStats -> {
listener.onResponse(new Response(Collections.singletonList(
new DataFrameTransformStateAndStats(task.getTransformId(), task.getState(), task.getStats(), checkpointStats))));
}, e -> {
listener.onResponse(new Response(
Collections.singletonList(new DataFrameTransformStateAndStats(task.getTransformId(), task.getState(),
task.getStats(), DataFrameTransformCheckpointingInfo.EMPTY)),
ActionListener.wrap(checkpointStats -> listener.onResponse(new Response(
Collections.singletonList(new DataFrameTransformStateAndStats(task.getTransformId(),
task.getState(),
task.getStats(),
checkpointStats)),
1L)),
e -> listener.onResponse(new Response(
Collections.singletonList(new DataFrameTransformStateAndStats(task.getTransformId(),
task.getState(),
task.getStats(),
DataFrameTransformCheckpointingInfo.EMPTY)),
1L,
Collections.emptyList(),
Collections.singletonList(new FailedNodeException(nodeId, "Failed to retrieve checkpointing info", e))));
}));
Collections.singletonList(new FailedNodeException(nodeId, "Failed to retrieve checkpointing info", e))))
));
} else {
listener.onResponse(new Response(Collections.emptyList()));
listener.onResponse(new Response(Collections.emptyList(), 0L));
}
}

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> finalListener) {
dataFrameTransformsConfigManager.expandTransformIds(request.getId(), request.getPageParams(), ActionListener.wrap(
ids -> {
request.setExpandedIds(ids);
request.setNodes(DataFrameNodes.dataFrameTaskNodes(ids, clusterService.state()));
hitsAndIds -> {
request.setExpandedIds(hitsAndIds.v2());
request.setNodes(DataFrameNodes.dataFrameTaskNodes(hitsAndIds.v2(), clusterService.state()));
super.doExecute(task, request, ActionListener.wrap(
response -> collectStatsForTransformsWithoutTasks(request, response, finalListener),
response -> collectStatsForTransformsWithoutTasks(request, response, ActionListener.wrap(
finalResponse -> finalListener.onResponse(new Response(finalResponse.getTransformsStateAndStats(),
hitsAndIds.v1(),
finalResponse.getTaskFailures(),
finalResponse.getNodeFailures())),
finalListener::onFailure
)),
finalListener::onFailure
));
},
e -> {
// If the index to search, or the individual config is not there, just return empty
if (e instanceof ResourceNotFoundException) {
finalListener.onResponse(new Response(Collections.emptyList()));
finalListener.onResponse(new Response(Collections.emptyList(), 0L));
} else {
finalListener.onFailure(e);
}
Expand Down Expand Up @@ -165,7 +176,10 @@ private void collectStatsForTransformsWithoutTasks(Request request,
// it can easily become arbitrarily ordered based on which transforms don't have a task or stats docs
allStateAndStats.sort(Comparator.comparing(DataFrameTransformStateAndStats::getId));

listener.onResponse(new Response(allStateAndStats, response.getTaskFailures(), response.getNodeFailures()));
listener.onResponse(new Response(allStateAndStats,
allStateAndStats.size(),
response.getTaskFailures(),
response.getNodeFailures()));
},
e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ protected void doExecute(Task task, StopDataFrameTransformAction.Request request
}

dataFrameTransformsConfigManager.expandTransformIds(request.getId(), new PageParams(0, 10_000), ActionListener.wrap(
expandedIds -> {
request.setExpandedIds(new HashSet<>(expandedIds));
request.setNodes(DataFrameNodes.dataFrameTaskNodes(expandedIds, clusterService.state()));
super.doExecute(task, request, finalListener);
},
listener::onFailure
hitsAndIds -> {
request.setExpandedIds(new HashSet<>(hitsAndIds.v2()));
request.setNodes(DataFrameNodes.dataFrameTaskNodes(hitsAndIds.v2(), clusterService.state()));
super.doExecute(task, request, finalListener);
},
listener::onFailure
));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -196,13 +197,16 @@ public void getTransformConfiguration(String transformId, ActionListener<DataFra
* @param pageParams The paging params
* @param foundIdsListener The listener on signal on success or failure
*/
public void expandTransformIds(String transformIdsExpression, PageParams pageParams, ActionListener<List<String>> foundIdsListener) {
public void expandTransformIds(String transformIdsExpression,
PageParams pageParams,
ActionListener<Tuple<Long, List<String>>> foundIdsListener) {
String[] idTokens = ExpandedIdsMatcher.tokenizeExpression(transformIdsExpression);
QueryBuilder queryBuilder = buildQueryFromTokenizedIds(idTokens, DataFrameTransformConfig.NAME);

SearchRequest request = client.prepareSearch(DataFrameInternalIndex.INDEX_NAME)
.addSort(DataFrameField.ID.getPreferredName(), SortOrder.ASC)
.setFrom(pageParams.getFrom())
.setTrackTotalHits(true)
.setSize(pageParams.getSize())
.setQuery(queryBuilder)
// We only care about the `id` field, small optimization
Expand All @@ -214,6 +218,7 @@ public void expandTransformIds(String transformIdsExpression, PageParams pagePar
executeAsyncWithOrigin(client.threadPool().getThreadContext(), DATA_FRAME_ORIGIN, request,
ActionListener.<SearchResponse>wrap(
searchResponse -> {
long totalHits = searchResponse.getHits().getTotalHits().value;
List<String> ids = new ArrayList<>(searchResponse.getHits().getHits().length);
for (SearchHit hit : searchResponse.getHits().getHits()) {
BytesReference source = hit.getSourceRef();
Expand All @@ -235,7 +240,7 @@ public void expandTransformIds(String transformIdsExpression, PageParams pagePar
requiredMatches.unmatchedIdsString())));
return;
}
foundIdsListener.onResponse(ids);
foundIdsListener.onResponse(new Tuple<>(totalHits, ids));
},
foundIdsListener::onFailure
), client::search);
Expand Down
Loading

0 comments on commit 77ce326

Please sign in to comment.