diff --git a/ipv8/src/main/java/nl/tudelft/ipv8/messaging/Serialization.kt b/ipv8/src/main/java/nl/tudelft/ipv8/messaging/Serialization.kt index de3edf4b..de4567fb 100644 --- a/ipv8/src/main/java/nl/tudelft/ipv8/messaging/Serialization.kt +++ b/ipv8/src/main/java/nl/tudelft/ipv8/messaging/Serialization.kt @@ -2,11 +2,15 @@ package nl.tudelft.ipv8.messaging import java.nio.Buffer import java.nio.ByteBuffer +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.full.memberProperties +import kotlin.reflect.full.primaryConstructor const val SERIALIZED_USHORT_SIZE = 2 const val SERIALIZED_UINT_SIZE = 4 +const val SERIALIZED_INT_SIZE = 4 const val SERIALIZED_ULONG_SIZE = 8 -const val SERIALIZED_LONG_SIZE = 4 +const val SERIALIZED_LONG_SIZE = 8 const val SERIALIZED_UBYTE_SIZE = 1 const val SERIALIZED_PUBLIC_KEY_SIZE = 74 @@ -20,6 +24,81 @@ interface Serializable { interface Deserializable { fun deserialize(buffer: ByteArray, offset: Int = 0): Pair + +} + +/** + * Serializes the object and returns the buffer. + * Alternative to manually defining the serialize function. + */ +interface AutoSerializable : Serializable { + override fun serialize(): ByteArray { + return this::class.primaryConstructor!!.parameters.map { param -> + val value = + this.javaClass.kotlin.memberProperties.find { it.name == param.name }!!.get(this) + simpleSerialize(value) + }.reduce { acc, bytes -> acc + bytes } + } +} + +///** +// * Deserializes the object from the buffer and returns the object and the new offset. +// * Alternative to manually defining the deserialize function. +// */ +//inline fun Deserializable.autoDeserialize(buffer: ByteArray, offset: Int = 0): Pair { +// TODO() +//} + +fun simpleSerialize(data: U): ByteArray { + return when (data) { + is Int -> serializeInt(data) + is Long -> serializeLong(data) + is UByte -> serializeUChar(data) + is UInt -> serializeUInt(data) + is UShort -> serializeUShort(data.toInt()) + is ULong -> serializeULong(data) + is String -> serializeVarLen(data.toByteArray()) + is ByteArray -> serializeVarLen(data) + is Boolean -> serializeBool(data) + is Enum<*> -> serializeUInt(data.ordinal.toUInt()) + is Serializable -> data.serialize() + else -> throw IllegalArgumentException("Unsupported serialization type") + } +} + +inline fun simpleDeserialize(buffer: ByteArray, offset: Int = 0): Pair { + val (value, off) = when (U::class) { + Int::class -> Pair(deserializeInt(buffer, offset) as U, SERIALIZED_INT_SIZE) + Long::class -> Pair(deserializeLong(buffer, offset) as U, SERIALIZED_LONG_SIZE) + UByte::class -> Pair(deserializeUChar(buffer, offset) as U, SERIALIZED_UBYTE_SIZE) + UShort::class -> Pair( + deserializeUShort(buffer, offset).toUShort() as U, + SERIALIZED_USHORT_SIZE + ) + + UInt::class -> Pair(deserializeUInt(buffer, offset) as U, SERIALIZED_UINT_SIZE) + ULong::class -> Pair(deserializeULong(buffer, offset) as U, SERIALIZED_ULONG_SIZE) + String::class -> { + val (data, len) = deserializeVarLen(buffer, offset) + Pair(data.decodeToString() as U, len) + } + + ByteArray::class -> { + val (data, len) = deserializeVarLen(buffer, offset) + Pair(data as U, len) + } + + Boolean::class -> Pair(deserializeBool(buffer, offset) as U, 1) + else -> { + println("Enum: ${U::class.qualifiedName}") + if (U::class.isSubclassOf(Enum::class)) { + val ordinal = deserializeUInt(buffer, offset).toInt() + val values = U::class.java.enumConstants + Pair(values[ordinal] as U, SERIALIZED_UINT_SIZE) + } else throw IllegalArgumentException("Unsupported deserialization type") + } + } + return (value to (offset + off)) } fun serializeBool(data: Boolean): ByteArray { @@ -40,10 +119,21 @@ fun serializeUShort(value: Int): ByteArray { return bytes } +fun serializeUShort(value: UShort): ByteArray { + val bytes = ByteBuffer.allocate(SERIALIZED_USHORT_SIZE) + bytes.putShort(value.toShort()) + return bytes.array() +} + fun deserializeUShort(buffer: ByteArray, offset: Int = 0): Int { return (((buffer[offset].toInt() and 0xFF) shl 8) or (buffer[offset + 1].toInt() and 0xFF)) } +fun deserializeRealUShort(buffer: ByteArray, offset: Int = 0): UShort { + val buf = ByteBuffer.wrap(buffer, offset, SERIALIZED_USHORT_SIZE) + return buf.short.toUShort() +} + fun serializeUInt(value: UInt): ByteArray { val bytes = UByteArray(SERIALIZED_UINT_SIZE) for (i in 0 until SERIALIZED_UINT_SIZE) { @@ -80,7 +170,7 @@ fun deserializeULong(buffer: ByteArray, offset: Int = 0): ULong { fun serializeLong(value: Long): ByteArray { val buffer = ByteBuffer.allocate(SERIALIZED_LONG_SIZE) - buffer.putInt(value.toInt()) + buffer.putLong(value) return buffer.array() } @@ -89,7 +179,20 @@ fun deserializeLong(bytes: ByteArray, offset: Int = 0): Long { buffer.put(bytes.copyOfRange(offset, offset + SERIALIZED_LONG_SIZE)) // In JDK 8 this returns a Buffer. (buffer as Buffer).flip() - return buffer.int.toLong() + return buffer.long +} + +fun serializeInt(value: Int): ByteArray { + val buffer = ByteBuffer.allocate(SERIALIZED_INT_SIZE) + buffer.putInt(value) + return buffer.array() +} + +fun deserializeInt(bytes: ByteArray, offset: Int = 0): Int { + val buffer = ByteBuffer.allocate(SERIALIZED_INT_SIZE) + buffer.put(bytes.copyOfRange(offset, offset + SERIALIZED_INT_SIZE)) + buffer.flip() + return buffer.int } fun serializeUChar(char: UByte): ByteArray { @@ -107,8 +210,10 @@ fun serializeVarLen(bytes: ByteArray): ByteArray { fun deserializeVarLen(buffer: ByteArray, offset: Int = 0): Pair { val len = deserializeUInt(buffer, offset).toInt() - val payload = buffer.copyOfRange(offset + SERIALIZED_UINT_SIZE, - offset + SERIALIZED_UINT_SIZE + len) + val payload = buffer.copyOfRange( + offset + SERIALIZED_UINT_SIZE, + offset + SERIALIZED_UINT_SIZE + len + ) return Pair(payload, SERIALIZED_UINT_SIZE + len) } @@ -117,19 +222,31 @@ fun deserializeRecursively(buffer: ByteArray, offset: Int = 0): Array return arrayOf() } val len = deserializeUInt(buffer, offset).toInt() - val payload = buffer.copyOfRange(offset + SERIALIZED_UINT_SIZE, - offset + SERIALIZED_UINT_SIZE + len) - return arrayOf(payload) + deserializeRecursively(buffer.copyOfRange(offset + SERIALIZED_UINT_SIZE + len, - buffer.size), offset) + val payload = buffer.copyOfRange( + offset + SERIALIZED_UINT_SIZE, + offset + SERIALIZED_UINT_SIZE + len + ) + return arrayOf(payload) + deserializeRecursively( + buffer.copyOfRange( + offset + SERIALIZED_UINT_SIZE + len, + buffer.size + ), offset + ) } -fun deserializeAmount(buffer: ByteArray, amount: Int, offset: Int = 0): Pair, ByteArray> { +fun deserializeAmount( + buffer: ByteArray, + amount: Int, + offset: Int = 0 +): Pair, ByteArray> { val returnValues = arrayListOf() var localOffset = offset for (i in 0 until amount) { val len = deserializeUInt(buffer, localOffset).toInt() - val payload = buffer.copyOfRange(localOffset + SERIALIZED_UINT_SIZE, - localOffset + SERIALIZED_UINT_SIZE + len) + val payload = buffer.copyOfRange( + localOffset + SERIALIZED_UINT_SIZE, + localOffset + SERIALIZED_UINT_SIZE + len + ) localOffset += SERIALIZED_UINT_SIZE + len returnValues.add(payload) } diff --git a/ipv8/src/test/java/nl/tudelft/ipv8/messaging/SerializationTest.kt b/ipv8/src/test/java/nl/tudelft/ipv8/messaging/SerializationTest.kt index 66b1e678..fb9323a6 100644 --- a/ipv8/src/test/java/nl/tudelft/ipv8/messaging/SerializationTest.kt +++ b/ipv8/src/test/java/nl/tudelft/ipv8/messaging/SerializationTest.kt @@ -6,6 +6,170 @@ import org.junit.Test import org.junit.Assert.* class SerializationTest { + + private enum class TestEnum { + A, B, C + } + @Test + fun simpleSerializeInt() { + val value = 248375682 + val simple = simpleSerialize(value) + val explicit = serializeInt(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeLong() { + val value = -3483756823489756836 + val simple = simpleSerialize(value) + val explicit = serializeLong(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeUInt() { + val value = 248375682u + val simple = simpleSerialize(value) + val explicit = serializeUInt(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeULong() { + val value = 9483756823489756836u + val simple = simpleSerialize(value) + val explicit = serializeULong(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeUShort() { + val value = 1025.toUShort() + val simple = simpleSerialize(value) + val explicit = serializeUShort(value.toInt()) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeUByte() { + val value = 248u.toUByte() + val simple = simpleSerialize(value) + val explicit = serializeUChar(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeByteArray() { + val value = byteArrayOf(0x01, 0x02, 0x03, 0x04) + val simple = simpleSerialize(value) + val explicit = serializeVarLen(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeBoolean() { + val value = true + val simple = simpleSerialize(value) + val explicit = serializeBool(value) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeString() { + val value = "Hello, World!" + val simple = simpleSerialize(value) + val explicit = serializeVarLen(value.toByteArray()) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleSerializeEnum() { + val value = TestEnum.B + val simple = simpleSerialize(value) + val explicit = serializeUInt(value.ordinal.toUInt()) + assertEquals(simple.toHex(), explicit.toHex()) + } + + @Test + fun simpleDeserializeString() { + val value = "Hello, World!" + val serialized = serializeVarLen(value.toByteArray()) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + + @Test + fun simpleDeserializeEnum() { + val value = TestEnum.B + val serialized = simpleSerialize(value) + val deserialized = simpleDeserialize(serialized) + assertEquals(value, TestEnum.entries[deserialized.first.toInt()]) + } + + @Test + fun simpleDeserializeBoolean() { + val value = true + val serialized = simpleSerialize(value) + val deserialized = simpleDeserialize(serialized) + assertEquals(value, deserialized.first) + } + + @Test + fun simpleDeserializeByteArray() { + val value = byteArrayOf(0x01, 0x02, 0x03, 0x04) + val serialized = serializeVarLen(value) + val (deserialized, _) = simpleDeserialize(serialized) + assertArrayEquals(value, deserialized) + } + + @Test + fun simpleDeserializeUByte() { + val value = 248u.toUByte() + val serialized = serializeUChar(value) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + + @Test + fun simpleDeserializeUShort() { + val value = 1025.toUShort() + val serialized = serializeUShort(value.toInt()) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + + @Test + fun simpleDeserializeUInt() { + val value = 248375682u + val serialized = serializeUInt(value) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + + @Test + fun simpleDeserializeULong() { + val value = 9483756823489756836u + val serialized = serializeULong(value) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + + @Test + fun simpleDeserializeLong() { + val value: Long = -3483756823489756836L + val serialized = serializeLong(value) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + + @Test + fun simpleDeserializeInt() { + val value = 248375682 + val serialized = serializeInt(value) + val (deserialized, _) = simpleDeserialize(serialized) + assertEquals(value, deserialized) + } + @Test fun serializeBool_true() { val serialized = serializeBool(true) @@ -36,6 +200,20 @@ class SerializationTest { assertEquals("0401", serialized.toHex()) } + @Test + fun serializeUShort_max() { + val uShort = UShort.MAX_VALUE + val serialized = serializeUShort(uShort) + assertEquals("ffff", serialized.toHex()) + } + + @Test + fun deserializeRealUShort() { + val uShort = UShort.MAX_VALUE + val serialized = serializeUShort(uShort) + assertEquals(uShort, deserializeRealUShort(serialized)) + } + @Test fun deserializeUShort_simple() { val value = 1025 @@ -62,6 +240,45 @@ class SerializationTest { assertEquals("ffffffffffffffff", serialized.toHex()) } + @Test + fun serializeUInt() { + val serialized = serializeUInt(UInt.MAX_VALUE) + assertEquals("ffffffff", serialized.toHex()) + } + @Test + fun deserializeUInt_simple() { + val value = 248375682u + val serialized = serializeUInt(value) + assertEquals(value, deserializeUInt(serialized)) + } + + @Test + fun deserializeUInt_max() { + val value = UInt.MAX_VALUE + val serialized = serializeUInt(value) + assertEquals(value, deserializeUInt(serialized)) + } + + @Test + fun serializeInt() { + val serialized = serializeInt(Int.MAX_VALUE) + assertEquals("7fffffff", serialized.toHex()) + } + + @Test + fun deserializeInt_simple() { + val value = 248375682 + val serialized = serializeInt(value) + assertEquals(value, deserializeInt(serialized)) + } + + @Test + fun deserializeInt_max() { + val value = Int.MAX_VALUE + val serialized = serializeInt(value) + assertEquals(value, deserializeInt(serialized)) + } + @Test fun deserializeULong_test() { val value = 18446744073709551615uL