diff --git a/qbit-core/src/commonMain/kotlin/qbit/Conn.kt b/qbit-core/src/commonMain/kotlin/qbit/Conn.kt index 55547c27..ac30c123 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/Conn.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/Conn.kt @@ -24,7 +24,7 @@ import qbit.index.Indexer import qbit.index.InternalDb import qbit.index.RawEntity import qbit.ns.Namespace -import qbit.resolving.lastWriterWinsResolve +import qbit.resolving.crdtResolve import qbit.resolving.logsDiff import qbit.serialization.* import qbit.spi.Storage @@ -122,14 +122,14 @@ class QConn( } } - override suspend fun update(trxLog: TrxLog, newLog: TrxLog, newDb: InternalDb) { + override suspend fun update(trxLog: TrxLog, baseDb: InternalDb, newLog: TrxLog, newDb: InternalDb) { val (log, db) = if (hasConcurrentTrx(trxLog)) { - mergeLogs(trxLog, this.trxLog, newLog, newDb) + mergeLogs(trxLog, this.trxLog, newLog, baseDb, newDb) } else { newLog to newDb } - storage.overwrite(Namespace("refs")["head"], newLog.hash.bytes) + storage.overwrite(Namespace("refs")["head"], log.hash.bytes) this.trxLog = log this.db = db } @@ -141,6 +141,7 @@ class QConn( baseLog: TrxLog, committedLog: TrxLog, committingLog: TrxLog, + baseDb: InternalDb, newDb: InternalDb ): Pair { val logsDifference = logsDiff(baseLog, committedLog, committingLog, resolveNode) @@ -149,7 +150,7 @@ class QConn( .logAEntities() .toEavsList() val reconciliationEavs = logsDifference - .reconciliationEntities(lastWriterWinsResolve { db.attr(it) }) + .reconciliationEntities(crdtResolve(baseDb::pullEntity, db::attr)) .toEavsList() val mergedDb = newDb diff --git a/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt b/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt index ce35fdea..36edae61 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt @@ -21,6 +21,10 @@ import kotlin.reflect.KClass * - List */ +val scalarRange = 0..31 +val listRange = 32..63 +val pnCounterRange = 64..95 + @Suppress("UNCHECKED_CAST") sealed class DataType { @@ -31,12 +35,12 @@ sealed class DataType { private val values: Array> get() = arrayOf(QBoolean, QByte, QInt, QLong, QString, QBytes, QGid, QRef) - fun ofCode(code: Byte): DataType<*>? = - if (code <= 19) { - values.firstOrNull { it.code == code } - } else { - values.map { it.list() }.firstOrNull { it.code == code } - } + fun ofCode(code: Byte): DataType<*>? = when(code) { + in scalarRange -> values.firstOrNull { it.code == code } + in listRange -> values.map { it.list() }.firstOrNull { it.code == code } + in pnCounterRange -> ofCode((code - 64).toByte())?.counter() + else -> null + } fun ofValue(value: T?): DataType? = when (value) { is Boolean -> QBoolean as DataType @@ -46,7 +50,7 @@ sealed class DataType { is String -> QString as DataType is ByteArray -> QBytes as DataType is Gid -> QGid as DataType - is List<*> -> value.firstOrNull()?.let { ofValue(it)?.list() } as DataType + is List<*> -> value.firstOrNull()?.let { ofValue(it)?.list() } as DataType? else -> QRef as DataType } } @@ -57,7 +61,14 @@ sealed class DataType { return QList(this) } - fun isList(): Boolean = (code.toInt().and(32)) > 0 + fun isList(): Boolean = code in listRange + + fun counter(): QCounter { + require(this is QByte || this is QInt || this is QLong) { "Only primitive number values are allowed in counters" } + return QCounter(this) + } + + fun isCounter(): Boolean = code in pnCounterRange fun ref(): Boolean = this == QRef || this is QList<*> && this.itemsType == QRef @@ -73,6 +84,7 @@ sealed class DataType { is QBytes -> ByteArray::class is QGid -> Gid::class is QList<*> -> this.itemsType.typeClass() + is QCounter<*> -> this.primitiveType.typeClass() QRef -> Any::class } } @@ -85,6 +97,12 @@ data class QList(val itemsType: DataType) : DataType>() } +data class QCounter(val primitiveType: DataType) : DataType() { + + override val code = (64 + primitiveType.code).toByte() + +} + object QBoolean : DataType() { override val code = 0.toByte() diff --git a/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt b/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt index 4d3bdbc8..32bcd849 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt @@ -2,10 +2,9 @@ package qbit.resolving import kotlinx.coroutines.flow.toList import qbit.api.Instances +import qbit.api.QBitException import qbit.api.gid.Gid -import qbit.api.model.Attr -import qbit.api.model.Eav -import qbit.api.model.Hash +import qbit.api.model.* import qbit.index.RawEntity import qbit.serialization.* import qbit.trx.TrxLog @@ -72,6 +71,51 @@ internal fun lastWriterWinsResolve(resolveAttrName: (String) -> Attr?): (Li } } +internal fun crdtResolve( + resolveEntity: (Gid) -> StoredEntity?, + resolveAttrName: (String) -> Attr? +): (List, List) -> List = { eavsFromA, eavsFromB -> + require(eavsFromA.isNotEmpty()) { "eavsFromA should be not empty" } + require(eavsFromB.isNotEmpty()) { "eavsFromB should be not empty" } + + val gid = eavsFromA[0].eav.gid + val attr = resolveAttrName(eavsFromA[0].eav.attr) + ?: throw IllegalArgumentException("Cannot resolve ${eavsFromA[0].eav.attr}") + + when { + // temporary dirty hack until crdt counter or custom resolution strategy support is implemented + attr == Instances.nextEid -> listOf((eavsFromA + eavsFromB).maxByOrNull { it.eav.value as Int }!!.eav) + attr.list -> (eavsFromA + eavsFromB).map { it.eav }.distinct() + DataType.ofCode(attr.type)!!.isCounter() -> { + val latestFromA = eavsFromA.maxByOrNull { it.timestamp }!!.eav.value + val latestFromB = eavsFromB.maxByOrNull { it.timestamp }!!.eav.value + val previous = resolveEntity(gid)?.tryGet(attr) + + listOf( + if (previous != null) + Eav( + eavsFromA[0].eav.gid, + eavsFromA[0].eav.attr, + if (previous is Byte && latestFromA is Byte && latestFromB is Byte) latestFromA + latestFromB - previous + else if (previous is Int && latestFromA is Int && latestFromB is Int) latestFromA + latestFromB - previous + else if (previous is Long && latestFromA is Long && latestFromB is Long) latestFromA + latestFromB - previous + else throw QBitException("Unexpected counter value type for eav with gid=$gid, attr=$attr") + ) + else + Eav( + eavsFromA[0].eav.gid, + eavsFromA[0].eav.attr, + if (latestFromA is Byte && latestFromB is Byte) latestFromA + latestFromB + else if (latestFromA is Int && latestFromB is Int) latestFromA + latestFromB + else if (latestFromA is Long && latestFromB is Long) latestFromA + latestFromB + else throw QBitException("Unexpected counter value type for eav with gid=$gid, attr=$attr") + ) + ) + } + else -> listOf((eavsFromA + eavsFromB).maxByOrNull { it.timestamp }!!.eav) + } +} + internal fun findBaseNode(node1: Node, node2: Node, nodesDepth: Map): Node { return when { node1 == node2 -> node1 diff --git a/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt b/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt index 384abb6e..930f696e 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt @@ -24,7 +24,7 @@ class SchemaBuilder(private val serialModule: SerializersModule) { ?: throw QBitException("Cannot find descriptor for $type") val eb = EntityBuilder(descr) eb.body() - attrs.addAll(schemaFor(descr, eb.uniqueProps)) + attrs.addAll(schemaFor(descr, eb.uniqueProps, eb.counters)) } } @@ -33,6 +33,8 @@ class EntityBuilder(private val descr: SerialDescriptor) { internal val uniqueProps = HashSet() + internal val counters = HashSet() + fun uniqueInt(prop: KProperty1) { uniqueAttr(prop) } @@ -42,21 +44,41 @@ class EntityBuilder(private val descr: SerialDescriptor) { } private fun uniqueAttr(prop: KProperty1) { + uniqueProps.add(getAttrName(prop)) + } + + fun byteCounter(prop: KProperty1) { + counter(prop) + } + + fun intCounter(prop: KProperty1) { + counter(prop) + } + + fun longCounter(prop: KProperty1) { + counter(prop) + } + + private fun counter(prop: KProperty1) { + counters.add(getAttrName(prop)) + } + + private fun getAttrName(prop: KProperty1): String { val (idx, _) = descr.elementNames .withIndex().firstOrNull { (_, name) -> name == prop.name } ?: throw QBitException("Cannot find attr for ${prop.name} in $descr") - uniqueProps.add(AttrName(descr, idx).asString()) + return AttrName(descr, idx).asString() } } -fun schemaFor(rootDesc: SerialDescriptor, unique: Set = emptySet()): List> { +fun schemaFor(rootDesc: SerialDescriptor, unique: Set = emptySet(), counters: Set = emptySet()): List> { return rootDesc.elementDescriptors .withIndex() .filter { rootDesc.getElementName(it.index) !in setOf("id", "gid") } .map { (idx, desc) -> - val dataType = DataType.of(desc) val attr = AttrName(rootDesc, idx).asString() + val dataType = if (attr in counters) DataType.of(desc).counter() else DataType.of(desc) Attr(null, attr, dataType.code, attr in unique, dataType.isList()) } } diff --git a/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt b/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt index f198ee6f..479daf90 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt @@ -176,7 +176,7 @@ internal fun deserialize(ins: Input): Any { private fun readMark(ins: Input, expectedMark: DataType): Any { return when (expectedMark) { QBoolean -> (ins.readByte() == 1.toByte()) as T - QByte, QInt, QLong -> readLong(ins) as T + QByte, QInt, QLong, is QCounter<*> -> readLong(ins) as T QBytes -> readLong(ins).let { count -> readBytes(ins, count.toInt()) as T diff --git a/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt b/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt index 830fc897..2f8264c2 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt @@ -5,6 +5,6 @@ import qbit.index.InternalDb internal interface CommitHandler { - suspend fun update(trxLog: TrxLog, newLog: TrxLog, newDb: InternalDb) + suspend fun update(trxLog: TrxLog, baseDb: InternalDb, newLog: TrxLog, newDb: InternalDb) } \ No newline at end of file diff --git a/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt b/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt index f598b549..0fdc88f3 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt @@ -64,8 +64,7 @@ internal class QTrx( val instance = factor(inst.copy(nextEid = gids.next().eid), curDb::attr, EmptyIterator) val newLog = trxLog.append(factsBuffer + instance) try { - base = curDb.with(instance) - commitHandler.update(trxLog, newLog, base) + commitHandler.update(trxLog, base, newLog, curDb.with(instance)) factsBuffer.clear() } catch (e: Throwable) { // todo clean up diff --git a/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt b/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt index 29fb3db2..d945bb23 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt @@ -63,7 +63,7 @@ class ConnTest { ) val newLog = FakeTrxLog(storedLeaf.hash) - conn.update(conn.trxLog, newLog, EmptyDb) + conn.update(conn.trxLog, EmptyDb, newLog, EmptyDb) assertArrayEquals(newLog.hash.bytes, storage.load(Namespace("refs")["head"])) } diff --git a/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt b/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt index 1636e341..9b9d5adc 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt @@ -40,7 +40,7 @@ internal class FakeConn : Conn(), CommitHandler { override val head: Hash get() = TODO("not implemented") - override suspend fun update(trxLog: TrxLog, newLog: TrxLog, newDb: InternalDb) { + override suspend fun update(trxLog: TrxLog, baseDb: InternalDb, newLog: TrxLog, newDb: InternalDb) { updatesCalls++ } diff --git a/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt b/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt index c53945e2..1959fa93 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt @@ -400,7 +400,7 @@ class FunTest { assertEquals(bomb.country, storedBomb.country) assertEquals(bomb.optCountry, storedBomb.optCountry) assertEquals( - listOf(Country(12884901889, "Country1", 0), Country(4294967383, "Country3", 2)), + listOf(Country(12884901889, "Country1", 0), Country(4294967384, "Country3", 2)), storedBomb.countiesList ) // todo: assertEquals(bomb.countriesListOpt, storedBomb.countriesListOpt) @@ -459,9 +459,9 @@ class FunTest { trx1.persist(eBrewer.copy(name = "Im different change")) val trx2 = conn.trx() trx2.persist(eCodd.copy(name = "Im change 2")) - delay(100) trx2.persist(pChen.copy(name = "Im different change")) trx1.commit() + delay(1) trx2.commit() conn.db { assertEquals("Im change 2", it.pull(eCodd.id!!)!!.name) @@ -540,6 +540,7 @@ class FunTest { ) ) trx1.commit() + delay(1) trx2.commit() conn.db { assertEquals("Im change 2", it.pull(eCodd.id!!)!!.name) @@ -574,4 +575,25 @@ class FunTest { assertEquals(Gid(nsk.id!!), trx2EntityAttrValues.first { it.attr.name == "City/region" }.value) } } + + @JsName("qbit_should_accumulate_concurrent_increments_of_counter") + @Test + fun `qbit should accumulate concurrent increments of counter`() { + runBlocking { + val conn = setupTestSchema() + val counter = IntCounterEntity(1, 10) + val trx = conn.trx() + trx.persist(counter) + trx.commit() + + val trx1 = conn.trx() + val trx2 = conn.trx() + trx1.persist(counter.copy(counter = 40)) + trx2.persist(counter.copy(counter = 70)) + trx1.commit() + trx2.commit() + + assertEquals(conn.db().pull(1)?.counter, 100) + } + } } \ No newline at end of file diff --git a/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt b/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt index fc7d77a9..f274bc02 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt @@ -36,6 +36,9 @@ val testSchema = schema(internalTestsSerialModule) { entity(NullableList::class) entity(NullableRef::class) entity(IntEntity::class) + entity(IntCounterEntity::class) { + intCounter(IntCounterEntity::counter) + } entity(ResearchGroup::class) entity(EntityWithByteArray::class) entity(EntityWithListOfBytes::class) diff --git a/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt b/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt index 840ccd9e..92d65d54 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt @@ -6,18 +6,16 @@ import qbit.api.Attrs import qbit.api.Instances import qbit.api.QBitException import qbit.api.db.Conn -import qbit.api.db.attrIs import qbit.api.db.pull -import qbit.api.db.query import qbit.api.gid.Gid import qbit.api.gid.nextGids -import qbit.api.model.Attr import qbit.api.system.Instance import qbit.ns.Key import qbit.ns.ns import qbit.platform.runBlocking import qbit.spi.Storage import qbit.storage.MemStorage +import qbit.test.model.IntCounterEntity import qbit.test.model.Region import qbit.test.model.Scientist import qbit.test.model.testsSerialModule @@ -176,6 +174,25 @@ class TrxTest { } } + @JsName("Counter_test") + @Test + fun `Counter test`() { // TODO: find an appropriate place for this test + runBlocking { + val conn = setupTestData() + val counterEntity = IntCounterEntity(1, 10) + + conn.trx { + persist(counterEntity) + } + assertEquals(conn.db().pull(1)?.counter, 10) + + conn.trx { + persist(counterEntity.copy(counter = 90)) + } + assertEquals(conn.db().pull(1)?.counter, 90) + } + } + private suspend fun openEmptyConn(): Pair { val storage = MemStorage() val conn = qbit(storage, testsSerialModule) diff --git a/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt b/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt index d88500fc..79760478 100644 --- a/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt +++ b/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt @@ -9,6 +9,9 @@ data class TheSimplestEntity(val id: Long?, val scalar: String) @Serializable data class IntEntity(val id: Long?, val int: Int) +@Serializable +data class IntCounterEntity(val id: Long?, val counter: Int) + @Serializable data class NullableIntEntity(val id: Long?, val int: Int?) @@ -307,6 +310,7 @@ val testsSerialModule = SerializersModule { contextual(ByteArrayEntity::class, ByteArrayEntity.serializer()) contextual(ListOfByteArraysEntity::class, ListOfByteArraysEntity.serializer()) contextual(IntEntity::class, IntEntity.serializer()) + contextual(IntCounterEntity::class, IntCounterEntity.serializer()) contextual(Region::class, Region.serializer()) contextual(ParentToChildrenTreeEntity::class, ParentToChildrenTreeEntity.serializer()) contextual(EntityWithRefsToSameType::class, EntityWithRefsToSameType.serializer())