diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt index 07efb47c4..d696bf07b 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt @@ -12,10 +12,29 @@ import kotlinx.serialization.modules.* import kotlinx.serialization.protobuf.* internal typealias ProtoDesc = Long -internal const val VARINT = 0 -internal const val i64 = 1 -internal const val SIZE_DELIMITED = 2 -internal const val i32 = 5 + +internal enum class ProtoWireType(val typeId: Int) { + INVALID(-1), + VARINT(0), + i64(1), + SIZE_DELIMITED(2), + i32(5), + ; + + companion object { + fun from(typeId: Int): ProtoWireType { + return ProtoWireType.entries.find { it.typeId == typeId } ?: INVALID + } + } + + fun wireIntWithTag(tag: Int): Int { + return ((tag shl 3) or typeId) + } + + override fun toString(): String { + return "${this.name}($typeId)" + } +} internal const val ID_HOLDER_ONE_OF = -2 @@ -104,7 +123,7 @@ internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedD return result } -internal class ProtobufDecodingException(message: String) : SerializationException(message) +internal class ProtobufDecodingException(message: String, e: Throwable? = null) : SerializationException(message, e) internal expect fun Int.reverseBytes(): Int internal expect fun Long.reverseBytes(): Long diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt index 55549f4a1..7cae22a24 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt @@ -122,41 +122,53 @@ internal open class ProtobufDecoder( } override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { - return when (descriptor.kind) { - StructureKind.LIST -> { - val tag = currentTagOrDefault - return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) { - val reader = makeDelimited(reader, tag) - // repeated decoder expects the first tag to be read already - reader.readTag() - // all elements always have id = 1 - RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor) - - } else if (reader.currentType == SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) { - val sliceReader = ProtobufReader(reader.objectInput()) - PackedArrayDecoder(proto, sliceReader, descriptor) - - } else { - RepeatedDecoder(proto, reader, tag, descriptor) + return try { + when (descriptor.kind) { + StructureKind.LIST -> { + val tag = currentTagOrDefault + return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) { + val reader = makeDelimited(reader, tag) + // repeated decoder expects the first tag to be read already + reader.readTag() + // all elements always have id = 1 + RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor) + + } else if (reader.currentType == ProtoWireType.SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) { + val sliceReader = ProtobufReader(reader.objectInput()) + PackedArrayDecoder(proto, sliceReader, descriptor) + + } else { + RepeatedDecoder(proto, reader, tag, descriptor) + } } - } - StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> { - val tag = currentTagOrDefault - // Do not create redundant copy - if (tag == MISSING_TAG && this.descriptor == descriptor) return this - if (tag.isOneOf) { - // If a tag is annotated as oneof - // [tag.protoId] here is overwritten with index-based default id in - // [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters] - // and restored the real id from index2IdMap, set by [decodeElementIndex] - val rawIndex = tag.protoId - 1 - val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag - return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor) + + StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> { + val tag = currentTagOrDefault + // Do not create redundant copy + if (tag == MISSING_TAG && this.descriptor == descriptor) return this + if (tag.isOneOf) { + // If a tag is annotated as oneof + // [tag.protoId] here is overwritten with index-based default id in + // [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters] + // and restored the real id from index2IdMap, set by [decodeElementIndex] + val rawIndex = tag.protoId - 1 + val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag + return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor) + } + return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor) } - return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor) + + StructureKind.MAP -> MapEntryReader( + proto, + makeDelimitedForced(reader, currentTagOrDefault), + currentTagOrDefault, + descriptor + ) + + else -> throw SerializationException("Primitives are not supported at top-level") } - StructureKind.MAP -> MapEntryReader(proto, makeDelimitedForced(reader, currentTagOrDefault), currentTagOrDefault, descriptor) - else -> throw SerializationException("Primitives are not supported at top-level") + } catch (e: ProtobufDecodingException) { + throw ProtobufDecodingException("Fail to begin structure for ${descriptor.serialName} in ${this.descriptor.serialName} at proto number ${currentTagOrDefault.protoId}", e) } } @@ -173,41 +185,51 @@ internal open class ProtobufDecoder( override fun decodeTaggedByte(tag: ProtoDesc): Byte = decodeTaggedInt(tag).toByte() override fun decodeTaggedShort(tag: ProtoDesc): Short = decodeTaggedInt(tag).toShort() override fun decodeTaggedInt(tag: ProtoDesc): Int { - return if (tag == MISSING_TAG) { - reader.readInt32NoTag() - } else { - reader.readInt(tag.integerType) + return decodeOrThrow(tag) { + if (tag == MISSING_TAG) { + reader.readInt32NoTag() + } else { + reader.readInt(tag.integerType) + } } } override fun decodeTaggedLong(tag: ProtoDesc): Long { - return if (tag == MISSING_TAG) { - reader.readLongNoTag() - } else { - reader.readLong(tag.integerType) + return decodeOrThrow(tag) { + if (tag == MISSING_TAG) { + reader.readLongNoTag() + } else { + reader.readLong(tag.integerType) + } } } override fun decodeTaggedFloat(tag: ProtoDesc): Float { - return if (tag == MISSING_TAG) { - reader.readFloatNoTag() - } else { - reader.readFloat() + return decodeOrThrow(tag) { + if (tag == MISSING_TAG) { + reader.readFloatNoTag() + } else { + reader.readFloat() + } } } override fun decodeTaggedDouble(tag: ProtoDesc): Double { - return if (tag == MISSING_TAG) { - reader.readDoubleNoTag() - } else { - reader.readDouble() + return decodeOrThrow(tag) { + if (tag == MISSING_TAG) { + reader.readDoubleNoTag() + } else { + reader.readDouble() + } } } override fun decodeTaggedChar(tag: ProtoDesc): Char = decodeTaggedInt(tag).toChar() override fun decodeTaggedString(tag: ProtoDesc): String { - return if (tag == MISSING_TAG) { - reader.readStringNoTag() - } else { - reader.readString() + return decodeOrThrow(tag) { + if (tag == MISSING_TAG) { + reader.readStringNoTag() + } else { + reader.readString() + } } } @@ -218,22 +240,49 @@ internal open class ProtobufDecoder( override fun decodeSerializableValue(deserializer: DeserializationStrategy): T = decodeSerializableValue(deserializer, null) @Suppress("UNCHECKED_CAST") - override fun decodeSerializableValue(deserializer: DeserializationStrategy, previousValue: T?): T = when { - deserializer is MapLikeSerializer<*, *, *, *> -> { - deserializeMap(deserializer as DeserializationStrategy, previousValue) + override fun decodeSerializableValue(deserializer: DeserializationStrategy, previousValue: T?): T = try { + when { + deserializer is MapLikeSerializer<*, *, *, *> -> { + deserializeMap(deserializer as DeserializationStrategy, previousValue) + } + + deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T + deserializer is AbstractCollectionSerializer<*, *, *> -> + (deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue) + + else -> deserializer.deserialize(this) + } + } catch (e: ProtobufDecodingException) { + val currentTag = currentTagOrDefault + val msg = if (descriptor != deserializer.descriptor) { + // Decoding child element + if (descriptor.kind == StructureKind.LIST && deserializer.descriptor.kind != StructureKind.MAP) { + // Decoding repeated field + "Error while decoding index ${currentTag.protoId - 1} in repeated field of ${deserializer.descriptor.serialName}" + } else if (descriptor.kind == StructureKind.MAP) { + // Decoding map field + val index = (currentTag.protoId - 1) / 2 + val field = if ((currentTag.protoId - 1) % 2 == 0) { "key" } else "value" + "Error while decoding $field of index $index in map field of ${deserializer.descriptor.serialName}" + } else { + // Decoding common class + "Error while decoding ${deserializer.descriptor.serialName} at proto number ${currentTag.protoId} of ${descriptor.serialName}" + } + } else { + // Decoding self + "Error while decoding ${descriptor.serialName}" } - deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T - deserializer is AbstractCollectionSerializer<*, *, *> -> - (deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue) - else -> deserializer.deserialize(this) + throw ProtobufDecodingException(msg, e) } private fun deserializeByteArray(previousValue: ByteArray?): ByteArray { val tag = currentTagOrDefault - val array = if (tag == MISSING_TAG) { - reader.readByteArrayNoTag() - } else { - reader.readByteArray() + val array = decodeOrThrow(tag) { + if (tag == MISSING_TAG) { + reader.readByteArrayNoTag() + } else { + reader.readByteArray() + } } return if (previousValue == null) array else previousValue + array } @@ -252,29 +301,33 @@ internal open class ProtobufDecoder( override fun SerialDescriptor.getTag(index: Int) = extractParameters(index) override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - while (true) { - val protoId = reader.readTag() - if (protoId == -1) { // EOF - return elementMarker.nextUnmarkedIndex() - } - val index = getIndexByNum(protoId) - if (index == -1) { // not found - reader.skipElement() - } else { - if (descriptor.extractParameters(index).isOneOf) { - /** - * While decoding message with one-of field, - * the proto id read from wire data cannot be easily found - * in the properties of this type, - * So the index of this one-of property and the id read from the wire - * are saved in this map, then restored in [beginStructure] - * and passed to [OneOfPolymorphicReader] to get the actual deserializer. - */ - index2IdMap?.put(index, protoId) + try { + while (true) { + val protoId = reader.readTag() + if (protoId == -1) { // EOF + return elementMarker.nextUnmarkedIndex() + } + val index = getIndexByNum(protoId) + if (index == -1) { // not found + reader.skipElement() + } else { + if (descriptor.extractParameters(index).isOneOf) { + /** + * While decoding message with one-of field, + * the proto id read from wire data cannot be easily found + * in the properties of this type, + * So the index of this one-of property and the id read from the wire + * are saved in this map, then restored in [beginStructure] + * and passed to [OneOfPolymorphicReader] to get the actual deserializer. + */ + index2IdMap?.put(index, protoId) + } + elementMarker.mark(index) + return index } - elementMarker.mark(index) - return index } + } catch (e: ProtobufDecodingException) { + throw ProtobufDecodingException("Fail to get element index for ${descriptor.serialName} in ${this.descriptor.serialName}", e) } } @@ -296,6 +349,19 @@ internal open class ProtobufDecoder( } return false } + + private inline fun decodeOrThrow(tag: ProtoDesc, action: (tag: ProtoDesc) -> T): T { + try { + return action(tag) + } catch (e: ProtobufDecodingException) { + rethrowException(tag, e) + } + } + + @Suppress("NOTHING_TO_INLINE") + private inline fun rethrowException(tag: ProtoDesc, e: ProtobufDecodingException): Nothing { + throw ProtobufDecodingException("Error while decoding proto number ${tag.protoId} of ${descriptor.serialName}", e) + } } private class RepeatedDecoder( diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt index c7d4ea087..3fab92001 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt @@ -13,7 +13,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { @JvmField public var currentId = -1 @JvmField - public var currentType = -1 + public var currentType = ProtoWireType.INVALID private var pushBack = false private var pushBackHeader = 0 @@ -23,13 +23,13 @@ internal class ProtobufReader(private val input: ByteArrayInput) { public fun readTag(): Int { if (pushBack) { pushBack = false - val previousHeader = (currentId shl 3) or currentType + val previousHeader = (currentId shl 3) or currentType.typeId return updateIdAndType(pushBackHeader).also { pushBackHeader = previousHeader } } // Header to use when pushed back is the old id/type - pushBackHeader = (currentId shl 3) or currentType + pushBackHeader = (currentId shl 3) or currentType.typeId val header = input.readVarint64(true).toInt() return updateIdAndType(header) @@ -38,11 +38,11 @@ internal class ProtobufReader(private val input: ByteArrayInput) { private fun updateIdAndType(header: Int): Int { return if (header == -1) { currentId = -1 - currentType = -1 + currentType = ProtoWireType.INVALID -1 } else { currentId = header ushr 3 - currentType = header and 0b111 + currentType = ProtoWireType.from(header and 0b111) currentId } } @@ -50,28 +50,28 @@ internal class ProtobufReader(private val input: ByteArrayInput) { public fun pushBackTag() { pushBack = true - val nextHeader = (currentId shl 3) or currentType + val nextHeader = (currentId shl 3) or currentType.typeId updateIdAndType(pushBackHeader) pushBackHeader = nextHeader } fun skipElement() { when (currentType) { - VARINT -> readInt(ProtoIntegerType.DEFAULT) - i64 -> readLong(ProtoIntegerType.FIXED) - SIZE_DELIMITED -> readByteArray() - i32 -> readInt(ProtoIntegerType.FIXED) + ProtoWireType.VARINT -> readInt(ProtoIntegerType.DEFAULT) + ProtoWireType.i64 -> readLong(ProtoIntegerType.FIXED) + ProtoWireType.SIZE_DELIMITED -> readByteArray() + ProtoWireType.i32 -> readInt(ProtoIntegerType.FIXED) else -> throw ProtobufDecodingException("Unsupported start group or end group wire type: $currentType") } } @Suppress("NOTHING_TO_INLINE") - private inline fun assertWireType(expected: Int) { + private inline fun assertWireType(expected: ProtoWireType) { if (currentType != expected) throw ProtobufDecodingException("Expected wire type $expected, but found $currentType") } fun readByteArray(): ByteArray { - assertWireType(SIZE_DELIMITED) + assertWireType(ProtoWireType.SIZE_DELIMITED) return readByteArrayNoTag() } @@ -82,7 +82,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { } fun objectInput(): ByteArrayInput { - assertWireType(SIZE_DELIMITED) + assertWireType(ProtoWireType.SIZE_DELIMITED) return objectTaglessInput() } @@ -93,7 +93,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { } fun readInt(format: ProtoIntegerType): Int { - val wireType = if (format == ProtoIntegerType.FIXED) i32 else VARINT + val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i32 else ProtoWireType.VARINT assertWireType(wireType) return decode32(format) } @@ -101,7 +101,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { fun readInt32NoTag(): Int = decode32() fun readLong(format: ProtoIntegerType): Long { - val wireType = if (format == ProtoIntegerType.FIXED) i64 else VARINT + val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i64 else ProtoWireType.VARINT assertWireType(wireType) return decode64(format) } @@ -109,7 +109,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { fun readLongNoTag(): Long = decode64(ProtoIntegerType.DEFAULT) fun readFloat(): Float { - assertWireType(i32) + assertWireType(ProtoWireType.i32) return Float.fromBits(readIntLittleEndian()) } @@ -136,7 +136,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { } fun readDouble(): Double { - assertWireType(i64) + assertWireType(ProtoWireType.i64) return Double.fromBits(readLongLittleEndian()) } @@ -145,7 +145,7 @@ internal class ProtobufReader(private val input: ByteArrayInput) { } fun readString(): String { - assertWireType(SIZE_DELIMITED) + assertWireType(ProtoWireType.SIZE_DELIMITED) val length = decode32() checkLength(length) return input.readString(length) diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufWriter.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufWriter.kt index ba1642729..f43211228 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufWriter.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufWriter.kt @@ -10,7 +10,7 @@ import kotlinx.serialization.protobuf.* internal class ProtobufWriter(private val out: ByteArrayOutput) { fun writeBytes(bytes: ByteArray, tag: Int) { - out.encode32((tag shl 3) or SIZE_DELIMITED) + out.encode32(ProtoWireType.SIZE_DELIMITED.wireIntWithTag(tag)) writeBytes(bytes) } @@ -20,7 +20,7 @@ internal class ProtobufWriter(private val out: ByteArrayOutput) { } fun writeOutput(output: ByteArrayOutput, tag: Int) { - out.encode32((tag shl 3) or SIZE_DELIMITED) + out.encode32(ProtoWireType.SIZE_DELIMITED.wireIntWithTag(tag)) writeOutput(output) } @@ -30,8 +30,8 @@ internal class ProtobufWriter(private val out: ByteArrayOutput) { } fun writeInt(value: Int, tag: Int, format: ProtoIntegerType) { - val wireType = if (format == ProtoIntegerType.FIXED) i32 else VARINT - out.encode32((tag shl 3) or wireType) + val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i32 else ProtoWireType.VARINT + out.encode32(wireType.wireIntWithTag(tag)) out.encode32(value, format) } @@ -40,8 +40,8 @@ internal class ProtobufWriter(private val out: ByteArrayOutput) { } fun writeLong(value: Long, tag: Int, format: ProtoIntegerType) { - val wireType = if (format == ProtoIntegerType.FIXED) i64 else VARINT - out.encode32((tag shl 3) or wireType) + val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i64 else ProtoWireType.VARINT + out.encode32(wireType.wireIntWithTag(tag)) out.encode64(value, format) } @@ -60,7 +60,7 @@ internal class ProtobufWriter(private val out: ByteArrayOutput) { } fun writeDouble(value: Double, tag: Int) { - out.encode32((tag shl 3) or i64) + out.encode32(ProtoWireType.i64.wireIntWithTag(tag)) out.writeLong(value.reverseBytes()) } @@ -69,7 +69,7 @@ internal class ProtobufWriter(private val out: ByteArrayOutput) { } fun writeFloat(value: Float, tag: Int) { - out.encode32((tag shl 3) or i32) + out.encode32(ProtoWireType.i32.wireIntWithTag(tag)) out.writeInt(value.reverseBytes()) } diff --git a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtoTagExceptionTest.kt b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtoTagExceptionTest.kt new file mode 100644 index 000000000..a119ea5e8 --- /dev/null +++ b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtoTagExceptionTest.kt @@ -0,0 +1,200 @@ +/* + * Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.protobuf + +import kotlinx.serialization.Serializable +import kotlinx.serialization.decodeFromHexString +import kotlinx.serialization.encodeToHexString +import kotlinx.serialization.protobuf.internal.ProtobufDecodingException +import kotlin.test.Test +import kotlin.test.assertEquals + +class ProtoTagExceptionTest { + + @Serializable + data class TestDataToBuildWrongWireType( + @ProtoNumber(1) val a: Int, + @ProtoNumber(2) val b: Int, + ) + + @Serializable + data class TestData( + @ProtoNumber(1) val a: Int, + @ProtoNumber(2) val b: String, + ) + + @Test + fun testWrongTypeMessage() { + val build = ProtoBuf.encodeToHexString(TestDataToBuildWrongWireType(42, 42)) + + assertFailsWith( + assertion = { + assertFailsWith( + "Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Error while decoding proto number 2 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Expected wire type SIZE_DELIMITED(2), but found VARINT(0)", + ) + } + ) { + ProtoBuf.decodeFromHexString(build) + } + } + + @Serializable + data class TestNestedDataToBuild( + @ProtoNumber(1) val nested: TestDataToBuildWrongWireType, + @ProtoNumber(2) val a: String, + ) + + @Serializable + data class TestNestedData( + @ProtoNumber(1) val nested: TestData, + @ProtoNumber(2) val a: String, + ) + + @Test + fun testWrongIntFieldInNestedMessage() { + val build = ProtoBuf.encodeToHexString(TestNestedDataToBuild(TestDataToBuildWrongWireType(42, 42), "foo")) + + assertFailsWith( + assertion = { + assertFailsWith( + "Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestNestedData", + "Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData at proto number 1 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestNestedData", + "Error while decoding proto number 2 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Expected wire type SIZE_DELIMITED(2), but found VARINT(0)", + ) + } + ) { + ProtoBuf.decodeFromHexString(build) + } + assertFailsWith( + assertion = { + assertFailsWith( + "Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Error while decoding proto number 1 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Expected wire type VARINT(0), but found SIZE_DELIMITED(2)", + ) + } + ) { + ProtoBuf.decodeFromHexString(build) + } + } + + @Test + fun testWrongStringFieldInNestedMessage() { + val build = ProtoBuf.encodeToHexString(TestNestedDataToBuild(TestDataToBuildWrongWireType(42, 42), "foo")) + assertFailsWith( + assertion = { + assertFailsWith( + "Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Error while decoding proto number 1 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Expected wire type VARINT(0), but found SIZE_DELIMITED(2)", + ) + } + ) { + ProtoBuf.decodeFromHexString(build) + } + } + + @Serializable + data class TestDataWithMessageList(@ProtoNumber(1) @ProtoPacked val list: List) + + @Serializable + data class TestDataWithWrongList(@ProtoNumber(1) @ProtoPacked val list: List) + + @Test + fun testWrongIntFieldInNestedMessageInList() { + val build = ProtoBuf.encodeToHexString(TestDataWithWrongList(listOf(TestDataToBuildWrongWireType(42, 42)))) + assertFailsWith( + assertion = { + assertFailsWith("Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestDataWithMessageList") + assertCausedBy { + assertFailsWith("Error while decoding kotlin.collections.ArrayList at proto number 1 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestDataWithMessageList") + assertCausedBy { + assertFailsWith( + "Error while decoding index 0 in repeated field of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Error while decoding proto number 2 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Expected wire type SIZE_DELIMITED(2), but found VARINT(0)", + ) + } + } + } + ) { + val result = ProtoBuf.decodeFromHexString(build) + } + } + + @Serializable + data class TestDataWithMessageMapValue(@ProtoNumber(1) val map: Map) + + @Serializable + data class TestDataWithWrongMapValue(@ProtoNumber(1) val map: Map) + + @Test + fun testWrongIntFieldInNestedMapValue() { + val build = ProtoBuf.encodeToHexString(TestDataWithWrongMapValue(map = mapOf("1" to TestDataToBuildWrongWireType(42, 42)))) + assertFailsWith( + assertion = { + assertFailsWith("Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestDataWithMessageMapValue") + assertCausedBy { + assertFailsWith("Error while decoding kotlin.collections.LinkedHashMap at proto number 1 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestDataWithMessageMapValue") + assertCausedBy { + assertFailsWith( + "Error while decoding kotlin.collections.Map.Entry at proto number 1 of kotlin.collections.LinkedHashSet", + "Error while decoding value of index 0 in map field of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Error while decoding proto number 2 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.TestData", + "Expected wire type SIZE_DELIMITED(2), but found VARINT(0)", + ) + } + } + } + ) { + ProtoBuf.decodeFromHexString(build) + } + } + + + @Serializable + data class DuplicatingIdData( + @ProtoOneOf val bad: IDuplicatingIdType, + @ProtoNumber(3) val d: Int, + ) + + @Serializable + sealed interface IDuplicatingIdType + + @Serializable + data class DuplicatingIdStringType(@ProtoNumber(3) val s: String) : IDuplicatingIdType + + @Test + fun testDuplicatedIdClass() { + val duplicated = DuplicatingIdData(DuplicatingIdStringType("foo"), 42) + // Fine to encode duplicated proto number properties in wire data + ProtoBuf.encodeToHexString(duplicated).also { + /** + * 3:LEN {"foo"} + * 3:VARINT 42 + */ + assertEquals("1a03666f6f182a", it) + } + + assertFailsWith( + assertion = { + assertFailsWith( + "Error while decoding kotlinx.serialization.protobuf.ProtoTagExceptionTest.DuplicatingIdData", + "Error while decoding proto number 3 of kotlinx.serialization.protobuf.ProtoTagExceptionTest.DuplicatingIdData", + "Expected wire type VARINT(0), but found SIZE_DELIMITED(2)", + ) + } + ) { + /** + * 3:LEN {"foo"} + * 3:VARINT 42 + */ + ProtoBuf.decodeFromHexString("1a03666f6f182a") + } + } +} \ No newline at end of file diff --git a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtobufOneOfTest.kt b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtobufOneOfTest.kt index 63d7dd504..e272738fc 100644 --- a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtobufOneOfTest.kt +++ b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtobufOneOfTest.kt @@ -361,7 +361,9 @@ class ProtobufOneOfTest { } assertFailsWithMessage( - message = "Serializer for subclass 'OtherIntType' is not found in the polymorphic scope of 'OtherType'." + message = "Serializer for subclass 'OtherIntType' is not found in the polymorphic scope of 'OtherType'.\n" + + "Check if class with serial name 'OtherIntType' exists and serializer is registered in a corresponding SerializersModule.\n" + + "To be registered automatically, class 'OtherIntType' has to be '@Serializable', and the base class 'OtherType' has to be sealed and '@Serializable'." ) { buf.encodeToHexString( DoubleOneOfElement.serializer(), DoubleOneOfElement( @@ -531,44 +533,6 @@ class ProtobufOneOfTest { assertEquals(data, buf.decodeFromHexString("082a")) } - @Serializable - data class DuplicatingIdData( - @ProtoOneOf val bad: IDuplicatingIdType, - @ProtoNumber(3) val d: Int, - ) - - @Serializable - sealed interface IDuplicatingIdType - - @Serializable - data class DuplicatingIdStringType(@ProtoNumber(3) val s: String) : IDuplicatingIdType - - @Test - fun testDuplicatedIdClass() { - val duplicated = DuplicatingIdData(DuplicatingIdStringType("foo"), 42) - // Fine to encode duplicated proto number properties in wire data - ProtoBuf.encodeToHexString(duplicated).also { - /** - * 3:LEN {"foo"} - * 3:VARINT 42 - */ - assertEquals("1a03666f6f182a", it) - } - - // Without checking duplication of proto numbers, - // ProtoBuf just throw exception about wrong wire type - assertFailsWithMessage( -// "Duplicated proto number 3 in kotlinx.serialization.protobuf.ProtobufOneOfTest.DuplicatingIdData for elements: d, bad." - "Expected wire type 0, but found 2" - ) { - /** - * 3:LEN {"foo"} - * 3:VARINT 42 - */ - ProtoBuf.decodeFromHexString("1a03666f6f182a") - } - } - @Serializable data class TypedIntOuter( @ProtoOneOf val i: ITypedInt, @@ -739,7 +703,7 @@ class ProtobufOneOfTest { fun testNonePolymorphicClass() { val data = Outer(Inner(42)) assertFailsWithMessage( - "The serializer of one of type kotlinx.serialization.protobuf.ProtobufOneOfTest.Inner should be using generic polymorphic serializer, but got CLASS" + "The serializer of one of type kotlinx.serialization.protobuf.ProtobufOneOfTest.Inner should be using generic polymorphic serializer, but got CLASS." ) { // Fails in [kotlinx.serialization.protobuf.internal.OneOfPolymorphicEncoder.init] ProtoBuf.encodeToHexString(data) diff --git a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctionTest.kt b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctionTest.kt new file mode 100644 index 000000000..823422c12 --- /dev/null +++ b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctionTest.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.protobuf + +import kotlinx.serialization.protobuf.internal.ProtobufDecodingException +import kotlin.test.Ignore +import kotlin.test.Test + +/** + * Tests for [assertFailsWith] to see if output in IDEA can be checked with button. + * Expected to fail so ignore in CI. + */ +@Ignore +class TestFunctionTest { + @Test + fun testAssertionMessage() { + assertFailsWith(assertion = { + assertFailsWith("expected message") + }) { + throw IllegalArgumentException("actual message") + } + } + @Test + fun testAssertionType() { + assertFailsWith(assertion = { + assertFailsWith("") + assertCausedBy { + assertFailsWith("expected message") + } + }) { + throw IllegalArgumentException("", IllegalArgumentException()) + } + } + @Test + fun testAssertionFailWith() { + assertFailsWith(assertion = {}) { + throw ProtobufDecodingException("expected message") + } + } +} \ No newline at end of file diff --git a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctions.kt b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctions.kt index c321c4786..d886e0ff4 100644 --- a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctions.kt +++ b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/TestFunctions.kt @@ -5,6 +5,7 @@ package kotlinx.serialization.protobuf import kotlinx.serialization.* +import kotlin.reflect.KClass import kotlin.test.* fun testConversion(data: T, serializer: KSerializer, expectedHexString: String) { @@ -24,9 +25,87 @@ inline fun assertFailsWithMessage( assertionMessage: String? = null, block: () -> Unit ) { - val exception = assertFailsWith(T::class, assertionMessage, block) - assertTrue( - exception.message!!.contains(message), - "expected:<$message> but was:<${exception.message}>" + assertFailsWith( + assertionMessage, + { + assertFailsWith(message) + }, + block, ) -} \ No newline at end of file +} + +@DslMarker +annotation class ExceptionCheckDsl + +@ExceptionCheckDsl +interface ExceptionCheckScope { + fun assertFailsWith(vararg message: String) + fun assertCausedBy(byType: KClass, assertion: ExceptionCheckScope.() -> Unit) +} + +@ExceptionCheckDsl +inline fun ExceptionCheckScope<*>.assertCausedBy(noinline assertion: ExceptionCheckScope.() -> Unit) { + assertCausedBy(R::class, assertion) +} + +inline fun assertFailsWith( + assertionMessage: String? = null, + assertion: ExceptionCheckScope.() -> Unit = {}, + block: () -> Unit +) { + val exception = assertFailsWith(T::class, assertionMessage, block = block) + val scope = buildExceptionCheckScope(exception) + scope.assertion() +} + +fun buildExceptionCheckScope(exception: T, depth: Int = 0): ExceptionCheckScope = + object : ExceptionCheckScope { + override fun assertFailsWith(vararg message: String) { + val exceptionStackSize = exception.exceptionStackSize + assertTrue( + message.size <= exceptionStackSize, + "Expected exception to be assembled by at least ${message.size} throwable(s), but it has $exceptionStackSize, actual exception is $exception." + ) + var index = 0 + var currentException: Throwable? = exception + while (index < message.size) { + val currentMessage = message[index] + assertNotNull( + currentException, + "Expected exception to have a cause with message $currentMessage, but it was null." + ) + assertEquals( + currentMessage, + currentException.message, + "Exception messages are different at cause stack ${index + depth}." + ) + val nextException = currentException.cause + currentException = nextException + index++ + } + } + + @Suppress("UNCHECKED_CAST") + override fun assertCausedBy(byType: KClass, assertion: ExceptionCheckScope.() -> Unit) { + val cause = exception.cause + assertNotNull(cause, "Expected exception to have a cause of type $byType, but it was null.") + assertEquals( + byType, + cause::class, + "Current exception is caused by unexpected exception at cause stack $depth." + ) + buildExceptionCheckScope(cause as R, depth + 1).assertion() + } + + } + +private val Throwable.exceptionStackSize: Int + get() { + var count = 1 + var current = this + while (current.cause != null) { + count++ + current = current.cause!! + } + return count + } \ No newline at end of file