Skip to content

Commit

Permalink
Extend async search keep alive (#67877)
Browse files Browse the repository at this point in the history
There can be a race between two GET async search requests, and the one 
with a lower keep_alive parameter wins the race. This scenario is not 
desirable as we should retain the search result for all requests. This
commit ensures the keep_alive is extended and never goes backward.
  • Loading branch information
dnhatn authored Jan 25, 2021
1 parent 3d2e82f commit 244fc95
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.apache.lucene.store.AlreadyClosedException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
Expand Down Expand Up @@ -345,7 +346,7 @@ public void testUpdateRunningKeepAlive() throws Exception {
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(6));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));

assertTrue(response.isRunning());
Expand All @@ -364,8 +365,13 @@ public void testUpdateRunningKeepAlive() throws Exception {
assertEquals(0, statusResponse.getSkippedShards());
assertEquals(null, statusResponse.getCompletionStatus());

response = getAsyncSearch(response.getId(), TimeValue.timeValueMillis(1));
assertThat(response.getExpirationTime(), lessThan(expirationTime));
expirationTime = response.getExpirationTime();
response = getAsyncSearch(response.getId(), TimeValue.timeValueMinutes(between(1, 24 * 60)));
assertThat(response.getExpirationTime(), equalTo(expirationTime));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));

deleteAsyncSearch(response.getId());
ensureTaskNotRunning(response.getId());
ensureTaskRemoval(response.getId());
}
Expand All @@ -391,16 +397,21 @@ public void testUpdateStoreKeepAlive() throws Exception {
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(8));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));
expirationTime = response.getExpirationTime();

assertFalse(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId(), TimeValue.timeValueMillis(1));
assertThat(response.getExpirationTime(), lessThan(expirationTime));
assertThat(response.getExpirationTime(), equalTo(expirationTime));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));

deleteAsyncSearch(response.getId());
ensureTaskNotRunning(response.getId());
ensureTaskRemoval(response.getId());
}
Expand All @@ -427,22 +438,24 @@ public void testRemoveAsyncIndex() throws Exception {
ExceptionsHelper.unwrapCause(exc.getCause()) : ExceptionsHelper.unwrapCause(exc);
assertThat(ExceptionsHelper.status(cause).getStatus(), equalTo(404));

SubmitAsyncSearchRequest newReq = new SubmitAsyncSearchRequest(indexName);
SubmitAsyncSearchRequest newReq = new SubmitAsyncSearchRequest(indexName) {
@Override
public ActionRequestValidationException validate() {
return null; // to use a small keep_alive
}
};
newReq.getSearchRequest().source(
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test", randomLong()))
);
newReq.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
newReq.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1)).setKeepAlive(TimeValue.timeValueSeconds(5));
AsyncSearchResponse newResp = submitAsyncSearch(newReq);
assertNotNull(newResp.getSearchResponse());
assertTrue(newResp.isRunning());
assertThat(newResp.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(newResp.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(newResp.getSearchResponse().getFailedShards(), equalTo(0));
long expirationTime = newResp.getExpirationTime();

// check garbage collection
newResp = getAsyncSearch(newResp.getId(), TimeValue.timeValueMillis(1));
assertThat(newResp.getExpirationTime(), lessThan(expirationTime));
ensureTaskNotRunning(newResp.getId());
ensureTaskRemoval(newResp.getId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -57,11 +58,15 @@
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import static org.elasticsearch.xpack.core.XPackPlugin.ASYNC_RESULTS_INDEX;
import static org.elasticsearch.xpack.core.async.AsyncTaskMaintenanceService.ASYNC_SEARCH_CLEANUP_INTERVAL_SETTING;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

Expand Down Expand Up @@ -93,6 +98,44 @@ public List<AggregationSpec> getAggregations() {
}
}

public static class ExpirationTimeScriptPlugin extends MockScriptPlugin {
@Override
public String pluginScriptLang() {
return "painless";
}

@Override
@SuppressWarnings("unchecked")
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
final String fieldName = "expiration_time";
final String script =
" if (ctx._source.expiration_time < params.expiration_time) { " +
" ctx._source.expiration_time = params.expiration_time; " +
" } else { " +
" ctx.op = \"noop\"; " +
" }";
return Map.of(
script, vars -> {
Map<String, Object> params = (Map<String, Object>) vars.get("params");
assertNotNull(params);
assertThat(params.keySet(), contains(fieldName));
long updatingValue = (long) params.get(fieldName);

Map<String, Object> ctx = (Map<String, Object>) vars.get("ctx");
assertNotNull(ctx);
Map<String, Object> source = (Map<String, Object>) ctx.get("_source");
long currentValue = (long) source.get(fieldName);
if (currentValue < updatingValue) {
source.put(fieldName, updatingValue);
} else {
ctx.put("op", "noop");
}
return ctx;
}
);
}
}

@Before
public void startMaintenanceService() {
for (AsyncTaskMaintenanceService service : internalCluster().getDataNodeInstances(AsyncTaskMaintenanceService.class)) {
Expand Down Expand Up @@ -120,7 +163,7 @@ public void releaseQueryLatch() {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, AsyncResultsIndexPlugin.class, IndexLifecycle.class,
SearchTestPlugin.class, ReindexPlugin.class);
SearchTestPlugin.class, ReindexPlugin.class, ExpirationTimeScriptPlugin.class);
}

@Override
Expand Down Expand Up @@ -189,7 +232,7 @@ protected void ensureTaskNotRunning(String id) throws Exception {
throw exc;
}
}
});
}, 30, TimeUnit.SECONDS);
}

