Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize decimal state serializers for small value case #13573

Merged

Conversation

lukasz-stec
Copy link
Member

@lukasz-stec lukasz-stec commented Aug 9, 2022

Description

Given that many decimal aggregations (sum, avg) stay in the long range, aggregation state serializer can be optimized for this case, limiting the number of bytes per position significantly (3-4X).

tpch/tpcds benchmarks

image

Benchmarks_decimal_aggr_serde_simple_case.pdf

Is this change a fix, improvement, new feature, refactoring, or other?

improvement

Is this a change to the core query engine, a connector, client library, or the SPI interfaces? (be specific)

core query engine (sum, avg aggregation state serialization)

How would you describe this change to a non-technical end user or system administrator?

improve performance of queries with sum or avg aggregations

Related issues, pull requests, and links

Documentation

( X) No documentation is needed.
( ) Sufficient documentation is included in this PR.
( ) Documentation PR is available with #prnumber.
( ) Documentation issue #issuenumber is filed, and can be handled later.

Release notes

(X ) No release notes entries required.
( ) Release notes entries required with the following suggested text:

# Section
* Fix some things. ({issue}`issuenumber`)

@cla-bot cla-bot bot added the cla-signed label Aug 9, 2022
@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch 3 times, most recently from ea406fd to 1f42004 Compare August 10, 2022 13:02
@lukasz-stec lukasz-stec marked this pull request as ready for review August 10, 2022 19:44
@@ -42,7 +42,13 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal[offset], decimal[offset + 1]));
if (count == 1 && overflow == 0 && decimal[offset] == 0) {
Copy link
Member

@sopel39 sopel39 Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count == 1 && overflow == 0 truncate should be separate from decimal[offset] == 0

decimal[offset] == 0 will often be true when count == 1 && overflow == 0 isn't.

So you have cases:

  • 1 long -> count == 1 && overflow == 0 && decimal[offset] == 0
  • 2 longs -> count == 1 && overflow == 0
  • 3 longs -> decimal[offset] == 0
  • 4 longs -> full case

I think you can even make it branchless:

  • append decimal[offset + 1] unconditionally, len += 1
  • append decimal[offset] unconditionally, len = len + (decimal[offset] == 0 ? 0 : 1)
  • append overflow and count, len = len + ((count == 1 & overflow == 0) ? 0 : 2)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may help in some cases, I will try that. This probably won't help in tpch q17/ q18 as it's already hitting the condition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementation you suggested was not correct, I did something similar but I couldn't figure out an efficient way to do branchless in deserialization. Let me know if you see one.

}
else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lukasz-stec lukasz-stec marked this pull request as draft August 11, 2022 14:36
@lukasz-stec
Copy link
Member Author

Made it draft again as I'm gonna benchmark supporting more cases

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from da7883e to f213351 Compare August 12, 2022 12:03
@lukasz-stec lukasz-stec changed the title Optimize decimal state serializers for low value case Optimize decimal state serializers for small value case Aug 12, 2022
@lukasz-stec lukasz-stec marked this pull request as ready for review August 12, 2022 12:06
@lukasz-stec
Copy link
Member Author

Currently, our benchmarking infra is failing so the latest benchmark results will available next week probably. cc @sopel39 @gaurav8297

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from f213351 to db86662 Compare August 16, 2022 14:05
Copy link
Member

@raunaqmorarka raunaqmorarka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is BenchmarkBlockSerde#serializeLongDecimal or any other other JMH benchmark relevant here ?

Copy link
Member Author

@lukasz-stec lukasz-stec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments addressed

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from db86662 to 34ed7ef Compare August 17, 2022 09:09
@lukasz-stec
Copy link
Member Author

Is BenchmarkBlockSerde#serializeLongDecimal or any other other JMH benchmark relevant here ?

@raunaqmorarka I think BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate is the best match for those changes although the first step and 50% of the benchmark time is the addInput.

I don't see many changes in this microbenchmark. The important thing to notice is that the number of groups thus random memory load times is the most important factor here.

Before
Benchmark                                                  (function)  (groupCount)  (type)  Mode  Cnt   Score   Error  Units
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         avg          1000    LONG  avgt   20   7.190 ± 0.307  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         avg       1000000    LONG  avgt   20  99.130 ± 1.451  ns/op

After
Benchmark                                                  (function)  (groupCount)  (type)  Mode  Cnt   Score   Error  Units
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         avg          1000    LONG  avgt   20   6.856 ± 0.037  ns/op
BenchmarkDecimalAggregation.benchmarkEvaluateIntermediate         avg       1000000    LONG  avgt   20  94.083 ± 1.077  ns/op

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from 34ed7ef to 87aafb6 Compare August 17, 2022 10:35
buffer[1] = decimalHighBytes;
buffer[2] = overflow;
buffer[3] = count;
// if decimalHighBytes == 0 and count == 1 and overflow == 0 we only write decimalLowBytes (bufferLength = 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You miss one important case: 3 longs -> decimal[offset] == 0
see #13573 (comment)

