Skip to content

Commit

Permalink
Fix array_min/max function for nans
Browse files Browse the repository at this point in the history
also fixes when array has nans and nulls
  • Loading branch information
rschlussel committed Jun 6, 2024
1 parent 984749a commit b17a916
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import java.lang.invoke.MethodHandle;

import static com.facebook.presto.util.Failures.internalError;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;

public final class ArrayMinMaxUtils
{
Expand Down Expand Up @@ -98,9 +96,6 @@ public static Double doubleArrayMinMax(MethodHandle compareMethodHandle, Type el
if ((boolean) compareMethodHandle.invokeExact(value, selectedValue)) {
selectedValue = value;
}
else if (isNaN(value)) {
return NaN;
}
}

return containNull ? null : selectedValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,11 @@ public void testArrayMin()
assertFunction("ARRAY_MIN(ARRAY [NULL, NULL, NULL])", UNKNOWN, null);
assertFunction("ARRAY_MIN(ARRAY [NaN(), NaN(), NaN()])", DOUBLE, NaN);
assertFunction("ARRAY_MIN(ARRAY [NULL, 2, 3])", INTEGER, null);
assertFunction("ARRAY_MIN(ARRAY [NaN(), 2, 3])", DOUBLE, NaN);
assertFunction("ARRAY_MIN(ARRAY [NULL, NaN(), 1])", DOUBLE, NaN);
assertFunction("ARRAY_MIN(ARRAY [NaN(), NULL, 3.0])", DOUBLE, NaN);
assertFunction("ARRAY_MIN(ARRAY [NaN(), 2, 3])", DOUBLE, 2.0);
assertFunction("ARRAY_MIN(ARRAY [NULL, NaN(), 1])", DOUBLE, null);
assertFunction("ARRAY_MIN(ARRAY [NaN(), NULL, 3.0])", DOUBLE, null);
assertFunction("ARRAY_MIN(ARRAY [1.0E0, NULL, 3])", DOUBLE, null);
assertFunction("ARRAY_MIN(ARRAY [1.0, NaN(), 3])", DOUBLE, NaN);
assertFunction("ARRAY_MIN(ARRAY [1.0, NaN(), 3])", DOUBLE, 1.0);
assertFunction("ARRAY_MIN(ARRAY ['1', '2', NULL])", createVarcharType(1), null);
assertFunction("ARRAY_MIN(ARRAY [3, 2, 1])", INTEGER, 1);
assertFunction("ARRAY_MIN(ARRAY [1, 2, 3])", INTEGER, 1);
Expand All @@ -706,8 +706,8 @@ public void testArrayMax()
assertFunction("ARRAY_MAX(ARRAY [NaN(), NaN(), NaN()])", DOUBLE, NaN);
assertFunction("ARRAY_MAX(ARRAY [NULL, 2, 3])", INTEGER, null);
assertFunction("ARRAY_MAX(ARRAY [NaN(), 2, 3])", DOUBLE, NaN);
assertFunction("ARRAY_MAX(ARRAY [NULL, NaN(), 1])", DOUBLE, NaN);
assertFunction("ARRAY_MAX(ARRAY [NaN(), NULL, 3.0])", DOUBLE, NaN);
assertFunction("ARRAY_MAX(ARRAY [NULL, NaN(), 1])", DOUBLE, null);
assertFunction("ARRAY_MAX(ARRAY [NaN(), NULL, 3.0])", DOUBLE, null);
assertFunction("ARRAY_MAX(ARRAY [1.0E0, NULL, 3])", DOUBLE, null);
assertFunction("ARRAY_MAX(ARRAY [1.0, NaN(), 3])", DOUBLE, NaN);
assertFunction("ARRAY_MAX(ARRAY ['1', '2', NULL])", createVarcharType(1), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,30 @@ public void testMin()
assertQueryWithSameQueryRunner(format("SELECT min(%s), min(%s), min(%s) FROM %s", REAL_NAN_FIRST_COLUMN, REAL_NAN_MIDDLE_COLUMN, REAL_NAN_LAST_COLUMN, REAL_NANS_TABLE_NAME), "SELECT REAL '-4.0', REAL '0.0', REAL '0.0'");
}

@Test
public void testDoubleArrayMinAgg()
{
assertQueryWithSameQueryRunner(format("SELECT min(%s) FROM %s WHERE none_match(%s, x -> x IS NULL)", SIMPLE_DOUBLE_ARRAY_COLUMN, ARRAY_TABLE_NAME, SIMPLE_DOUBLE_ARRAY_COLUMN), "SELECT ARRAY[0, 1, -1, nan()]");
}

@Test
public void testRealArrayMinAgg()
{
assertQueryWithSameQueryRunner(format("SELECT min(%s) FROM %s WHERE none_match(%s, x -> x IS NULL)", SIMPLE_REAL_ARRAY_COLUMN, ARRAY_TABLE_NAME, SIMPLE_REAL_ARRAY_COLUMN), "SELECT ARRAY[REAL'0', REAL '1', REAL '-1', CAST(nan() AS REAL)]");
}

@Test
public void testDoubleArrayMaxAgg()
{
assertQueryWithSameQueryRunner(format("SELECT max(%s) FROM %s WHERE none_match(%s, x -> x IS NULL)", SIMPLE_DOUBLE_ARRAY_COLUMN, ARRAY_TABLE_NAME, SIMPLE_DOUBLE_ARRAY_COLUMN), "SELECT ARRAY[nan(), nan()]");
}

@Test
public void testRealArrayMaxAgg()
{
assertQueryWithSameQueryRunner(format("SELECT max(%s) FROM %s WHERE none_match(%s, x -> x IS NULL)", SIMPLE_REAL_ARRAY_COLUMN, ARRAY_TABLE_NAME, SIMPLE_REAL_ARRAY_COLUMN), "SELECT ARRAY[CAST(nan() AS REAL), CAST(nan() AS REAL)]");
}

@Test
public void testMax()
{
Expand Down

0 comments on commit b17a916

Please sign in to comment.