/**
Expand All @@ -207,7 +250,7 @@ protected void ensureTaskCompletion(String id) throws Exception {
throw exc;
}
}
});
}, 30, TimeUnit.SECONDS);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
Expand All @@ -62,7 +63,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
private final List<Runnable> initListeners = new ArrayList<>();
private final Map<Long, Consumer<AsyncSearchResponse>> completionListeners = new HashMap<>();

private volatile long expirationTimeMillis;
private final AtomicLong expirationTimeMillis;
private final AtomicBoolean isCancelling = new AtomicBoolean(false);

private final AtomicReference<MutableSearchResponse> searchResponse = new AtomicReference<>();
Expand Down Expand Up @@ -93,7 +94,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
ThreadPool threadPool,
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
super(id, type, action, () -> "async_search{" + descriptionSupplier.get() + "}", parentTaskId, taskHeaders);
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.expirationTimeMillis = new AtomicLong(getStartTime() + keepAlive.getMillis());
this.originHeaders = originHeaders;
this.searchId = searchId;
this.client = client;
Expand Down Expand Up @@ -127,8 +128,8 @@ Listener getSearchProgressActionListener() {
* Update the expiration time of the (partial) response.
*/
@Override
public void setExpirationTime(long expirationTimeMillis) {
this.expirationTimeMillis = expirationTimeMillis;
public void extendExpirationTime(long newExpirationTimeMillis) {
this.expirationTimeMillis.updateAndGet(curr -> Math.max(curr, newExpirationTimeMillis));
}

@Override
Expand Down Expand Up @@ -330,19 +331,19 @@ private AsyncSearchResponse getResponse(boolean restoreResponseHeaders) {
checkCancellation();
AsyncSearchResponse asyncSearchResponse;
try {
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, restoreResponseHeaders);
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis.get(), restoreResponseHeaders);
} catch(Exception e) {
ElasticsearchException exception = new ElasticsearchStatusException("Async search: error while reducing partial results",
ExceptionsHelper.status(e), e);
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, exception);
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis.get(), exception);
}
return asyncSearchResponse;
}