Copy link
Member

@sopel39 sopel39 Aug 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can implement it by:

// append low
buffer[0] = decimalLowBytes;
offset = 1;
// append high
buffer[offset] = decimalHiBytes;
offset = offset + (decimalHiBytes == 0 ? 0 : 1);
// append overflow, count
buffer[offset] = overflow;
buffer[offset + 1] = count;
offset = offset + (overflow == 0 & count == 1) ? 0 : 2 // will this be branchless really?

long high = 0;
long overflow = 0;
long count = 1;
if (slice.length() > Long.BYTES) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

branchless version

int sliceLength = slices.length();
long low = slice.getLong(0);
long highOffset = sliceLength == 4 | sliceLength == 2 ? 1 : 0; 
long high = slice.getLong(highOffset) & (highOffset * -1L);
// similar for count & overflow

Would be great to learn if JIT actually generates branchless versions

@lukasz-stec
Copy link
Member Author

the latest version tpch/tpcds benchmarks

good improvement for tpch
image

Benchmarks_decimal_aggr_serde_opt.pdf

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch 2 times, most recently from 29a2fae to 1870888 Compare August 25, 2022 07:47
Copy link
Member Author

@lukasz-stec lukasz-stec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented the additional case + UT.
For the branchless deserialization. I tested multiple ways (including offered by @sopel39 ) and it there is a little difference in jmh for mixed decimals cases (block with both small and large decimals) but if/switch approach is a little better for both small and large cases where the branch prediction works efficiently.

The tpc benchmarks look very good with the current code (switch).

@lukasz-stec
Copy link
Member Author

tpch/tpcds orc part 1k on latest version (missing case implemented + using switch in deserailziation).
Big gains for tpch, some for tpcds.

image

Benchmarks_decimal_aggr_serde_opt_orc_part_sf1k_250822.pdf

Copy link
Member

