Skip to content

Commit

Permalink
fix: Allow multiple first-party endpoints to communicate with the sam…
Browse files Browse the repository at this point in the history
…e third-party endpoint (#361)

Due to bugs in the key stores (relaycorp/awala-jvm#306, relaycorp/awala-jvm#310), we were accidentally reusing the same session keys, which the third-party endpoint rightfully refused.

This PR will integrate the two breaking changes above, plus one additional change needed to make this work.

A consequence of this change is that **we'll now need to pass the linked first-party endpoint when importing and deleting a third-party endpoint**, so that we know which session keys to delete in case the same third-party endpoint is used by another first-party endpoint.

# TODO

Each of the following require updating the respective test suite.

- [x] Pass first-party endpoint to every function call inside `ThirdPartyEndpoint.delete()`:

  https://github.com/relaycorp/awala-endpoint-android/blob/6ecb5345d67bf29033b0e27a02d91e7d1b35c02a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt#L31-L36

  This should look like this, roughly (untested):

  ```kotlin
  public open suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) {
    val context = Awala.getContextOrThrow()
    context.privateKeyStore.deleteBoundSessionKeys(linkedFirstPartyEndpoint.nodeId, nodeId) 
    context.sessionPublicKeyStore.delete(linkedFirstPartyEndpoint.nodeId, nodeId) 
    context.channelManager.delete(linkedFirstPartyEndpoint, this)
  }
  ```
- [x] Replace `delete(ThirdPartyEndpoint)` in `ChannelManager` with `delete(FirstPartyEndpoint, ThirdPartyEndpoint)`, and only delete the items for the given first-/third-party endpoint pair. This change is needed by the previous task.
- [x] Replace `import(ByteArray)` in `PrivateThirdPartyEndpoint` with `import(ByteArray, FirstPartyEndpoint)`, and pass the first-party endpoint to

  https://github.com/relaycorp/awala-endpoint-android/blob/6ecb5345d67bf29033b0e27a02d91e7d1b35c02a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt#L183
- [x] Replace `import(ByteArray)` in `PublicThirdPartyEndpoint` with `import(ByteArray, FirstPartyEndpoint)`, and pass the first-party endpoint to

  https://github.com/relaycorp/awala-endpoint-android/blob/6ecb5345d67bf29033b0e27a02d91e7d1b35c02a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt#L249-L252
  • Loading branch information
gnarea authored Dec 26, 2023
1 parent 1346945 commit 14f44d2
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 53 deletions.
6 changes: 3 additions & 3 deletions lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ dependencies {
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion"

// Awala
implementation 'tech.relaycorp:awala:1.68.0'
implementation 'tech.relaycorp:awala-keystore-file:1.6.31'
implementation 'tech.relaycorp:awala:1.68.5'
implementation 'tech.relaycorp:awala-keystore-file:1.6.39'
implementation 'tech.relaycorp:poweb:1.5.68'
testImplementation 'tech.relaycorp:awala-testing:1.5.24'
testImplementation 'tech.relaycorp:awala-testing:1.5.27'

// Security
implementation 'androidx.security:security-crypto:1.1.0-alpha06'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,29 @@ internal class ChannelManager(
}
}

suspend fun delete(thirdPartyEndpoint: ThirdPartyEndpoint) {
suspend fun delete(
linkedFirstPartyEndpoint: FirstPartyEndpoint,
thirdPartyEndpoint: ThirdPartyEndpoint,
) {
val key = linkedFirstPartyEndpoint.nodeId
withContext(coroutineContext) {
sharedPreferences.all.forEach { (key, value) ->
// Skip malformed values
if (value !is MutableSet<*>) {
return@forEach
}
val sanitizedValue: List<String> = value.filterIsInstance<String>()
if (value.size != sanitizedValue.size) {
return@forEach
}
val value =
try {
sharedPreferences.getStringSet(key, null)
} catch (exc: ClassCastException) {
// Skip malformed values
return@withContext
} ?: return@withContext
val sanitizedValue: List<String> = value.filterIsInstance<String>()
if (value.size != sanitizedValue.size) {
return@withContext
}

if ((value).contains(thirdPartyEndpoint.nodeId)) {
val newValue = sanitizedValue.filter { it != thirdPartyEndpoint.nodeId }
with(sharedPreferences.edit()) {
putStringSet(key, newValue.toMutableSet())
commit()
}
if ((value).contains(thirdPartyEndpoint.nodeId)) {
val newValue = sanitizedValue.filter { it != thirdPartyEndpoint.nodeId }
with(sharedPreferences.edit()) {
putStringSet(key, newValue.toMutableSet())
commit()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ public sealed class ThirdPartyEndpoint(
* Delete the endpoint.
*/
@Throws(PersistenceException::class)
public open suspend fun delete() {
public open suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) {
val context = Awala.getContextOrThrow()
context.privateKeyStore.deleteSessionKeysForPeer(nodeId)
context.sessionPublicKeyStore.delete(nodeId)
context.channelManager.delete(this)
context.privateKeyStore.deleteBoundSessionKeys(linkedFirstPartyEndpoint.nodeId, nodeId)
context.sessionPublicKeyStore.delete(linkedFirstPartyEndpoint.nodeId, nodeId)
context.channelManager.delete(linkedFirstPartyEndpoint, this)
}

internal companion object {
Expand Down Expand Up @@ -62,10 +62,10 @@ public class PrivateThirdPartyEndpoint internal constructor(
private val storageKey = "${firstPartyEndpointAddress}_$nodeId"

@Throws(PersistenceException::class, SetupPendingException::class)
override suspend fun delete() {
override suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) {
val context = Awala.getContextOrThrow()
context.storage.privateThirdParty.delete(storageKey)
super.delete()
super.delete(linkedFirstPartyEndpoint)
}

@Throws(InvalidAuthorizationException::class)
Expand Down Expand Up @@ -135,6 +135,7 @@ public class PrivateThirdPartyEndpoint internal constructor(
)
public suspend fun import(
connectionParamsSerialized: ByteArray,
firstPartyEndpoint: FirstPartyEndpoint,
): PrivateThirdPartyEndpoint {
val context = Awala.getContextOrThrow()

Expand Down Expand Up @@ -180,7 +181,11 @@ public class PrivateThirdPartyEndpoint internal constructor(
)
context.storage.privateThirdParty.set(endpoint.storageKey, data)

context.sessionPublicKeyStore.save(params.sessionKey, endpoint.nodeId)
context.sessionPublicKeyStore.save(
params.sessionKey,
firstPartyEndpoint.nodeId,
endpoint.nodeId,
)

return endpoint
}
Expand All @@ -197,10 +202,10 @@ public class PublicThirdPartyEndpoint internal constructor(
identityKey: PublicKey,
) : ThirdPartyEndpoint(identityKey, internetAddress) {
@Throws(PersistenceException::class, SetupPendingException::class)
override suspend fun delete() {
override suspend fun delete(linkedFirstPartyEndpoint: FirstPartyEndpoint) {
val context = Awala.getContextOrThrow()
context.storage.publicThirdParty.delete(nodeId)
super.delete()
super.delete(linkedFirstPartyEndpoint)
}

public companion object {
Expand All @@ -227,6 +232,7 @@ public class PublicThirdPartyEndpoint internal constructor(
)
public suspend fun import(
connectionParamsSerialized: ByteArray,
firstPartyEndpoint: FirstPartyEndpoint,
): PublicThirdPartyEndpoint {
val context = Awala.getContextOrThrow()
val connectionParams =
Expand All @@ -248,6 +254,7 @@ public class PublicThirdPartyEndpoint internal constructor(
)
context.sessionPublicKeyStore.save(
connectionParams.sessionKey,
firstPartyEndpoint.nodeId,
peerNodeId,
)
return PublicThirdPartyEndpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ internal class ChannelManagerTest {
apply()
}

manager.delete(thirdPartyEndpoint)
manager.delete(firstPartyEndpoint, thirdPartyEndpoint)

assertEquals(
mutableSetOf(unrelatedThirdPartyEndpointAddress),
Expand All @@ -145,7 +145,7 @@ internal class ChannelManagerTest {
apply()
}

manager.delete(thirdPartyEndpoint)
manager.delete(firstPartyEndpoint, thirdPartyEndpoint)

assertEquals(
setOf(unrelatedThirdPartyEndpointAddress),
Expand All @@ -166,7 +166,7 @@ internal class ChannelManagerTest {
apply()
}

manager.delete(thirdPartyEndpoint)
manager.delete(firstPartyEndpoint, thirdPartyEndpoint)

assertEquals(
malformedValue,
Expand All @@ -187,7 +187,7 @@ internal class ChannelManagerTest {
apply()
}

manager.delete(thirdPartyEndpoint)
manager.delete(firstPartyEndpoint, thirdPartyEndpoint)

assertEquals(
malformedValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
listOf(thirdPartyEndpointCertificate),
)
val paramsSerialized = serializeConnectionParams(delivAuth)
val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized)
val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint)

assertEquals(
firstPartyEndpoint.nodeId,
Expand Down Expand Up @@ -145,7 +145,10 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
},
)

assertEquals(sessionKey, sessionPublicKeystore.retrieve(endpoint.nodeId))
assertEquals(
sessionKey,
sessionPublicKeystore.retrieve(firstPartyEndpoint.nodeId, endpoint.nodeId),
)
}

@Test
Expand All @@ -154,8 +157,10 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
val firstPartyCert = PDACertPath.PRIVATE_ENDPOINT
val pdaPath = CertificationPath(firstPartyCert, emptyList())
val paramsSerialized = serializeConnectionParams(pdaPath)
val firstPartyEndpoint = createFirstPartyEndpoint()
firstPartyEndpoint.delete()
try {
PrivateThirdPartyEndpoint.import(paramsSerialized)
PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint)
} catch (exception: UnknownFirstPartyEndpointException) {
assertEquals(
"First-party endpoint ${firstPartyCert.subjectId} is not registered",
Expand Down Expand Up @@ -195,7 +200,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
)
val paramsSerialized = serializeConnectionParams(pdaPath)
try {
PrivateThirdPartyEndpoint.import(paramsSerialized)
PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint)
} catch (exception: InvalidAuthorizationException) {
assertEquals("PDA path is invalid", exception.message)
assertTrue(exception.cause is CertificationPathException)
Expand All @@ -209,8 +214,9 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
@Test
fun import_malformedParams() =
runTest {
val firstPartyEndpoint = createFirstPartyEndpoint()
try {
PrivateThirdPartyEndpoint.import("malformed".toByteArray())
PrivateThirdPartyEndpoint.import("malformed".toByteArray(), firstPartyEndpoint)
} catch (exception: InvalidThirdPartyEndpoint) {
assertEquals("Malformed connection params", exception.message)
assertTrue(exception.cause is InvalidNodeConnectionParams)
Expand All @@ -223,7 +229,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
@Test
fun import_invalidPDAPath() =
runTest {
createFirstPartyEndpoint()
val firstPartyEndpoint = createFirstPartyEndpoint()
val pdaPath =
CertificationPath(
pda,
Expand All @@ -232,7 +238,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
)
val paramsSerialized = serializeConnectionParams(pdaPath)
try {
PrivateThirdPartyEndpoint.import(paramsSerialized)
PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint)
} catch (exception: InvalidAuthorizationException) {
assertEquals("PDA path is invalid", exception.message)
return@runTest
Expand All @@ -259,7 +265,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
val pdaPath = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate))
val paramsSerialized = serializeConnectionParams(pdaPath)
try {
PrivateThirdPartyEndpoint.import(paramsSerialized)
PrivateThirdPartyEndpoint.import(paramsSerialized, firstPartyEndpoint)
} catch (exception: InvalidAuthorizationException) {
assertEquals("PDA path is invalid", exception.message)
assertTrue(exception.cause is CertificationPathException)
Expand Down Expand Up @@ -401,13 +407,13 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() {
val endpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint
val firstPartyEndpoint = channel.firstPartyEndpoint

endpoint.delete()
endpoint.delete(firstPartyEndpoint)

verify(storage.privateThirdParty)
.delete("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}")
assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size)
assertEquals(0, sessionPublicKeystore.keys.size)
verify(channelManager).delete(endpoint)
assertEquals(0, sessionPublicKeystore.keys[firstPartyEndpoint.nodeId]!!.size)
verify(channelManager).delete(firstPartyEndpoint, endpoint)
}

private fun serializeConnectionParams(deliveryAuth: CertificationPath) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() {
SessionKeyPair.generate().sessionKey,
)

val thirdPartyEndpoint = PublicThirdPartyEndpoint.import(connectionParams.serialize())
val firstPartyEndpoint = createFirstPartyEndpoint()

val thirdPartyEndpoint =
PublicThirdPartyEndpoint.import(
connectionParams.serialize(),
firstPartyEndpoint,
)

assertEquals(connectionParams.internetAddress, thirdPartyEndpoint.internetAddress)
assertEquals(connectionParams.identityKey, thirdPartyEndpoint.identityKey)
Expand All @@ -91,15 +97,17 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() {
connectionParams.identityKey,
),
)
sessionPublicKeystore.retrieve(thirdPartyEndpoint.nodeId)
sessionPublicKeystore.retrieve(firstPartyEndpoint.nodeId, thirdPartyEndpoint.nodeId)
}

@Test
fun import_invalidConnectionParams() =
runTest {
val firstPartyEndpoint = createFirstPartyEndpoint()
try {
PublicThirdPartyEndpoint.import(
"malformed".toByteArray(),
firstPartyEndpoint,
)
} catch (exception: InvalidThirdPartyEndpoint) {
assertEquals("Connection params serialization is malformed", exception.message)
Expand Down Expand Up @@ -134,13 +142,17 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() {
thirdPartyEndpoint.nodeId,
)
val peerSessionKey = SessionKeyPair.generate().sessionKey
sessionPublicKeystore.save(peerSessionKey, thirdPartyEndpoint.nodeId)
sessionPublicKeystore.save(
peerSessionKey,
firstPartyEndpoint.nodeId,
thirdPartyEndpoint.nodeId,
)

thirdPartyEndpoint.delete()
thirdPartyEndpoint.delete(firstPartyEndpoint)

verify(storage.publicThirdParty).delete(thirdPartyEndpoint.nodeId)
assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size)
assertEquals(0, sessionPublicKeystore.keys.size)
verify(channelManager).delete(thirdPartyEndpoint)
assertEquals(0, sessionPublicKeystore.keys[firstPartyEndpoint.nodeId]!!.size)
verify(channelManager).delete(firstPartyEndpoint, thirdPartyEndpoint)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ internal class IncomingMessageTest : MockContextTestCase() {
payload =
thirdPartyEndpointManager.wrapMessagePayload(
serviceMessage,
channel.firstPartyEndpoint.nodeId,
channel.thirdPartyEndpoint.nodeId,
channel.firstPartyEndpoint.nodeId,
),
senderCertificate = PDACertPath.PDA,
)
Expand Down Expand Up @@ -279,6 +279,7 @@ internal class IncomingMessageTest : MockContextTestCase() {
thirdPartySessionPublicKeyStore.save(
channel.firstPartySessionKeyPair.sessionKey,
channel.firstPartyEndpoint.nodeId,
channel.thirdPartyEndpoint.nodeId,
)
return EndpointManager(
thirdPartyPrivateKeyStore,
Expand All @@ -299,8 +300,8 @@ internal class IncomingMessageTest : MockContextTestCase() {
val pdaPathServiceMessage = makePDAPathMessage(plaintext)
return thirdPartyEndpointManager.wrapMessagePayload(
pdaPathServiceMessage,
channel.firstPartyEndpoint.nodeId,
channel.thirdPartyEndpoint.nodeId,
channel.firstPartyEndpoint.nodeId,
)
}

Expand All @@ -310,7 +311,8 @@ internal class IncomingMessageTest : MockContextTestCase() {
companion object {
private val logCaptor = LogCaptor.forClass(IncomingMessage::class.java)

@JvmStatic
@AfterClass
fun closeLogs() = logCaptor.close()
fun closeLogs(): Unit = logCaptor.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ internal class ReceiveMessagesTest : MockContextTestCase() {
val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection)))
pdcClient = MockPDCClient(collectParcelsCall)

channel.thirdPartyEndpoint.delete()
channel.thirdPartyEndpoint.delete(channel.firstPartyEndpoint)

val messages = subject.receive().toCollection(mutableListOf())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ internal abstract class MockContextTestCase {

sessionPublicKeystore.save(
sessionKey,
firstPartyEndpoint.nodeId,
thirdPartyEndpoint.nodeId,
)
return thirdPartyEndpoint
Expand Down

0 comments on commit 14f44d2

Please sign in to comment.