Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize decimal state serializers for small value case #13573

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;

import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class LongDecimalWithOverflowAndLongStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowAndLongState>
{
private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + Int128.SIZE;

@Override
public Type getSerializedType()
{
Expand All @@ -42,7 +39,27 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
sopel39 marked this conversation as resolved.
Show resolved Hide resolved
VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal[offset], decimal[offset + 1]));
long[] buffer = new long[4];
sopel39 marked this conversation as resolved.
Show resolved Hide resolved
long high = decimal[offset];
long low = decimal[offset + 1];

buffer[0] = low;
buffer[1] = high;
// if high = 0, the count will overwrite it
int countOffset = 1 + (high == 0 ? 0 : 1);
// append count, overflow
buffer[countOffset] = count;
buffer[countOffset + 1] = overflow;

// cases
// high == 0 (countOffset = 1)
// overflow == 0 & count == 1 -> bufferLength = 1
// overflow != 0 || count != 1 -> bufferLength = 3
// high != 0 (countOffset = 2)
// overflow == 0 & count == 1 -> bufferLength = 2
// overflow != 0 || count != 1 -> bufferLength = 4
int bufferLength = countOffset + ((overflow == 0 & count == 1) ? 0 : 2);
VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength));
}
else {
out.appendNull();
Expand All @@ -54,20 +71,33 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt
{
if (!block.isNull(index)) {
Slice slice = VARBINARY.getSlice(block, index);
if (slice.length() != SERIALIZED_SIZE) {
throw new IllegalStateException("Unexpected serialized state size: " + slice.length());
}
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long count = slice.getLong(0);
long overflow = slice.getLong(Long.BYTES);
int sliceLength = slice.length();
long low = slice.getLong(0);
long high = 0;
long overflow = 0;
long count = 1;

state.setLong(count);
switch (sliceLength) {
case 4 * Long.BYTES:
overflow = slice.getLong(Long.BYTES * 3);
count = slice.getLong(Long.BYTES * 2);
// fall through
case 2 * Long.BYTES:
high = slice.getLong(Long.BYTES);
break;
case 3 * Long.BYTES:
overflow = slice.getLong(Long.BYTES * 2);
count = slice.getLong(Long.BYTES);
}

decimal[offset + 1] = low;
decimal[offset] = high;
state.setOverflow(overflow);
state.setLong(count);
state.setNotNull();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
decimal[offset] = slice.getLong(Long.BYTES * 2);
decimal[offset + 1] = slice.getLong(Long.BYTES * 3);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;

import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class LongDecimalWithOverflowStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowState>
{
private static final int SERIALIZED_SIZE = Long.BYTES + Int128.SIZE;

@Override
public Type getSerializedType()
{
Expand All @@ -41,7 +38,18 @@ public void serialize(LongDecimalWithOverflowState state, BlockBuilder out)
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal[offset], decimal[offset + 1]));
long[] buffer = new long[3];
lukasz-stec marked this conversation as resolved.
Show resolved Hide resolved
long low = decimal[offset + 1];
long high = decimal[offset];
buffer[0] = low;
buffer[1] = high;
buffer[2] = overflow;
// if high == 0 and overflow == 0 we only write low (bufferLength = 1)
// if high != 0 and overflow == 0 we write both low and high (bufferLength = 2)
// if overflow != 0 we write all values (bufferLength = 3)
int decimalsCount = 1 + (high == 0 ? 0 : 1);
int bufferLength = overflow == 0 ? decimalsCount : 3;
sopel39 marked this conversation as resolved.
Show resolved Hide resolved
VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength));
}
else {
out.appendNull();
Expand All @@ -53,18 +61,26 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta
{
if (!block.isNull(index)) {
Slice slice = VARBINARY.getSlice(block, index);
if (slice.length() != SERIALIZED_SIZE) {
throw new IllegalStateException("Unexpected serialized state size: " + slice.length());
}
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long overflow = slice.getLong(0);
long low = slice.getLong(0);
int sliceLength = slice.length();
long high = 0;
long overflow = 0;

switch (sliceLength) {
case 3 * Long.BYTES:
overflow = slice.getLong(Long.BYTES * 2);
// fall through
case 2 * Long.BYTES:
high = slice.getLong(Long.BYTES);
}

decimal[offset + 1] = low;
decimal[offset] = high;
state.setOverflow(overflow);
state.setNotNull();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
decimal[offset] = slice.getLong(Long.BYTES);
decimal[offset + 1] = slice.getLong(Long.BYTES * 2);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.VariableWidthBlockBuilder;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestLongDecimalWithOverflowAndLongStateSerializer
{
private static final LongDecimalWithOverflowAndLongStateFactory STATE_FACTORY = new LongDecimalWithOverflowAndLongStateFactory();

@Test(dataProvider = "input")
public void testSerde(long low, long high, long overflow, long count, int expectedLength)
{
LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState();
state.getDecimalArray()[0] = high;
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setLong(count);
state.setNotNull();

LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength);

assertTrue(outState.isNotNull());
assertEquals(outState.getDecimalArray()[0], high);
assertEquals(outState.getDecimalArray()[1], low);
assertEquals(outState.getOverflow(), overflow);
assertEquals(outState.getLong(), count);
}

@Test
public void testNullSerde()
{
// state is created null
LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState();

LongDecimalWithOverflowAndLongState outState = roundTrip(state, 0);

assertFalse(outState.isNotNull());
}

private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength)
{
LongDecimalWithOverflowAndLongStateSerializer serializer = new LongDecimalWithOverflowAndLongStateSerializer();
BlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0);

serializer.serialize(state, out);

Block serialized = out.build();
assertEquals(serialized.getSliceLength(0), expectedLength * Long.BYTES);
LongDecimalWithOverflowAndLongState outState = STATE_FACTORY.createSingleState();
serializer.deserialize(serialized, 0, outState);
return outState;
}

@DataProvider
public Object[][] input()
sopel39 marked this conversation as resolved.
Show resolved Hide resolved
{
return new Object[][] {
{3, 0, 0, 1, 1},
{3, 5, 0, 1, 2},
{3, 5, 7, 1, 4},
{3, 0, 0, 2, 3},
{3, 5, 0, 2, 4},
{3, 5, 7, 2, 4},
{3, 0, 7, 1, 3},
{3, 0, 7, 2, 3},
{0, 0, 0, 1, 1},
{0, 5, 0, 1, 2},
{0, 5, 7, 1, 4},
{0, 0, 0, 2, 3},
{0, 5, 0, 2, 4},
{0, 5, 7, 2, 4},
{0, 0, 7, 1, 3},
{0, 0, 7, 2, 3}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.VariableWidthBlockBuilder;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestLongDecimalWithOverflowStateSerializer
{
private static final LongDecimalWithOverflowStateFactory STATE_FACTORY = new LongDecimalWithOverflowStateFactory();

@Test(dataProvider = "input")
public void testSerde(long low, long high, long overflow, int expectedLength)
{
LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState();
state.getDecimalArray()[0] = high;
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setNotNull();

LongDecimalWithOverflowState outState = roundTrip(state, expectedLength);

assertTrue(outState.isNotNull());
assertEquals(outState.getDecimalArray()[0], high);
assertEquals(outState.getDecimalArray()[1], low);
assertEquals(outState.getOverflow(), overflow);
}

@Test
public void testNullSerde()
{
// state is created null
LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState();

LongDecimalWithOverflowState outState = roundTrip(state, 0);

assertFalse(outState.isNotNull());
}

private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength)
{
LongDecimalWithOverflowStateSerializer serializer = new LongDecimalWithOverflowStateSerializer();
BlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0);

serializer.serialize(state, out);

Block serialized = out.build();
assertEquals(serialized.getSliceLength(0), expectedLength * Long.BYTES);
LongDecimalWithOverflowState outState = STATE_FACTORY.createSingleState();
serializer.deserialize(serialized, 0, outState);
return outState;
}

@DataProvider
public Object[][] input()
{
return new Object[][] {
{3, 0, 0, 1},
{3, 5, 0, 2},
{3, 5, 7, 3},
{3, 0, 7, 3},
{0, 0, 0, 1},
{0, 5, 0, 2},
{0, 5, 7, 3},
{0, 0, 7, 3}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1400,4 +1400,16 @@ public void testApproxMostFrequentWithStringGroupBy()
assertEquals(actual1.getMaterializedRows().get(2).getFields().get(0), "c");
assertEquals(actual1.getMaterializedRows().get(2).getFields().get(1), ImmutableMap.of("C", 2L));
}

@Test
public void testLongDecimalAggregations()
{
assertQuery("""
SELECT avg(value_big), sum(value_big), avg(value_small), sum(value_small)
FROM (
SELECT orderkey as id, CAST(power(2, 65) as DECIMAL(38, 0)) as value_big, CAST(1 as DECIMAL(38, 0)) as value_small
FROM orders
LIMIT 10)
GROUP BY id""");
}
}