@sopel39 sopel39 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm [Optimize decimal state serializers for small value case](https://github.com/trinodb/trino/pull/13573/commits/18708885459f1921b720148552a14d1ca8ff7c93) % comments

@@ -53,12 +54,12 @@
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@Warmup(iterations = 20, time = 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is time needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is 10s, this means 1s.
That said, looking through generated assembly code i noticed that 1s may not be enough to trigger c2. let me check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, 1s is not enough to trigger c2. 5s warmup does the trick


@Benchmark
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public Object benchmarkDeserialize(BenchmarkData data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it's called benchmarkDeserialize? I think processPage does more than just deserialization

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no way to benchmark deserialize only so I do processPage with unique group ids which makes combine behavior simple

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, for a code reader this is confusing.

there is no way to benchmark deserialize only so

Sure there is: LongDecimalWithOverflowAndLongStateSerializer is a class and there is nothing stopping us from benchmarking it directly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmarking LongDecimalWithOverflowAndLongStateSerializer#deserialize in isolation has little value. it's always used in loop body in addIntermediate with combine and it's supposed to be inlined there (that's the reason for code generation)

@Param({"SHORT", "LONG", "MIXED"})
private String decimalSize = "SHORT";

@Param({"true", "false"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does group ordering matter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't matter for serialization, but for evaluateIntermediate the main bottleneck is actually jumping to random memory locations in case groups are located randomly.
For the partial aggregation adaptation case, the group ids are always consecutive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the partial aggregation adaptation case, the group ids are always consecutive.

This has little to do with serialization.

It doesn't matter for serialization, but for evaluateIntermediate the main bottleneck is actually jumping to random memory locations in case groups are located randomly.

For skip-cases groups are consecutive and number of them is small, so it probably doesn't matter too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has little to do with serialization.

That's what I said ^^

For skip-cases groups are consecutive and number of them is small, so it probably doesn't matter too.

For skip cases (I understand you mean PA disabled) yes, for other cases it does matter.

@@ -39,6 +39,7 @@
import org.testng.annotations.Test;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be separate PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all changes to BenchmarkDecimalAggregation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sopel39 ping

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extracted BenchmarkDecimalAggregation changes to #13939

@@ -108,36 +130,66 @@ public Block benchmarkEvaluateFinal(BenchmarkData data)
@Param({"10", "1000"})
private int groupCount = 10;

@Param({"SHORT", "LONG", "MIXED"})
private String decimalSize = "SHORT";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decimalSize=LONG won't work with type=SHORT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, I use larger long values (close to Long.MAX_VALUE)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, I use larger long values (close to Long.MAX_VALUE)

It doesn't matter from perf perspective.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can matter. it doesn't at the moment. I can also skip this case for type=SHORT if that bothers you.

values = createValues(functionResolution, type, (builder, value) -> {
boolean writeShort = "SHORT".equals(decimalSize) || ("MIXED".equals(decimalSize) && random.nextBoolean());
if (writeShort) {
builder.writeLong(value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both cases are long really writeLong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but if we use varint those are different cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but if we use varint those are different cases.

Here SHORT vs LONG is mixed with Short/Long decimal. This is confusing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced SHORT, LONG with SMALL, BIG

@sopel39
Copy link
Member

sopel39 commented Aug 29, 2022

Nice gains

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from 1870888 to ab4315c Compare August 30, 2022 11:23
Copy link
Member Author

@lukasz-stec lukasz-stec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments addressed


@Benchmark
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public Object benchmarkDeserialize(BenchmarkData data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no way to benchmark deserialize only so I do processPage with unique group ids which makes combine behavior simple

@@ -108,36 +130,66 @@ public Block benchmarkEvaluateFinal(BenchmarkData data)
@Param({"10", "1000"})
private int groupCount = 10;

@Param({"SHORT", "LONG", "MIXED"})
private String decimalSize = "SHORT";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, I use larger long values (close to Long.MAX_VALUE)

@Param({"SHORT", "LONG", "MIXED"})
private String decimalSize = "SHORT";

@Param({"true", "false"})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't matter for serialization, but for evaluateIntermediate the main bottleneck is actually jumping to random memory locations in case groups are located randomly.
For the partial aggregation adaptation case, the group ids are always consecutive.

values = createValues(functionResolution, type, (builder, value) -> {
boolean writeShort = "SHORT".equals(decimalSize) || ("MIXED".equals(decimalSize) && random.nextBoolean());
if (writeShort) {
builder.writeLong(value);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but if we use varint those are different cases.

@@ -53,12 +54,12 @@
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@Warmup(iterations = 20, time = 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is 10s, this means 1s.
That said, looking through generated assembly code i noticed that 1s may not be enough to trigger c2. let me check

@@ -39,6 +39,7 @@
import org.testng.annotations.Test;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all changes to BenchmarkDecimalAggregation?

@@ -53,12 +54,12 @@
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@Warmup(iterations = 20, time = 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, 1s is not enough to trigger c2. 5s warmup does the trick

Copy link
Member

@sopel39 sopel39 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reviewed again

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from ab4315c to b506c13 Compare August 30, 2022 20:54
Copy link
Member Author

@lukasz-stec lukasz-stec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments addressed


@Benchmark
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public Object benchmarkDeserialize(BenchmarkData data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmarking LongDecimalWithOverflowAndLongStateSerializer#deserialize in isolation has little value. it's always used in loop body in addIntermediate with combine and it's supposed to be inlined there (that's the reason for code generation)

@@ -108,36 +130,66 @@ public Block benchmarkEvaluateFinal(BenchmarkData data)
@Param({"10", "1000"})
private int groupCount = 10;

@Param({"SHORT", "LONG", "MIXED"})
private String decimalSize = "SHORT";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can matter. it doesn't at the moment. I can also skip this case for type=SHORT if that bothers you.

@Param({"SHORT", "LONG", "MIXED"})
private String decimalSize = "SHORT";

@Param({"true", "false"})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has little to do with serialization.

That's what I said ^^

For skip-cases groups are consecutive and number of them is small, so it probably doesn't matter too.

For skip cases (I understand you mean PA disabled) yes, for other cases it does matter.

values = createValues(functionResolution, type, (builder, value) -> {
boolean writeShort = "SHORT".equals(decimalSize) || ("MIXED".equals(decimalSize) && random.nextBoolean());
if (writeShort) {
builder.writeLong(value);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced SHORT, LONG with SMALL, BIG

@@ -39,6 +39,7 @@
import org.testng.annotations.Test;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sopel39 ping

@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from b506c13 to ac1f861 Compare August 31, 2022 09:40
Given that many decimal aggregations (sum, avg) stay
in the long range, aggregation state serializer can
be optimized for this case, limiting the number of
bytes per position significantly (3-4X) at the cost of
small cpu overhead during serialization and deserialization.
@lukasz-stec lukasz-stec force-pushed the ls/034-paa-decimal-aggregation-serde branch from ac1f861 to 923fa7a Compare August 31, 2022 10:28
@sopel39 sopel39 merged commit 182b44e into trinodb:master Aug 31, 2022
@sopel39 sopel39 mentioned this pull request Aug 31, 2022
@github-actions github-actions bot added this to the 395 milestone Aug 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

Successfully merging this pull request may close these issues.

3 participants