Skip to content

Commit

Permalink
Support bitwise shift functions
Browse files Browse the repository at this point in the history
- bitwise_left_shift: Bitwise left shift function
- bitwise_right_shift: Bitwise logical right shift function
- bitwise_right_shift_arithmetic: Bitwise arithmetic right shift function
  • Loading branch information
Lewuathe authored and martint committed Aug 6, 2020
1 parent 79ebfe5 commit 277f154
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 2 deletions.
12 changes: 12 additions & 0 deletions presto-docs/src/main/sphinx/functions/bitwise.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,16 @@ Bitwise Functions

Returns the bitwise XOR of ``x`` and ``y`` in 2's complement representation.

.. function:: bitwise_left_shift(value, shift) -> [same as value]

Returns the left shifted value of ``value``.

.. function:: bitwise_right_shift(value, shift, digits) -> [same as value]

Returns the logical right shifted value of ``value``.

.. function:: bitwise_right_shift_arithmetic(value, shift) -> [same as value]

Returns the arithmetic right shifted value of ``value``.

See also :func:`bitwise_and_agg` and :func:`bitwise_or_agg`.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

public final class BitwiseFunctions
{
private static final long TINYINT_MASK = 0b1111_1111L;
private static final long TINYINT_SIGNED_BIT = 0b1000_0000L;
private static final long SMALLINT_MASK = 0b1111_1111_1111_1111L;
private static final long SMALLINT_SIGNED_BIT = 0b1000_0000_0000_0000L;
private static final long INTEGER_MASK = 0x00_00_00_00_ff_ff_ff_ffL;
private static final long INTEGER_SIGNED_BIT = 0x00_00_00_00_00_80_00_00_00L;

private BitwiseFunctions() {}

@Description("Count number of set bits in 2's complement representation")
Expand Down Expand Up @@ -75,4 +82,178 @@ public static long bitwiseXor(@SqlType(StandardTypes.BIGINT) long left, @SqlType
{
return left ^ right;
}

@Description("bitwise left shift")
@ScalarFunction("bitwise_left_shift")
@SqlType(StandardTypes.TINYINT)
public static long bitwiseLeftShiftTinyint(@SqlType(StandardTypes.TINYINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
long shifted = (value << shift);
return preserveSign(shifted, TINYINT_MASK, TINYINT_SIGNED_BIT);
}

@Description("bitwise left shift")
@ScalarFunction("bitwise_left_shift")
@SqlType(StandardTypes.SMALLINT)
public static long bitwiseLeftShiftSmallint(@SqlType(StandardTypes.SMALLINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
long shifted = (value << shift);
return preserveSign(shifted, SMALLINT_MASK, SMALLINT_SIGNED_BIT);
}

@Description("bitwise left shift")
@ScalarFunction("bitwise_left_shift")
@SqlType(StandardTypes.INTEGER)
public static long bitwiseLeftShiftInteger(@SqlType(StandardTypes.INTEGER) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
long shifted = (value << shift);
return preserveSign(shifted, INTEGER_MASK, INTEGER_SIGNED_BIT);
}

@Description("bitwise left shift")
@ScalarFunction("bitwise_left_shift")
@SqlType(StandardTypes.BIGINT)
public static long bitwiseLeftShiftBigint(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
return value << shift;
}

private static long preserveSign(long shiftedValue, long mask, long signedBit)
{
if ((shiftedValue & signedBit) != 0) {
// Preserve the sign in 2's complement format
return shiftedValue | ~mask;
}

return shiftedValue & mask;
}

@Description("bitwise logical right shift")
@ScalarFunction("bitwise_right_shift")
@SqlType(StandardTypes.TINYINT)
public static long bitwiseRightShiftTinyint(@SqlType(StandardTypes.TINYINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
if (shift == 0) {
return value;
}
return (value & TINYINT_MASK) >>> shift;
}

@Description("bitwise logical right shift")
@ScalarFunction("bitwise_right_shift")
@SqlType(StandardTypes.SMALLINT)
public static long bitwiseRightShiftSmallint(@SqlType(StandardTypes.SMALLINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
if (shift == 0) {
return value;
}
return (value & SMALLINT_MASK) >>> shift;
}

@Description("bitwise logical right shift")
@ScalarFunction("bitwise_right_shift")
@SqlType(StandardTypes.INTEGER)
public static long bitwiseRightShiftInteger(@SqlType(StandardTypes.INTEGER) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
if (shift == 0) {
return value;
}
return (value & INTEGER_MASK) >>> shift;
}

@Description("bitwise logical right shift")
@ScalarFunction("bitwise_right_shift")
@SqlType(StandardTypes.BIGINT)
public static long bitwiseRightShiftBigint(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
return 0L;
}
return value >>> shift;
}

@Description("bitwise arithmetic right shift")
@ScalarFunction("bitwise_right_shift_arithmetic")
@SqlType(StandardTypes.TINYINT)
public static long bitwiseRightShiftArithmeticTinyint(@SqlType(StandardTypes.TINYINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
if (value >= 0) {
return 0L;
}
else {
return -1L;
}
}
return preserveSign(value, TINYINT_MASK, TINYINT_SIGNED_BIT) >> shift;
}

@Description("bitwise arithmetic right shift")
@ScalarFunction("bitwise_right_shift_arithmetic")
@SqlType(StandardTypes.SMALLINT)
public static long bitwiseRightShiftArithmeticSmallint(@SqlType(StandardTypes.SMALLINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
if (value >= 0) {
return 0L;
}
else {
return -1L;
}
}
return preserveSign(value, SMALLINT_MASK, SMALLINT_SIGNED_BIT) >> shift;
}

@Description("bitwise arithmetic right shift")
@ScalarFunction("bitwise_right_shift_arithmetic")
@SqlType(StandardTypes.INTEGER)
public static long bitwiseRightShiftArithmeticInteger(@SqlType(StandardTypes.INTEGER) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
if (value >= 0) {
return 0L;
}
else {
return -1L;
}
}
return preserveSign(value, INTEGER_MASK, INTEGER_SIGNED_BIT) >> shift;
}

@Description("bitwise arithmetic right shift")
@ScalarFunction("bitwise_right_shift_arithmetic")
@SqlType(StandardTypes.BIGINT)
public static long bitwiseRightShiftArithmeticBigint(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long shift)
{
if (shift >= 64) {
if (value >= 0) {
return 0L;
}
else {
return -1L;
}
}
return value >> shift;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.testng.annotations.Test;

import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.IntegerType.INTEGER;
import static io.prestosql.spi.type.SmallintType.SMALLINT;
import static io.prestosql.spi.type.TinyintType.TINYINT;

public class TestBitwiseFunctions
extends AbstractTestFunctions
Expand Down Expand Up @@ -90,4 +93,77 @@ public void testBitwiseXor()
assertFunction("bitwise_xor(-4, 12)", BIGINT, -4L ^ 12L);
assertFunction("bitwise_xor(60, 21)", BIGINT, 60L ^ 21L);
}

@Test
public void testBitwiseLeftShift()
{
assertFunction("bitwise_left_shift(TINYINT'7', 2)", TINYINT, (byte) (7 << 2));
assertFunction("bitwise_left_shift(TINYINT '-7', 2)", TINYINT, (byte) (-7 << 2));
assertFunction("bitwise_left_shift(TINYINT '1', 7)", TINYINT, (byte) (1 << 7));
assertFunction("bitwise_left_shift(TINYINT '-128', 1)", TINYINT, (byte) 0);
assertFunction("bitwise_left_shift(TINYINT '-65', 1)", TINYINT, (byte) (-65 << 1));
assertFunction("bitwise_left_shift(TINYINT '-7', 64)", TINYINT, (byte) 0);
assertFunction("bitwise_left_shift(TINYINT '-128', 0)", TINYINT, (byte) -128);
assertFunction("bitwise_left_shift(SMALLINT '7', 2)", SMALLINT, (short) (7 << 2));
assertFunction("bitwise_left_shift(SMALLINT '-7', 2)", SMALLINT, (short) (-7 << 2));
assertFunction("bitwise_left_shift(SMALLINT '1', 7)", SMALLINT, (short) (1 << 7));
assertFunction("bitwise_left_shift(SMALLINT '-32768', 1)", SMALLINT, (short) 0);
assertFunction("bitwise_left_shift(SMALLINT '-65', 1)", SMALLINT, (short) (-65 << 1));
assertFunction("bitwise_left_shift(SMALLINT '-7', 64)", SMALLINT, (short) 0);
assertFunction("bitwise_left_shift(SMALLINT '-32768', 0)", SMALLINT, (short) -32768);
assertFunction("bitwise_left_shift(INTEGER '7', 2)", INTEGER, 7 << 2);
assertFunction("bitwise_left_shift(INTEGER '-7', 2)", INTEGER, -7 << 2);
assertFunction("bitwise_left_shift(INTEGER '1', 7)", INTEGER, 1 << 7);
assertFunction("bitwise_left_shift(INTEGER '-2147483648', 1)", INTEGER, 0);
assertFunction("bitwise_left_shift(INTEGER '-65', 1)", INTEGER, -65 << 1);
assertFunction("bitwise_left_shift(INTEGER '-7', 64)", INTEGER, 0);
assertFunction("bitwise_left_shift(INTEGER '-2147483648', 0)", INTEGER, -2147483648);
assertFunction("bitwise_left_shift(BIGINT '7', 2)", BIGINT, 7L << 2);
assertFunction("bitwise_left_shift(BIGINT '-7', 2)", BIGINT, -7L << 2);
assertFunction("bitwise_left_shift(BIGINT '-7', 64)", BIGINT, 0L);
}

@Test
public void testBitwiseRightShift()
{
assertFunction("bitwise_right_shift(TINYINT '7', 2)", TINYINT, (byte) (7 >>> 2));
assertFunction("bitwise_right_shift(TINYINT '-7', 2)", TINYINT, (byte) 62);
assertFunction("bitwise_right_shift(TINYINT '-7', 64)", TINYINT, (byte) 0);
assertFunction("bitwise_right_shift(TINYINT '-128', 0)", TINYINT, (byte) -128);
assertFunction("bitwise_right_shift(SMALLINT '7', 2)", SMALLINT, (short) (7 >>> 2));
assertFunction("bitwise_right_shift(SMALLINT '-7', 2)", SMALLINT, (short) 16382);
assertFunction("bitwise_right_shift(SMALLINT '-7', 64)", SMALLINT, (short) 0);
assertFunction("bitwise_right_shift(SMALLINT '-32768', 0)", SMALLINT, (short) -32768);
assertFunction("bitwise_right_shift(INTEGER '7', 2)", INTEGER, 7 >>> 2);
assertFunction("bitwise_right_shift(INTEGER '-7', 2)", INTEGER, 1073741822);
assertFunction("bitwise_right_shift(INTEGER '-7', 64)", INTEGER, 0);
assertFunction("bitwise_right_shift(INTEGER '-2147483648', 0)", INTEGER, -2147483648);
assertFunction("bitwise_right_shift(BIGINT '7', 2)", BIGINT, 7L >>> 2);
assertFunction("bitwise_right_shift(BIGINT '-7', 2)", BIGINT, -7L >>> 2);
assertFunction("bitwise_right_shift(BIGINT '-7', 64)", BIGINT, 0L);
}

@Test
public void testBitwiseRightShiftArithmetic()
{
assertFunction("bitwise_right_shift_arithmetic(TINYINT '7', 2)", TINYINT, (byte) (7 >> 2));
assertFunction("bitwise_right_shift_arithmetic(TINYINT '-7', 2)", TINYINT, (byte) (-7 >> 2));
assertFunction("bitwise_right_shift_arithmetic(TINYINT '7', 64)", TINYINT, (byte) 0);
assertFunction("bitwise_right_shift_arithmetic(TINYINT '-7', 64)", TINYINT, (byte) -1);
assertFunction("bitwise_right_shift_arithmetic(TINYINT '-128', 0)", TINYINT, (byte) -128);
assertFunction("bitwise_right_shift_arithmetic(SMALLINT '7', 2)", SMALLINT, (short) (7 >> 2));
assertFunction("bitwise_right_shift_arithmetic(SMALLINT '-7', 2)", SMALLINT, (short) (-7 >> 2));
assertFunction("bitwise_right_shift_arithmetic(SMALLINT '7', 64)", SMALLINT, (short) 0);
assertFunction("bitwise_right_shift_arithmetic(SMALLINT '-7', 64)", SMALLINT, (short) -1);
assertFunction("bitwise_right_shift_arithmetic(SMALLINT '-32768', 0)", SMALLINT, (short) -32768);
assertFunction("bitwise_right_shift_arithmetic(INTEGER '7', 2)", INTEGER, (7 >> 2));
assertFunction("bitwise_right_shift_arithmetic(INTEGER '-7', 2)", INTEGER, -7 >> 2);
assertFunction("bitwise_right_shift_arithmetic(INTEGER '7', 64)", INTEGER, 0);
assertFunction("bitwise_right_shift_arithmetic(INTEGER '-7', 64)", INTEGER, -1);
assertFunction("bitwise_right_shift_arithmetic(INTEGER '-2147483648', 0)", INTEGER, -2147483648);
assertFunction("bitwise_right_shift_arithmetic(BIGINT '7', 2)", BIGINT, 7L >> 2);
assertFunction("bitwise_right_shift_arithmetic(BIGINT '-7', 2)", BIGINT, -7L >> 2);
assertFunction("bitwise_right_shift_arithmetic(BIGINT '7', 64)", BIGINT, 0L);
assertFunction("bitwise_right_shift_arithmetic(BIGINT '-7', 64)", BIGINT, -1L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public void testShowFunctionLike()
{
assertThat(assertions.query("SHOW FUNCTIONS LIKE 'split%'"))
.matches("VALUES " +
"(cast('split' AS VARCHAR(24)), cast('array(varchar(x))' AS VARCHAR(28)), cast('varchar(x), varchar(y)' AS VARCHAR(68)), cast('scalar' AS VARCHAR(9)), true, cast('' AS VARCHAR(131)))," +
"(cast('split' AS VARCHAR(30)), cast('array(varchar(x))' AS VARCHAR(28)), cast('varchar(x), varchar(y)' AS VARCHAR(68)), cast('scalar' AS VARCHAR(9)), true, cast('' AS VARCHAR(131)))," +
"('split', 'array(varchar(x))', 'varchar(x), varchar(y), bigint', 'scalar', true, '')," +
"('split_part', 'varchar(x)', 'varchar(x), varchar(y), bigint', 'scalar', true, 'Splits a string by a delimiter and returns the specified field (counting from one)')," +
"('split_to_map', 'map(varchar,varchar)', 'varchar, varchar, varchar', 'scalar', true, 'Creates a map using entryDelimiter and keyValueDelimiter')," +
Expand All @@ -98,7 +98,7 @@ public void testShowFunctionsLikeWithEscape()
{
assertThat(assertions.query("SHOW FUNCTIONS LIKE 'split$_to$_%' ESCAPE '$'"))
.matches("VALUES " +
"(cast('split_to_map' AS VARCHAR(24)), cast('map(varchar,varchar)' AS VARCHAR(28)), cast('varchar, varchar, varchar' AS VARCHAR(68)), cast('scalar' AS VARCHAR(9)), true, cast('Creates a map using entryDelimiter and keyValueDelimiter' AS VARCHAR(131)))," +
"(cast('split_to_map' AS VARCHAR(30)), cast('map(varchar,varchar)' AS VARCHAR(28)), cast('varchar, varchar, varchar' AS VARCHAR(68)), cast('scalar' AS VARCHAR(9)), true, cast('Creates a map using entryDelimiter and keyValueDelimiter' AS VARCHAR(131)))," +
"('split_to_multimap', 'map(varchar,array(varchar))', 'varchar, varchar, varchar', 'scalar', true, 'Creates a multimap by splitting a string into key/value pairs')");
}

Expand Down

0 comments on commit 277f154

Please sign in to comment.