Skip to content

Commit

Permalink
Add support of special WrappingSearchAsyncActionPhase so the onPhaseS…
Browse files Browse the repository at this point in the history
…tart() will always be followed by onPhaseEnd() within AbstractSearchAsyncAction

Signed-off-by: Andriy Redko <[email protected]>
  • Loading branch information
reta committed Feb 13, 2024
1 parent 76ae14a commit 9f38f6f
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 75 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Fixed
- Fix for deserilization bug in weighted round-robin metadata ([#11679](https://github.com/opensearch-project/OpenSearch/pull/11679))
- [Revert] [Bug] Check phase name before SearchRequestOperationsListener onPhaseStart ([#12035](https://github.com/opensearch-project/OpenSearch/pull/12035))
- Add support of special WrappingSearchAsyncActionPhase so the onPhaseStart() will always be followed by onPhaseEnd() within AbstractSearchAsyncAction ([#12293](https://github.com/opensearch-project/OpenSearch/pull/12293))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,18 +432,16 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
}

private void onPhaseEnd(SearchRequestContext searchRequestContext) {
if (getCurrentPhase() != null && SearchPhaseName.isValidName(getName())) {
if (getCurrentPhase() != null) {
long tookInNanos = System.nanoTime() - getCurrentPhase().getStartTimeInNanos();
searchRequestContext.updatePhaseTookMap(getCurrentPhase().getName(), TimeUnit.NANOSECONDS.toMillis(tookInNanos));
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext);
}
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext);
}

void onPhaseStart(SearchPhase phase) {
private void onPhaseStart(SearchPhase phase) {
setCurrentPhase(phase);
if (SearchPhaseName.isValidName(phase.getName())) {
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this);
}
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this);
}

private void onRequestEnd(SearchRequestContext searchRequestContext) {
Expand All @@ -454,10 +452,19 @@ private void executePhase(SearchPhase phase) {
try {
onPhaseStart(phase);
phase.recordAndRun();
// The WrappingSearchAsyncActionPhase (see please CanMatchPreFilterSearchPhase as one example) is a special case
// of search phase that wraps SearchAsyncActionPhase as SearchPhase. The AbstractSearchAsyncAction manages own
// onPhaseStart / onPhaseFailure / OnPhaseDone callbacks and the wrapping SearchPhase is being abandoned
// (fe, has no onPhaseEnd callbacks called ever). To fix that, the explicit onPhaseEnd is being called
// since SearchPhase::recordAndRun would delegate to AbstractSearchAsyncAction::start internally.
if (phase instanceof WrappingSearchAsyncActionPhase) {
onPhaseEnd(searchRequestContext);
}
} catch (Exception e) {
if (logger.isDebugEnabled()) {
logger.debug(new ParameterizedMessage("Failed to execute [{}] while moving to [{}] phase", request, phase.getName()), e);
}

onPhaseFailure(phase, "", e);
}
}
Expand Down Expand Up @@ -716,9 +723,7 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At

@Override
public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
if (SearchPhaseName.isValidName(phase.getName())) {
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this);
}
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this);
raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@

import org.opensearch.common.annotation.PublicApi;

import java.util.HashSet;
import java.util.Set;

/**
* Enum for different Search Phases in OpenSearch
*
Expand All @@ -28,12 +25,6 @@ public enum SearchPhaseName {
CAN_MATCH("can_match");

private final String name;
private static final Set<String> PHASE_NAMES = new HashSet<>();
static {
for (SearchPhaseName phaseName : SearchPhaseName.values()) {
PHASE_NAMES.add(phaseName.name);
}
}

SearchPhaseName(final String name) {
this.name = name;
Expand All @@ -42,8 +33,4 @@ public enum SearchPhaseName {
public String getName() {
return name;
}

public static boolean isValidName(String phaseName) {
return PHASE_NAMES.contains(phaseName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1220,8 +1220,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
timeProvider,
clusterState,
task,
(iter) -> {
AbstractSearchAsyncAction<? extends SearchPhaseResult> action = searchAsyncAction(
(iter) -> new WrappingSearchAsyncActionPhase(
searchAsyncAction(
task,
searchRequest,
executor,
Expand All @@ -1237,14 +1237,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
threadPool,
clusters,
searchRequestContext
);
return new SearchPhase("none") {
@Override
public void run() {
action.start();
}
};
},
)
),
clusters,
searchRequestContext
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.action.search;

import org.opensearch.search.SearchPhaseResult;

/**
* The WrappingSearchAsyncActionPhase (see please {@link CanMatchPreFilterSearchPhase} as one example) is a special case
* of search phase that wraps SearchAsyncActionPhase as {@link SearchPhase}. The {@link AbstractSearchAsyncAction} manages own
* onPhaseStart / onPhaseFailure / OnPhaseDone callbacks and but just wrapping it with the SearchPhase causes
* only some callbacks being called. The {@link AbstractSearchAsyncAction} has special treatment of {@link WrappingSearchAsyncActionPhase}.
*/
class WrappingSearchAsyncActionPhase extends SearchPhase {
private final AbstractSearchAsyncAction<? extends SearchPhaseResult> action;

protected WrappingSearchAsyncActionPhase(AbstractSearchAsyncAction<? extends SearchPhaseResult> action) {
super(action.getName());
this.action = action;
}

@Override
public void run() {
action.start();
}

SearchPhase getSearchPhase() {
return action;

Check warning on line 33 in server/src/main/java/org/opensearch/action/search/WrappingSearchAsyncActionPhase.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/search/WrappingSearchAsyncActionPhase.java#L33

Added line #L33 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
Expand All @@ -95,6 +97,7 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {
private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
private final Set<ShardSearchContextId> releasedContexts = new CopyOnWriteArraySet<>();
private ExecutorService executor;
private SearchRequestOperationsListener assertingListener;
ThreadPool threadPool;

@Before
Expand All @@ -103,6 +106,27 @@ public void setUp() throws Exception {
super.setUp();
executor = Executors.newFixedThreadPool(1);
threadPool = new TestThreadPool(getClass().getName());
assertingListener = new SearchRequestOperationsListener() {
private volatile SearchPhase phase;

@Override
protected void onPhaseStart(SearchPhaseContext context) {
assertThat(phase, is(nullValue()));
phase = context.getCurrentPhase();
}

@Override
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {
assertThat(phase, is(context.getCurrentPhase()));
phase = null;
}

@Override
protected void onPhaseFailure(SearchPhaseContext context) {
assertThat(phase, is(context.getCurrentPhase()));
phase = null;
}
};
}

@After
Expand Down Expand Up @@ -178,7 +202,10 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
results,
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request)
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
)
) {
@Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
Expand Down Expand Up @@ -334,18 +361,11 @@ public void testOnPhaseFailureAndVerifyListeners() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testListener = new SearchRequestStats(clusterSettings);

final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener));
final List<SearchRequestOperationsListener> requestOperationListeners = List.of(testListener);
SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners);
action.start();
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase("none") {
@Override
public void run() {

}
}, "message", null);
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase(action.getName()) {
action.onPhaseFailure(new SearchPhase("test") {
@Override
public void run() {

Expand All @@ -359,14 +379,14 @@ public void run() {
);
searchDfsQueryThenFetchAsyncAction.start();
assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase(searchDfsQueryThenFetchAsyncAction.getName()) {
searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") {
@Override
public void run() {

}
}, "message", null);
assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName()));

FetchSearchPhase fetchPhase = createFetchSearchPhase();
ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt());
Expand All @@ -375,7 +395,7 @@ public void run() {
action.skipShard(searchShardIterator);
action.executeNextPhase(action, fetchPhase);
assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase(fetchPhase.getName()) {
action.onPhaseFailure(new SearchPhase("test") {
@Override
public void run() {

Expand Down Expand Up @@ -410,30 +430,6 @@ public void run() {
assertEquals(requestIds, releasedContexts);
}

public void testOnPhaseStart() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testListener = new SearchRequestStats(clusterSettings);

final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener));
SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners);

action.onPhaseStart(new SearchPhase("test") {
@Override
public void run() {}
});
action.onPhaseStart(new SearchPhase("none") {
@Override
public void run() {}
});
assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName()));

action.onPhaseStart(new SearchPhase(action.getName()) {
@Override
public void run() {}
});
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
}

public void testShardNotAvailableWithDisallowPartialFailures() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false);
AtomicReference<Exception> exception = new AtomicReference<>();
Expand Down
Loading

0 comments on commit 9f38f6f

Please sign in to comment.