Skip to content

Commit

Permalink
Use distinct for array_intersect, array_except, array_union
Browse files Browse the repository at this point in the history
Use distinct semantics rather than equal to semantics for
array_intersect, array_except, and array_union. This fixes failures
where Arrays of rows that had null fields would fail with errors that
comparison of rows with null elements were not allowed
  • Loading branch information
rschlussel committed Jun 25, 2024
1 parent 1722ea7 commit 0c8fbb7
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 27 deletions.
16 changes: 16 additions & 0 deletions presto-docs/src/main/sphinx/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Array Functions
.. function:: array_except(x, y) -> array

Returns an array of elements in ``x`` but not in ``y``, without duplicates.
This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. ::

SELECT array_except(ARRAY[1, 3, 3, 2, null], ARRAY[1,2, 2, 4]) -- ARRAY[3, null]

.. function:: array_frequency(array(E)) -> map(E, int)

Expand All @@ -71,14 +74,24 @@ Array Functions
.. function:: array_has_duplicates(array(T)) -> boolean

Returns a boolean: whether ``array`` has any elements that occur more than once.
Throws an exception if any of the elements are rows or arrays that contain nulls. ::

SELECT array_has_duplicates(ARRAY[1, 2, null, 1, null, 3]) -- true
SELECT array_has_duplicates(ARRAY[ROW(1, null), ROW(1, null)]) -- "map key cannot be null or contain nulls"

.. function:: array_intersect(x, y) -> array

Returns an array of the elements in the intersection of ``x`` and ``y``, without duplicates.
This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. ::

SELECT array_intersect(ARRAY[1, 2, 3, 2, null], ARRAY[1,2, 2, 4, null]) -- ARRAY[1, 2, null]

.. function:: array_intersect(array(array(E))) -> array(E)

Returns an array of the elements in the intersection of all arrays in the given array, without duplicates.
This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. ::

SELECT array_intersect(ARRAY[ARRAY[1, 2, 3, 2, null], ARRAY[1,2,2, 4, null], ARRAY [1, 2, 3, 4 null]]) -- ARRAY[1, 2, null]

.. function:: array_join(x, delimiter, null_replacement) -> varchar

Expand Down Expand Up @@ -209,6 +222,9 @@ Array Functions
.. function:: array_union(x, y) -> array

Returns an array of the elements in the union of ``x`` and ``y``, without duplicates.
This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. ::

SELECT array_union(ARRAY[1, 2, 3, 2, null], ARRAY[1,2, 2, 4, null]) -- ARRAY[1, 2, 3, 4 null]

.. function:: cardinality(x) -> bigint

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
import com.facebook.presto.operator.project.SelectedPositions;
import org.openjdk.jol.info.ClassLayout;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.common.array.Arrays.ensureCapacity;
import static com.facebook.presto.common.type.TypeUtils.readNativeValue;
import static com.facebook.presto.operator.project.SelectedPositions.positionsList;
import static com.facebook.presto.type.TypeUtils.hashPosition;
import static com.facebook.presto.type.TypeUtils.positionEqualsPosition;
import static com.facebook.presto.util.Failures.internalError;
import static com.google.common.base.Defaults.defaultValue;
import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.sizeOf;
import static it.unimi.dsi.fastutil.HashCommon.arraySize;
Expand All @@ -47,6 +52,7 @@ public class OptimizedTypedSet
private static final SelectedPositions EMPTY_SELECTED_POSITIONS = positionsList(new int[0], 0, 0);

private final Type elementType;
private final Optional<MethodHandle> elementIsDistinctFrom;
private final int hashCapacity;
private final int hashMask;

Expand All @@ -56,17 +62,18 @@ public class OptimizedTypedSet
private long[] blockPositionByHash; // Each 64-bit long is 32-bit index for blocks + 32-bit position within block
private int currentBlockIndex = -1; // The index into the blocks array and positionsForBlocks list

public OptimizedTypedSet(Type elementType, int maxPositionCount)
public OptimizedTypedSet(Type elementType, MethodHandle elementIsDistinctFrom, int maxPositionCount)
{
this(elementType, INITIAL_BLOCK_COUNT, maxPositionCount);
this(elementType, Optional.of(elementIsDistinctFrom), INITIAL_BLOCK_COUNT, maxPositionCount);
}

