From 14f44d28c3a7dd7c29e78b057ceb19403ffa7906 Mon Sep 17 00:00:00 2001 From: Gus Narea Date: Tue, 26 Dec 2023 12:16:00 +0000 Subject: [PATCH] fix: Allow multiple first-party endpoints to communicate with the same third-party endpoint (#361) Due to bugs in the key stores (https://github.com/relaycorp/awala-jvm/pull/306, https://github.com/relaycorp/awala-jvm/pull/310), we were accidentally reusing the same session keys, which the third-party endpoint rightfully refused. This PR will integrate the two breaking changes above, plus one additional change needed to make this work. A consequence of this change is that **we'll now need to pass the linked first-party endpoint when importing and deleting a third-party endpoint**, so that we know which session keys to delete in case the same third-party endpoint is used by another first-party endpoint. # TODO Each of the following require updating the respective test suite. - [x] Pass first-party endpoint to every function call inside `ThirdPartyEndpoint.delete()`: https://github.com/relaycorp/awala-endpoint-android/blob/6ecb5345d67bf29033b0e27a02d91e7d1b35c02a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt#L31-L36 This should look like this, roughly (untested): ```kotlin public open suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) { val context = Awala.getContextOrThrow() context.privateKeyStore.deleteBoundSessionKeys(linkedFirstPartyEndpoint.nodeId, nodeId) context.sessionPublicKeyStore.delete(linkedFirstPartyEndpoint.nodeId, nodeId) context.channelManager.delete(linkedFirstPartyEndpoint, this) } ``` - [x] Replace `delete(ThirdPartyEndpoint)` in `ChannelManager` with `delete(FirstPartyEndpoint, ThirdPartyEndpoint)`, and only delete the items for the given first-/third-party endpoint pair. This change is needed by the previous task. - [x] Replace `import(ByteArray)` in `PrivateThirdPartyEndpoint` with `import(ByteArray, FirstPartyEndpoint)`, and pass the first-party endpoint to https://github.com/relaycorp/awala-endpoint-android/blob/6ecb5345d67bf29033b0e27a02d91e7d1b35c02a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt#L183 - [x] Replace `import(ByteArray)` in `PublicThirdPartyEndpoint` with `import(ByteArray, FirstPartyEndpoint)`, and pass the first-party endpoint to https://github.com/relaycorp/awala-endpoint-android/blob/6ecb5345d67bf29033b0e27a02d91e7d1b35c02a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt#L249-L252 --- lib/build.gradle | 6 +-- .../awaladroid/endpoint/ChannelManager.kt | 37 +++++++++++-------- .../awaladroid/endpoint/ThirdPartyEndpoint.kt | 25 ++++++++----- .../awaladroid/endpoint/ChannelManagerTest.kt | 8 ++-- .../endpoint/PrivateThirdPartyEndpointTest.kt | 28 ++++++++------ .../endpoint/PublicThirdPartyEndpointTest.kt | 24 +++++++++--- .../messaging/IncomingMessageTest.kt | 8 ++-- .../messaging/ReceiveMessagesTest.kt | 2 +- .../awaladroid/test/MockContextTestCase.kt | 1 + 9 files changed, 86 insertions(+), 53 deletions(-) diff --git a/lib/build.gradle b/lib/build.gradle index c0346ccd..54227631 100644 --- a/lib/build.gradle +++ b/lib/build.gradle @@ -66,10 +66,10 @@ dependencies { implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion" // Awala - implementation 'tech.relaycorp:awala:1.68.0' - implementation 'tech.relaycorp:awala-keystore-file:1.6.31' + implementation 'tech.relaycorp:awala:1.68.5' + implementation 'tech.relaycorp:awala-keystore-file:1.6.39' implementation 'tech.relaycorp:poweb:1.5.68' - testImplementation 'tech.relaycorp:awala-testing:1.5.24' + testImplementation 'tech.relaycorp:awala-testing:1.5.27' // Security implementation 'androidx.security:security-crypto:1.1.0-alpha06' diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt index ba7db77f..6702c092 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt @@ -54,24 +54,29 @@ internal class ChannelManager( } } - suspend fun delete(thirdPartyEndpoint: ThirdPartyEndpoint) { + suspend fun delete( + linkedFirstPartyEndpoint: FirstPartyEndpoint, + thirdPartyEndpoint: ThirdPartyEndpoint, + ) { + val key = linkedFirstPartyEndpoint.nodeId withContext(coroutineContext) { - sharedPreferences.all.forEach { (key, value) -> - // Skip malformed values - if (value !is MutableSet<*>) { - return@forEach - } - val sanitizedValue: List = value.filterIsInstance() - if (value.size != sanitizedValue.size) { - return@forEach - } + val value = + try { + sharedPreferences.getStringSet(key, null) + } catch (exc: ClassCastException) { + // Skip malformed values + return@withContext + } ?: return@withContext + val sanitizedValue: List = value.filterIsInstance() + if (value.size != sanitizedValue.size) { + return@withContext + } - if ((value).contains(thirdPartyEndpoint.nodeId)) { - val newValue = sanitizedValue.filter { it != thirdPartyEndpoint.nodeId } - with(sharedPreferences.edit()) { - putStringSet(key, newValue.toMutableSet()) - commit() - } + if ((value).contains(thirdPartyEndpoint.nodeId)) { + val newValue = sanitizedValue.filter { it != thirdPartyEndpoint.nodeId } + with(sharedPreferences.edit()) { + putStringSet(key, newValue.toMutableSet()) + commit() } } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt index 928ead87..ff9701dd 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt @@ -28,11 +28,11 @@ public sealed class ThirdPartyEndpoint( * Delete the endpoint. */ @Throws(PersistenceException::class) - public open suspend fun delete() { + public open suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) { val context = Awala.getContextOrThrow() - context.privateKeyStore.deleteSessionKeysForPeer(nodeId) - context.sessionPublicKeyStore.delete(nodeId) - context.channelManager.delete(this) + context.privateKeyStore.deleteBoundSessionKeys(linkedFirstPartyEndpoint.nodeId, nodeId) + context.sessionPublicKeyStore.delete(linkedFirstPartyEndpoint.nodeId, nodeId) + context.channelManager.delete(linkedFirstPartyEndpoint, this) } internal companion object { @@ -62,10 +62,10 @@ public class PrivateThirdPartyEndpoint internal constructor( private val storageKey = "${firstPartyEndpointAddress}_$nodeId" @Throws(PersistenceException::class, SetupPendingException::class) - override suspend fun delete() { + override suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) { val context = Awala.getContextOrThrow() context.storage.privateThirdParty.delete(storageKey) - super.delete() + super.delete(linkedFirstPartyEndpoint) } @Throws(InvalidAuthorizationException::class) @@ -135,6 +135,7 @@ public class PrivateThirdPartyEndpoint internal constructor( ) public suspend fun import( connectionParamsSerialized: ByteArray, + firstPartyEndpoint: FirstPartyEndpoint, ): PrivateThirdPartyEndpoint { val context = Awala.getContextOrThrow() @@ -180,7 +181,11 @@ public class PrivateThirdPartyEndpoint internal constructor( ) context.storage.privateThirdParty.set(endpoint.storageKey, data) - context.sessionPublicKeyStore.save(params.sessionKey, endpoint.nodeId) + context.sessionPublicKeyStore.save( + params.sessionKey, + firstPartyEndpoint.nodeId, + endpoint.nodeId, + ) return endpoint } @@ -197,10 +202,10 @@ public class PublicThirdPartyEndpoint internal constructor( identityKey: PublicKey, ) : ThirdPartyEndpoint(identityKey, internetAddress) { @Throws(PersistenceException::class, SetupPendingException::class) - override suspend fun delete() { + override suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) { val context = Awala.getContextOrThrow() context.storage.publicThirdParty.delete(nodeId) - super.delete() + super.delete(linkedFirstPartyEndpoint) } public companion object { @@ -227,6 +232,7 @@ public class PublicThirdPartyEndpoint internal constructor( ) public suspend fun import( connectionParamsSerialized: ByteArray, + firstPartyEndpoint: FirstPartyEndpoint, ): PublicThirdPartyEndpoint { val context = Awala.getContextOrThrow() val connectionParams = @@ -248,6 +254,7 @@ public class PublicThirdPartyEndpoint internal constructor( ) context.sessionPublicKeyStore.save( connectionParams.sessionKey, + firstPartyEndpoint.nodeId, peerNodeId, ) return PublicThirdPartyEndpoint( diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt index 2e58a4ef..c90c8b48 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt @@ -124,7 +124,7 @@ internal class ChannelManagerTest { apply() } - manager.delete(thirdPartyEndpoint) + manager.delete(firstPartyEndpoint, thirdPartyEndpoint) assertEquals( mutableSetOf(unrelatedThirdPartyEndpointAddress), @@ -145,7 +145,7 @@ internal class ChannelManagerTest { apply() } - manager.delete(thirdPartyEndpoint) + manager.delete(firstPartyEndpoint, thirdPartyEndpoint) assertEquals( setOf(unrelatedThirdPartyEndpointAddress), @@ -166,7 +166,7 @@ internal class ChannelManagerTest { apply() } - manager.delete(thirdPartyEndpoint) + manager.delete(firstPartyEndpoint, thirdPartyEndpoint) assertEquals( malformedValue, @@ -187,7 +187,7 @@ internal class ChannelManagerTest { apply() } - manager.delete(thirdPartyEndpoint) + manager.delete(firstPartyEndpoint, thirdPartyEndpoint) assertEquals( malformedValue, diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt index 9dc2a5e6..15eb3847 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt @@ -115,7 +115,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { listOf(thirdPartyEndpointCertificate), ) val paramsSerialized = serializeConnectionParams(delivAuth) - val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized) + val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint) assertEquals( firstPartyEndpoint.nodeId, @@ -145,7 +145,10 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { }, ) - assertEquals(sessionKey, sessionPublicKeystore.retrieve(endpoint.nodeId)) + assertEquals( + sessionKey, + sessionPublicKeystore.retrieve(firstPartyEndpoint.nodeId, endpoint.nodeId), + ) } @Test @@ -154,8 +157,10 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { val firstPartyCert = PDACertPath.PRIVATE_ENDPOINT val pdaPath = CertificationPath(firstPartyCert, emptyList()) val paramsSerialized = serializeConnectionParams(pdaPath) + val firstPartyEndpoint = createFirstPartyEndpoint() + firstPartyEndpoint.delete() try { - PrivateThirdPartyEndpoint.import(paramsSerialized) + PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint) } catch (exception: UnknownFirstPartyEndpointException) { assertEquals( "First-party endpoint ${firstPartyCert.subjectId} is not registered", @@ -195,7 +200,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { ) val paramsSerialized = serializeConnectionParams(pdaPath) try { - PrivateThirdPartyEndpoint.import(paramsSerialized) + PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint) } catch (exception: InvalidAuthorizationException) { assertEquals("PDA path is invalid", exception.message) assertTrue(exception.cause is CertificationPathException) @@ -209,8 +214,9 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { @Test fun import_malformedParams() = runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() try { - PrivateThirdPartyEndpoint.import("malformed".toByteArray()) + PrivateThirdPartyEndpoint.import("malformed".toByteArray(), firstPartyEndpoint) } catch (exception: InvalidThirdPartyEndpoint) { assertEquals("Malformed connection params", exception.message) assertTrue(exception.cause is InvalidNodeConnectionParams) @@ -223,7 +229,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { @Test fun import_invalidPDAPath() = runTest { - createFirstPartyEndpoint() + val firstPartyEndpoint = createFirstPartyEndpoint() val pdaPath = CertificationPath( pda, @@ -232,7 +238,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { ) val paramsSerialized = serializeConnectionParams(pdaPath) try { - PrivateThirdPartyEndpoint.import(paramsSerialized) + PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint) } catch (exception: InvalidAuthorizationException) { assertEquals("PDA path is invalid", exception.message) return@runTest @@ -259,7 +265,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { val pdaPath = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) val paramsSerialized = serializeConnectionParams(pdaPath) try { - PrivateThirdPartyEndpoint.import(paramsSerialized) + PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint) } catch (exception: InvalidAuthorizationException) { assertEquals("PDA path is invalid", exception.message) assertTrue(exception.cause is CertificationPathException) @@ -401,13 +407,13 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { val endpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint val firstPartyEndpoint = channel.firstPartyEndpoint - endpoint.delete() + endpoint.delete(firstPartyEndpoint) verify(storage.privateThirdParty) .delete("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}") assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) - assertEquals(0, sessionPublicKeystore.keys.size) - verify(channelManager).delete(endpoint) + assertEquals(0, sessionPublicKeystore.keys[firstPartyEndpoint.nodeId]!!.size) + verify(channelManager).delete(firstPartyEndpoint, endpoint) } private fun serializeConnectionParams(deliveryAuth: CertificationPath) = diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt index 5d9e4839..e761c43f 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt @@ -80,7 +80,13 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { SessionKeyPair.generate().sessionKey, ) - val thirdPartyEndpoint = PublicThirdPartyEndpoint.import(connectionParams.serialize()) + val firstPartyEndpoint = createFirstPartyEndpoint() + + val thirdPartyEndpoint = + PublicThirdPartyEndpoint.import( + connectionParams.serialize(), + firstPartyEndpoint, + ) assertEquals(connectionParams.internetAddress, thirdPartyEndpoint.internetAddress) assertEquals(connectionParams.identityKey, thirdPartyEndpoint.identityKey) @@ -91,15 +97,17 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { connectionParams.identityKey, ), ) - sessionPublicKeystore.retrieve(thirdPartyEndpoint.nodeId) + sessionPublicKeystore.retrieve(firstPartyEndpoint.nodeId, thirdPartyEndpoint.nodeId) } @Test fun import_invalidConnectionParams() = runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() try { PublicThirdPartyEndpoint.import( "malformed".toByteArray(), + firstPartyEndpoint, ) } catch (exception: InvalidThirdPartyEndpoint) { assertEquals("Connection params serialization is malformed", exception.message) @@ -134,13 +142,17 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { thirdPartyEndpoint.nodeId, ) val peerSessionKey = SessionKeyPair.generate().sessionKey - sessionPublicKeystore.save(peerSessionKey, thirdPartyEndpoint.nodeId) + sessionPublicKeystore.save( + peerSessionKey, + firstPartyEndpoint.nodeId, + thirdPartyEndpoint.nodeId, + ) - thirdPartyEndpoint.delete() + thirdPartyEndpoint.delete(firstPartyEndpoint) verify(storage.publicThirdParty).delete(thirdPartyEndpoint.nodeId) assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) - assertEquals(0, sessionPublicKeystore.keys.size) - verify(channelManager).delete(thirdPartyEndpoint) + assertEquals(0, sessionPublicKeystore.keys[firstPartyEndpoint.nodeId]!!.size) + verify(channelManager).delete(firstPartyEndpoint, thirdPartyEndpoint) } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt index f0873503..c156b291 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt @@ -65,8 +65,8 @@ internal class IncomingMessageTest : MockContextTestCase() { payload = thirdPartyEndpointManager.wrapMessagePayload( serviceMessage, - channel.firstPartyEndpoint.nodeId, channel.thirdPartyEndpoint.nodeId, + channel.firstPartyEndpoint.nodeId, ), senderCertificate = PDACertPath.PDA, ) @@ -279,6 +279,7 @@ internal class IncomingMessageTest : MockContextTestCase() { thirdPartySessionPublicKeyStore.save( channel.firstPartySessionKeyPair.sessionKey, channel.firstPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.nodeId, ) return EndpointManager( thirdPartyPrivateKeyStore, @@ -299,8 +300,8 @@ internal class IncomingMessageTest : MockContextTestCase() { val pdaPathServiceMessage = makePDAPathMessage(plaintext) return thirdPartyEndpointManager.wrapMessagePayload( pdaPathServiceMessage, - channel.firstPartyEndpoint.nodeId, channel.thirdPartyEndpoint.nodeId, + channel.firstPartyEndpoint.nodeId, ) } @@ -310,7 +311,8 @@ internal class IncomingMessageTest : MockContextTestCase() { companion object { private val logCaptor = LogCaptor.forClass(IncomingMessage::class.java) + @JvmStatic @AfterClass - fun closeLogs() = logCaptor.close() + fun closeLogs(): Unit = logCaptor.close() } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt index 142fd2a7..a8c4cf49 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt @@ -218,7 +218,7 @@ internal class ReceiveMessagesTest : MockContextTestCase() { val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) pdcClient = MockPDCClient(collectParcelsCall) - channel.thirdPartyEndpoint.delete() + channel.thirdPartyEndpoint.delete(channel.firstPartyEndpoint) val messages = subject.receive().toCollection(mutableListOf()) diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt index 1406c1ac..795e5323 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt @@ -171,6 +171,7 @@ internal abstract class MockContextTestCase { sessionPublicKeystore.save( sessionKey, + firstPartyEndpoint.nodeId, thirdPartyEndpoint.nodeId, ) return thirdPartyEndpoint