From 0a2f7757e38b07c2e40c5c885a7100e8c16eb40f Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Thu, 14 Dec 2023 18:54:05 -0800 Subject: [PATCH] Branch-free varint size calculation. On android art, compiles to: ``` int com.google.android.test.Outer.computeUInt32SizeNoTag(int) [24 bytes] 0x000024d0 mov w0, #0x160 0x000024d4 clz w1, w1 0x000024d8 add w1, w1, w1, lsl #3 0x000024dc sub w0, w0, w1 0x000024e0 lsr w0, w0, #6 0x000024e4 ret ``` versus existing: ``` int com.google.android.test.Outer.computeUInt32SizeNoTag(int) [72 bytes] 0x000022a0 and w0, w1, #0xffffff80 0x000022a4 cbnz w0, #+0xc (addr 0x22b0) 0x000022a8 mov w0, #0x1 0x000022ac b #+0x38 (addr 0x22e4) 0x000022b0 and w0, w1, #0xffffc000 0x000022b4 cbnz w0, #+0xc (addr 0x22c0) 0x000022b8 mov w0, #0x2 0x000022bc b #+0x28 (addr 0x22e4) 0x000022c0 and w0, w1, #0xffe00000 0x000022c4 cbnz w0, #+0xc (addr 0x22d0) 0x000022c8 mov w0, #0x3 0x000022cc b #+0x18 (addr 0x22e4) 0x000022d0 mov w2, #0x5 0x000022d4 mov w0, #0x4 0x000022d8 and w1, w1, #0xf0000000 0x000022dc cmp w1, #0x0 (0) 0x000022e0 csel w0, w2, w0, ne 0x000022e4 ret ``` PiperOrigin-RevId: 591113652 --- .../google/protobuf/CodedOutputStream.java | 94 ++++++++++--------- .../protobuf/CodedOutputStreamTest.java | 75 +++++++++++++++ 2 files changed, 125 insertions(+), 44 deletions(-) diff --git a/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java b/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java index 37bb44a05780..cd8127bd63f4 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java @@ -40,7 +40,9 @@ public abstract class CodedOutputStream extends ByteOutput { /** Used to adapt to the experimental {@link Writer} interface. */ CodedOutputStreamWriter wrapper; - /** @deprecated Use {@link #computeFixed32SizeNoTag(int)} instead. */ + /** + * @deprecated Use {@link #computeFixed32SizeNoTag(int)} instead. + */ @Deprecated public static final int LITTLE_ENDIAN_32_SIZE = FIXED32_SIZE; /** The buffer size used in {@link #newInstance(OutputStream)}. */ @@ -669,9 +671,8 @@ public static int computeRawMessageSetExtensionSize( } /** - * Compute the number of bytes that would be needed to encode a lazily parsed MessageSet - * extension field to the stream. For historical reasons, the wire format differs from normal - * fields. + * Compute the number of bytes that would be needed to encode a lazily parsed MessageSet extension + * field to the stream. For historical reasons, the wire format differs from normal fields. */ public static int computeLazyFieldMessageSetExtensionSize( final int fieldNumber, final LazyFieldLite value) { @@ -692,29 +693,52 @@ public static int computeTagSize(final int fieldNumber) { * tag. */ public static int computeInt32SizeNoTag(final int value) { - if (value >= 0) { - return computeUInt32SizeNoTag(value); - } else { - // Must sign-extend. - return MAX_VARINT_SIZE; - } + return computeUInt64SizeNoTag((long) value); } /** Compute the number of bytes that would be needed to encode a {@code uint32} field. */ public static int computeUInt32SizeNoTag(final int value) { - if ((value & (~0 << 7)) == 0) { - return 1; - } - if ((value & (~0 << 14)) == 0) { - return 2; - } - if ((value & (~0 << 21)) == 0) { - return 3; - } - if ((value & (~0 << 28)) == 0) { - return 4; - } - return 5; + /* + This code is ported from the C++ varint implementation. + Implementation notes: + + To calcuate varint size, we want to count the number of 7 bit chunks required. Rather than using + division by 7 to accomplish this, we use multiplication by 9/64. This has a number of important + properties: + * It's roughly 1/7.111111. This makes the 0 bits set case have the same value as the 7 bits set + case, so offsetting by 1 gives us the correct value we want for integers up to 448 bits. + * Multiplying by 9 is special. x * 9 = x << 3 + x, and so this multiplication can be done by a + single shifted add on arm (add w0, w0, w0, lsl #3), or a single lea instruction + (leal (%rax,%rax,8), %eax)) on x86. + * Dividing by 64 is a 6 bit right shift. + + An explicit non-sign-extended right shift is used instead of the more obvious '/ 64' because + that actually produces worse code on android arm64 at time of authoring because of sign + extension. Rather than + lsr w0, w0, #6 + It would emit: + add w16, w0, #0x3f (63) + cmp w0, #0x0 (0) + csel w0, w16, w0, lt + asr w0, w0, #6 + + Summarized: + floor(((Integer.SIZE - clz) / 7.1111) + 1 + ((Integer.SIZE - clz) * 9) / 64 + 1 + (((Integer.SIZE - clz) * 9) >>> 6) + 1 + ((Integer.SIZE - clz) * 9 + (1 << 6)) >>> 6 + (Integer.SIZE * 9 + (1 << 6) - clz * 9) >>> 6 + (352 - clz * 9) >>> 6 + on arm: + (352 - clz - (clz << 3)) >>> 6 + on x86: + (352 - lea(clz, clz, 8)) >>> 6 + + If you make changes here, please validate their compiled output on different architectures and + runtimes. + */ + int clz = Integer.numberOfLeadingZeros(value); + return ((Integer.SIZE * 9 + (1 << 6)) - (clz * 9)) >>> 6; } /** Compute the number of bytes that would be needed to encode an {@code sint32} field. */ @@ -745,27 +769,9 @@ public static int computeInt64SizeNoTag(final long value) { * tag. */ public static int computeUInt64SizeNoTag(long value) { - // handle two popular special cases up front ... - if ((value & (~0L << 7)) == 0L) { - return 1; - } - if (value < 0L) { - return 10; - } - // ... leaving us with 8 remaining, which we can divide and conquer - int n = 2; - if ((value & (~0L << 35)) != 0L) { - n += 4; - value >>>= 28; - } - if ((value & (~0L << 21)) != 0L) { - n += 2; - value >>>= 14; - } - if ((value & (~0L << 14)) != 0L) { - n += 1; - } - return n; + int clz = Long.numberOfLeadingZeros(value); + // See computeUInt32SizeNoTag for explanation + return ((Long.SIZE * 9 + (1 << 6)) - (clz * 9)) >>> 6; } /** Compute the number of bytes that would be needed to encode an {@code sint64} field. */ diff --git a/java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java index 3007c83a1f99..51e66b9759ff 100644 --- a/java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java +++ b/java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java @@ -327,6 +327,81 @@ public void testEncodeZigZag() throws Exception { .isEqualTo(-75123905439571256L); } + @Test + public void computeIntSize() { + assertThat(CodedOutputStream.computeUInt32SizeNoTag(0)).isEqualTo(1); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(0)).isEqualTo(1); + int i; + for (i = 0; i < 7; i++) { + assertThat(CodedOutputStream.computeInt32SizeNoTag(1 << i)).isEqualTo(1); + assertThat(CodedOutputStream.computeUInt32SizeNoTag(1 << i)).isEqualTo(1); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(1); + } + for (; i < 14; i++) { + assertThat(CodedOutputStream.computeInt32SizeNoTag(1 << i)).isEqualTo(2); + assertThat(CodedOutputStream.computeUInt32SizeNoTag(1 << i)).isEqualTo(2); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(2); + } + for (; i < 21; i++) { + assertThat(CodedOutputStream.computeInt32SizeNoTag(1 << i)).isEqualTo(3); + assertThat(CodedOutputStream.computeUInt32SizeNoTag(1 << i)).isEqualTo(3); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(3); + } + for (; i < 28; i++) { + assertThat(CodedOutputStream.computeInt32SizeNoTag(1 << i)).isEqualTo(4); + assertThat(CodedOutputStream.computeUInt32SizeNoTag(1 << i)).isEqualTo(4); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(4); + } + for (; i < 31; i++) { + assertThat(CodedOutputStream.computeInt32SizeNoTag(1 << i)).isEqualTo(5); + assertThat(CodedOutputStream.computeUInt32SizeNoTag(1 << i)).isEqualTo(5); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(5); + } + for (; i < 32; i++) { + assertThat(CodedOutputStream.computeInt32SizeNoTag(1 << i)).isEqualTo(10); + assertThat(CodedOutputStream.computeUInt32SizeNoTag(1 << i)).isEqualTo(5); + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(5); + } + for (; i < 35; i++) { + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(5); + } + for (; i < 42; i++) { + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(6); + } + for (; i < 49; i++) { + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(7); + } + for (; i < 56; i++) { + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(8); + } + for (; i < 63; i++) { + assertThat(CodedOutputStream.computeUInt64SizeNoTag(1L << i)).isEqualTo(9); + } + } + + @Test + public void computeTagSize() { + assertThat(CodedOutputStream.computeTagSize(0)).isEqualTo(1); + int i; + for (i = 0; i < 4; i++) { + assertThat(CodedOutputStream.computeTagSize(1 << i)).isEqualTo(1); + } + for (; i < 11; i++) { + assertThat(CodedOutputStream.computeTagSize(1 << i)).isEqualTo(2); + } + for (; i < 18; i++) { + assertThat(CodedOutputStream.computeTagSize(1 << i)).isEqualTo(3); + } + for (; i < 25; i++) { + assertThat(CodedOutputStream.computeTagSize(1 << i)).isEqualTo(4); + } + for (; i < 29; i++) { + assertThat(CodedOutputStream.computeTagSize(1 << i)).isEqualTo(5); + } + // Invalid tags + assertThat(CodedOutputStream.computeTagSize((1 << 30) + 1)).isEqualTo(1); + } + /** Tests writing a whole message with every field type. */ @Test public void testWriteWholeMessage() throws Exception {