// checks if the search task should be cancelled
private synchronized void checkCancellation() {
long now = System.currentTimeMillis();
if (hasCompleted == false && expirationTimeMillis < now) {
if (hasCompleted == false && expirationTimeMillis.get() < now) {
// we cancel expired search task even if they are still running
cancelTask(() -> {}, "async search has expired");
}
Expand All @@ -354,7 +355,7 @@ private synchronized void checkCancellation() {
public AsyncStatusResponse getStatusResponse() {
MutableSearchResponse mutableSearchResponse = searchResponse.get();
assert mutableSearchResponse != null;
return mutableSearchResponse.toStatusResponse(searchId.getEncoded(), getStartTime(), expirationTimeMillis);
return mutableSearchResponse.toStatusResponse(searchId.getEncoded(), getStartTime(), expirationTimeMillis.get());
}

class Listener extends SearchProgressActionListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public void retrieveResult(GetAsyncResultRequest request, ActionListener<Respons
// EQL doesn't store initial or intermediate results so we only need to update expiration time in store for only in case of
// async search
if (updateInitialResultsInStore & expirationTime > 0) {
store.updateExpirationTime(searchId.getDocId(), expirationTime,
store.extendExpirationTime(searchId.getDocId(), expirationTime,
ActionListener.wrap(
p -> getSearchResponseFromTask(searchId, request, nowInMillis, expirationTime, listener),
exc -> {
Expand Down Expand Up @@ -123,7 +123,7 @@ private void getSearchResponseFromTask(AsyncExecutionId searchId,
}

if (expirationTimeMillis != -1) {
task.setExpirationTime(expirationTimeMillis);
task.extendExpirationTime(expirationTimeMillis);
}
addCompletionListener.apply(task, new ActionListener<>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ public interface AsyncTask {
boolean isCancelled();

/**
* Update the expiration time of the (partial) response.
* Extends the expiration time of the (partial) response if needed
*/
void setExpirationTime(long expirationTimeMillis);
void extendExpirationTime(long newExpirationTimeMillis);

/**
* Performs necessary checks, cancels the task and calls the runnable upon completion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.xpack.core.XPackPlugin;
Expand All @@ -45,7 +47,6 @@
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -64,6 +65,13 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
public static final String HEADERS_FIELD = "headers";
public static final String RESPONSE_HEADERS_FIELD = "response_headers";
public static final String EXPIRATION_TIME_FIELD = "expiration_time";
public static final String EXPIRATION_TIME_SCRIPT =
" if (ctx._source.expiration_time < params.expiration_time) { " +
" ctx._source.expiration_time = params.expiration_time; " +
" } else { " +
" ctx.op = \"noop\"; " +
" }";

public static final String RESULT_FIELD = "result";

// Usually the settings, mappings and system index descriptor below
Expand Down Expand Up @@ -196,16 +204,15 @@ public void updateResponse(String docId,
}

/**
* Updates the expiration time of the provided <code>docId</code> if the place-holder
* document is still present (update).
* Extends the expiration time of the provided <code>docId</code> if the place-holder document is still present (update).
*/
public void updateExpirationTime(String docId,
long expirationTimeMillis,
ActionListener<UpdateResponse> listener) {
Map<String, Object> source = Collections.singletonMap(EXPIRATION_TIME_FIELD, expirationTimeMillis);
UpdateRequest request = new UpdateRequest().index(index)
public void extendExpirationTime(String docId, long expirationTimeMillis, ActionListener<UpdateResponse> listener) {
Script script = new Script(ScriptType.INLINE, "painless", EXPIRATION_TIME_SCRIPT,
Map.of(EXPIRATION_TIME_FIELD, expirationTimeMillis));
UpdateRequest request = new UpdateRequest()
.index(index)
.id(docId)
.doc(source, XContentType.JSON)
.script(script)
.retryOnConflict(5);
client.update(request, listener);
}
Expand Down
Loading

0 comments on commit 244fc95

Please sign in to comment.