public OptimizedTypedSet(Type elementType, int expectedBlockCount, int maxPositionCount)
public OptimizedTypedSet(Type elementType, Optional<MethodHandle> elementIsDistinctFrom, int expectedBlockCount, int maxPositionCount)
{
checkArgument(expectedBlockCount >= 0, "expectedBlockCount must not be negative");
checkArgument(maxPositionCount >= 0, "maxPositionCount must not be negative");

this.elementType = requireNonNull(elementType, "elementType must not be null");
this.elementIsDistinctFrom = requireNonNull(elementIsDistinctFrom, "elementIsDistinctFrom is null");
this.hashCapacity = arraySize(maxPositionCount, FILL_RATIO);
this.hashMask = hashCapacity - 1;

Expand Down Expand Up @@ -293,14 +300,31 @@ private int getInsertPosition(long[] hashtable, int hashPosition, Block block, i
// Already has this element
int blockIndex = (int) ((blockPosition & 0xffff_ffff_0000_0000L) >> 32);
int positionWithinBlock = (int) (blockPosition & 0xffff_ffff);
if (positionEqualsPosition(elementType, blocks[blockIndex], positionWithinBlock, block, position)) {
if (isContainedAt(blocks[blockIndex], positionWithinBlock, block, position)) {
return INVALID_POSITION;
}

hashPosition = getMaskedHash(hashPosition + 1);
}
}

private boolean isContainedAt(Block firstBlock, int positionWithinFirstBlock, Block secondBlock, int positionWithinSecondBlock)
{
if (elementIsDistinctFrom.isPresent()) {
boolean firstValueNull = firstBlock.isNull(positionWithinFirstBlock);
Object firstValue = firstValueNull ? defaultValue(elementType.getJavaType()) : readNativeValue(elementType, firstBlock, positionWithinFirstBlock);
boolean secondValueNull = secondBlock.isNull(positionWithinSecondBlock);
Object secondValue = secondValueNull ? defaultValue(elementType.getJavaType()) : readNativeValue(elementType, secondBlock, positionWithinSecondBlock);
try {
return !(boolean) elementIsDistinctFrom.get().invoke(firstValue, firstValueNull, secondValue, secondValueNull);
}
catch (Throwable t) {
throw internalError(t);
}
}
return positionEqualsPosition(elementType, firstBlock, positionWithinFirstBlock, secondBlock, positionWithinSecondBlock);
}

/**
* Add an element to the hash table if it's not already existed.
*
Expand All @@ -322,7 +346,7 @@ private boolean addElement(long[] hashtable, int hashPosition, Block block, int
// Already has this element
int blockIndex = (int) ((blockPosition & 0xffff_ffff_0000_0000L) >> 32);
int positionWithinBlock = (int) (blockPosition & 0xffff_ffff);
if (positionEqualsPosition(elementType, blocks[blockIndex], positionWithinBlock, block, position)) {
if (isContainedAt(blocks[blockIndex], positionWithinBlock, block, position)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.OptimizedTypedSet;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;
import static java.lang.Math.max;

@ScalarFunction("array_except")
Expand All @@ -33,6 +37,7 @@ private ArrayExceptFunction() {}
@SqlType("array(E)")
public static Block except(
@TypeParameter("E") Type type,
@OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom,
@SqlType("array(E)") Block leftArray,
@SqlType("array(E)") Block rightArray)
{
Expand All @@ -43,7 +48,7 @@ public static Block except(
return leftArray;
}

OptimizedTypedSet typedSet = new OptimizedTypedSet(type, max(leftPositionCount, rightPositionCount));
OptimizedTypedSet typedSet = new OptimizedTypedSet(type, elementIsDistinctFrom, max(leftPositionCount, rightPositionCount));
typedSet.union(rightArray);
typedSet.except(leftArray);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.OptimizedTypedSet;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;

public final class ArrayIntersectFunction
{
private ArrayIntersectFunction() {}
Expand All @@ -33,6 +38,7 @@ private ArrayIntersectFunction() {}
@SqlType("array(E)")
public static Block intersect(
@TypeParameter("E") Type type,
@OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom,
@SqlType("array(E)") Block leftArray,
@SqlType("array(E)") Block rightArray)
{
Expand All @@ -48,7 +54,7 @@ public static Block intersect(
return rightArray;
}

OptimizedTypedSet typedSet = new OptimizedTypedSet(type, rightPositionCount);
OptimizedTypedSet typedSet = new OptimizedTypedSet(type, elementIsDistinctFrom, rightPositionCount);
typedSet.union(rightArray);
typedSet.intersect(leftArray);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.OptimizedTypedSet;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;

@ScalarFunction("array_union")
@Description("Union elements of the two given arrays")
public final class ArrayUnionFunction
Expand All @@ -31,12 +36,13 @@ private ArrayUnionFunction() {}
@SqlType("array(E)")
public static Block union(
@TypeParameter("E") Type type,
@OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom,
@SqlType("array(E)") Block leftArray,
@SqlType("array(E)") Block rightArray)
{
int leftArrayCount = leftArray.getPositionCount();
int rightArrayCount = rightArray.getPositionCount();
OptimizedTypedSet typedSet = new OptimizedTypedSet(type, leftArrayCount + rightArrayCount);
OptimizedTypedSet typedSet = new OptimizedTypedSet(type, elementIsDistinctFrom, leftArrayCount + rightArrayCount);

typedSet.union(leftArray);
typedSet.union(rightArray);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public static Block mapConcat(MapType mapType, Block[] maps)
Type valueType = mapType.getValueType();

// We need to divide the entries by 2 because the maps array is SingleMapBlocks and it had the positionCount twice as large as a normal Block
OptimizedTypedSet typedSet = new OptimizedTypedSet(keyType, maps.length, entries / 2);
OptimizedTypedSet typedSet = new OptimizedTypedSet(keyType, Optional.empty(), maps.length, entries / 2);
for (int i = lastMapIndex; i >= firstMapIndex; i--) {
SingleMapBlock singleMapBlock = (SingleMapBlock) maps[i];
Block keyBlock = singleMapBlock.getKeyBlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,32 @@
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.type.BigintOperators;
import org.testng.annotations.Test;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static com.facebook.presto.block.BlockAssertions.assertBlockEquals;
import static com.facebook.presto.block.BlockAssertions.createEmptyBlock;
import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock;
import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock;
import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static org.testng.Assert.fail;

public class TestOptimizedTypedSet
{
private static final String FUNCTION_NAME = "optimized_typed_set_test";
private static final int POSITIONS_PER_PAGE = 100;
private static final MethodHandle BIGINT_DISTINCT_METHOD_HANDLE = methodHandle(BigintOperators.BigintDistinctFromOperator.class, "isDistinctFrom", long.class, boolean.class, long.class, boolean.class);

@Test
public void testConstructor()
{
for (int i = -2; i <= -1; i++) {
try {
//noinspection ResultOfObjectAllocationIgnored
new OptimizedTypedSet(BIGINT, 2, i);
new OptimizedTypedSet(BIGINT, Optional.of(BIGINT_DISTINCT_METHOD_HANDLE), 2, i);
fail("Should throw exception if expectedSize < 0");
}
catch (IllegalArgumentException e) {
Expand All @@ -44,7 +49,7 @@ public void testConstructor()

try {
//noinspection ResultOfObjectAllocationIgnored
new OptimizedTypedSet(null, -1, 1);
new OptimizedTypedSet(null, Optional.of(BIGINT_DISTINCT_METHOD_HANDLE), -1, 1);
fail("Should throw exception if expectedBlockCount is negative");
}
catch (NullPointerException | IllegalArgumentException e) {
Expand All @@ -53,7 +58,7 @@ public void testConstructor()

try {
//noinspection ResultOfObjectAllocationIgnored
new OptimizedTypedSet(null, 2, 1);
new OptimizedTypedSet(null, Optional.of(BIGINT_DISTINCT_METHOD_HANDLE), 2, 1);
fail("Should throw exception if type is null");
}
catch (NullPointerException | IllegalArgumentException e) {
Expand All @@ -64,7 +69,7 @@ public void testConstructor()
@Test
public void testUnionWithDistinctValues()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE + 1);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE + 1);

Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE / 2);
testUnion(typedSet, block, block);
Expand All @@ -80,7 +85,7 @@ public void testUnionWithDistinctValues()
@Test
public void testUnionWithRepeatingValues()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE);

Block block = createLongRepeatBlock(0, POSITIONS_PER_PAGE);
Block expectedBlock = createLongRepeatBlock(0, 1);
Expand All @@ -95,14 +100,14 @@ public void testUnionWithRepeatingValues()
@Test
public void testIntersectWithEmptySet()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE);
testIntersect(typedSet, createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull(), createEmptyBlock(BIGINT));
}

@Test
public void testIntersectWithDistinctValues()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE);

Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull();
typedSet.union(block);
Expand All @@ -119,7 +124,7 @@ public void testIntersectWithDistinctValues()
@Test
public void testIntersectWithNonDistinctValues()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE);

Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull();
typedSet.union(block);
Expand All @@ -137,7 +142,7 @@ public void testIntersectWithNonDistinctValues()
@Test
public void testExceptWithDistinctValues()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE);

Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull();
typedSet.union(block);
Expand All @@ -149,7 +154,7 @@ public void testExceptWithDistinctValues()
@Test
public void testExceptWithRepeatingValues()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE);

Block block = createLongRepeatBlock(0, POSITIONS_PER_PAGE - 1).appendNull();
testExcept(typedSet, block, createLongSequenceBlock(0, 1).appendNull());
Expand All @@ -158,7 +163,7 @@ public void testExceptWithRepeatingValues()
@Test
public void testMultipleOperations()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE + 1);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE + 1);

Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE / 2).appendNull();

Expand All @@ -176,7 +181,7 @@ public void testMultipleOperations()
@Test
public void testNulls()
{
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE + 1);
OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE + 1);

// Empty block
Block emptyBlock = createLongSequenceBlock(0, 0);
Expand Down
Loading

0 comments on commit 0c8fbb7

Please sign in to comment.