Skip to content

Commit

Permalink
Store TotalHits and use it to report total in response.
Browse files Browse the repository at this point in the history
Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand committed Mar 11, 2023
1 parent 1b5ab7e commit f4ea4ad
Show file tree
Hide file tree
Showing 20 changed files with 158 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void execute(PhysicalPlan plan, ExecutionContext context,
class QueryResponse {
private final Schema schema;
private final List<ExprValue> results;

private final long total;
private final Cursor cursor;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public ExecutionEngine.Schema schema() {
+ "ProjectOperator, instead of %s", this.getClass().getSimpleName()));
}

public long getTotalHits() {
return getChild().stream().mapToLong(PhysicalPlan::getTotalHits).max().orElse(0);
}

public String toCursor() {
throw new IllegalStateException(String.format("%s is not compatible with cursor feature",
this.getClass().getSimpleName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Helper executeSuccess(Split split) {
invocation -> {
ResponseListener<ExecutionEngine.QueryResponse> listener = invocation.getArgument(2);
listener.onResponse(
new ExecutionEngine.QueryResponse(schema, Collections.emptyList(),
new ExecutionEngine.QueryResponse(schema, Collections.emptyList(), 0,
Cursor.None));
return null;
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ Helper executeSuccess(Long... offsets) {
ResponseListener<ExecutionEngine.QueryResponse> listener =
invocation.getArgument(2);
listener.onResponse(
new ExecutionEngine.QueryResponse(null, Collections.emptyList(),
new ExecutionEngine.QueryResponse(null, Collections.emptyList(), 0,
Cursor.None));

PlanContext planContext = invocation.getArgument(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

package org.opensearch.sql.planner.physical;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.CALLS_REAL_METHODS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.List;

import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -16,6 +23,7 @@
import org.opensearch.sql.storage.split.Split;

@ExtendWith(MockitoExtension.class)
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class PhysicalPlanTest {
@Mock
Split split;
Expand Down Expand Up @@ -46,8 +54,25 @@ public List<PhysicalPlan> getChild() {
};

@Test
void addSplitToChildByDefault() {
void add_split_to_child_by_default() {
testPlan.add(split);
verify(child).add(split);
}

@Test
void get_total_hits_from_child() {
var plan = mock(PhysicalPlan.class);
when(child.getTotalHits()).thenReturn(42L);
when(plan.getChild()).thenReturn(List.of(child));
when(plan.getTotalHits()).then(CALLS_REAL_METHODS);
assertEquals(42, plan.getTotalHits());
verify(child).getTotalHits();
}

@Test
void get_total_hits_uses_default_value() {
var plan = mock(PhysicalPlan.class);
when(plan.getTotalHits()).then(CALLS_REAL_METHODS);
assertEquals(0, plan.getTotalHits());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void execute(
result.add(plan.next());
}
QueryResponse response = new QueryResponse(new Schema(new ArrayList<>()), new ArrayList<>(),
Cursor.None);
0, Cursor.None);
listener.onResponse(response);
} catch (Exception e) {
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ private ResponseListener<QueryResponse> createQueryResponseListener(
@Override
public void onResponse(QueryResponse response) {
sendResponse(channel, OK,
formatter.format(new QueryResult(response.getSchema(), response.getResults(), response.getCursor())));
formatter.format(new QueryResult(response.getSchema(), response.getResults(),
response.getCursor(), response.getTotal())));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void execute(PhysicalPlan physicalPlan, ExecutionContext context,

Cursor qc = paginatedPlanCache.convertToCursor(plan);

QueryResponse response = new QueryResponse(physicalPlan.schema(), result, qc);
QueryResponse response = new QueryResponse(physicalPlan.schema(), result, plan.getTotalHits(), qc);
listener.onResponse(response);
} catch (Exception e) {
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ public ExprValue next() {
return delegate.next();
}

@Override
public long getTotalHits() {
return delegate.getTotalHits();
}

@Override
public String toCursor() {
return delegate.toCursor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ public OpenSearchResponse(SearchHits hits, OpenSearchExprValueFactory exprValueF
*/
public boolean isEmpty() {
return (hits.getHits() == null) || (hits.getHits().length == 0) && aggregations == null;
// TODO TBD ^ ^
}

public long getTotalHits() {
return hits.getTotalHits().value;
}

public boolean isAggregationResponse() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ public ExprValue next() {
return iterator.next();
}

@Override
public long getTotalHits() {
// TODO maybe store totalHits from `response`
return queryCount;
}

private void fetchNextBatch() {
OpenSearchResponse response = client.search(request);
if (!response.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class OpenSearchPagedIndexScan extends TableScanOperator {
private OpenSearchRequest request;
private Iterator<ExprValue> iterator;
private boolean needClean = false;
private long totalHits = 0;

public OpenSearchPagedIndexScan(OpenSearchClient client,
PagedRequestBuilder requestBuilder) {
Expand Down Expand Up @@ -56,6 +57,7 @@ public void open() {
OpenSearchResponse response = client.search(request);
if (!response.isEmpty()) {
iterator = response.iterator();
totalHits = response.getTotalHits();
} else {
needClean = true;
iterator = Collections.emptyIterator();
Expand All @@ -71,6 +73,11 @@ public void close() {
}
}

@Override
public long getTotalHits() {
return totalHits;
}

@Override
public String toCursor() {
// TODO this assumes exactly one index is scanned.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,10 @@ void acceptSuccess() {
monitorPlan.accept(visitor, context);
verify(plan, times(1)).accept(visitor, context);
}

@Test
void getTotalHitsSuccess() {
monitorPlan.getTotalHits();
verify(plan, times(1)).getTotalHits();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,29 @@ void isEmpty() {
new TotalHits(2L, TotalHits.Relation.EQUAL_TO),
1.0F));

assertFalse(new OpenSearchResponse(searchResponse, factory).isEmpty());
var response = new OpenSearchResponse(searchResponse, factory);
assertFalse(response.isEmpty());
assertEquals(2L, response.getTotalHits());

when(searchResponse.getHits()).thenReturn(SearchHits.empty());
when(searchResponse.getAggregations()).thenReturn(null);
assertTrue(new OpenSearchResponse(searchResponse, factory).isEmpty());

response = new OpenSearchResponse(searchResponse, factory);
assertTrue(response.isEmpty());
assertEquals(0L, response.getTotalHits());

when(searchResponse.getHits())
.thenReturn(new SearchHits(null, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0));
OpenSearchResponse response3 = new OpenSearchResponse(searchResponse, factory);
assertTrue(response3.isEmpty());
response = new OpenSearchResponse(searchResponse, factory);
assertTrue(response.isEmpty());
assertEquals(0L, response.getTotalHits());

when(searchResponse.getHits()).thenReturn(SearchHits.empty());
when(searchResponse.getAggregations()).thenReturn(new Aggregations(emptyList()));
assertFalse(new OpenSearchResponse(searchResponse, factory).isEmpty());

response = new OpenSearchResponse(searchResponse, factory);
assertFalse(response.isEmpty());
assertEquals(0L, response.getTotalHits());
}

@Test
Expand All @@ -104,7 +113,8 @@ void iterator() {
when(factory.construct(any())).thenReturn(exprTupleValue1).thenReturn(exprTupleValue2);

int i = 0;
for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) {
var response = new OpenSearchResponse(searchResponse, factory);
for (ExprValue hit : response) {
if (i == 0) {
assertEquals(exprTupleValue1, hit);
} else if (i == 1) {
Expand All @@ -114,6 +124,7 @@ void iterator() {
}
i++;
}
assertEquals(2L, response.getTotalHits());
}

@Test
Expand Down
Loading

0 comments on commit f4ea4ad

Please sign in to comment.