Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FilterOperator to cache next element and avoid repeated consumption on hasNext() calls #3123

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class FilterOperator extends PhysicalPlan {
@Getter private final PhysicalPlan input;
@Getter private final Expression conditions;
@ToString.Exclude private ExprValue next = null;
@ToString.Exclude private boolean nextPrepared = false;

@Override
public <R, C> R accept(PhysicalPlanNodeVisitor<R, C> visitor, C context) {
Expand All @@ -41,19 +42,34 @@ public List<PhysicalPlan> getChild() {

@Override
public boolean hasNext() {
if (!nextPrepared) {
prepareNext();
}
return next != null;
}

@Override
public ExprValue next() {
if (!nextPrepared) {
prepareNext();
}
ExprValue result = next;
next = null;
nextPrepared = false;
return result;
}

private void prepareNext() {
while (input.hasNext()) {
ExprValue inputValue = input.next();
ExprValue exprValue = conditions.valueOf(inputValue.bindingTuples());
if (!(exprValue.isNull() || exprValue.isMissing()) && (exprValue.booleanValue())) {
if (!(exprValue.isNull() || exprValue.isMissing()) && exprValue.booleanValue()) {
next = inputValue;
return true;
nextPrepared = true;
return;
}
}
return false;
}

@Override
public ExprValue next() {
return next;
next = null;
nextPrepared = true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_MISSING;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.filter;

import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
Expand All @@ -26,12 +36,22 @@
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;

@ExtendWith(MockitoExtension.class)
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class FilterOperatorTest extends PhysicalPlanTestBase {
@Mock private PhysicalPlan inputPlan;

@Mock private Expression condition;

private FilterOperator filterOperator;

@BeforeEach
public void setup() {
filterOperator = filter(inputPlan, condition);
}

@Test
public void filter_test() {
FilterOperator plan =
Expand Down Expand Up @@ -82,4 +102,68 @@ public void missing_value_should_been_ignored() {
List<ExprValue> result = execute(plan);
assertEquals(0, result.size());
}

@Test
public void testHasNextWhenInputHasNoElements() {
when(inputPlan.hasNext()).thenReturn(false);

assertFalse(
filterOperator.hasNext(), "hasNext() should return false when input has no elements");
}

@Test
public void testHasNextWithMatchingCondition() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true).thenReturn(false);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_TRUE);

assertTrue(filterOperator.hasNext(), "hasNext() should return true when condition matches");
assertEquals(
inputValue, filterOperator.next(), "next() should return the matching input value");
}

@Test
public void testHasNextWithNonMatchingCondition() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true, false);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_FALSE);

assertFalse(
filterOperator.hasNext(), "hasNext() should return false if no values match the condition");
}

@Test
public void testMultipleCallsToHasNextDoNotConsumeInput() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_TRUE);

assertTrue(
filterOperator.hasNext(),
"First hasNext() call should return true if there is a matching value");
verify(inputPlan, times(1)).next();
assertTrue(
filterOperator.hasNext(),
"Subsequent hasNext() calls should still return true without advancing the input");
verify(inputPlan, times(1)).next();
assertEquals(
inputValue, filterOperator.next(), "next() should return the matching input value");
verify(inputPlan, times(1)).next();
}

@Test
public void testNextWithoutCallingHasNext() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true, false);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_TRUE);

assertEquals(
inputValue,
filterOperator.next(),
"next() should return the matching input value even if hasNext() was not called");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT Origin, Dest FROM (SELECT * FROM opensearch_dashboards_sample_data_flights WHERE AvgTicketPrice > 100 GROUP BY Origin, Dest, AvgTicketPrice) AS flights WHERE AvgTicketPrice < 1000 ORDER BY AvgTicketPrice LIMIT 30
Loading