Skip to content

Commit

Permalink
fix(PublicSessionKeyStore): Include nodeId in scope
Browse files Browse the repository at this point in the history
  • Loading branch information
gnarea committed Dec 17, 2023
1 parent b7dc045 commit af58de9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ abstract class SessionPublicKeyStore {
@Throws(KeyStoreBackendException::class)
suspend fun save(
key: SessionKey,
nodeId: String,
peerId: String,
creationTime: ZonedDateTime = ZonedDateTime.now()
) {
val creationTimestamp = creationTime.toEpochSecond()

val existingKeyData = retrieveKeyData(peerId)
val existingKeyData = retrieveKeyData(nodeId, peerId)
if (existingKeyData != null && creationTimestamp < existingKeyData.creationTimestamp) {
return
}
Expand All @@ -23,31 +24,32 @@ abstract class SessionPublicKeyStore {
key.publicKey.encoded,
creationTimestamp
)
saveKeyData(keyData, peerId)
saveKeyData(keyData, nodeId, peerId)
}

@Throws(KeyStoreBackendException::class)
suspend fun retrieve(peerId: String): SessionKey {
val keyData = retrieveKeyData(peerId)
?: throw MissingKeyException("There is no session key for $peerId")
suspend fun retrieve(nodeId: String, peerId: String): SessionKey {
val keyData = retrieveKeyData(nodeId, peerId)
?: throw MissingKeyException("Node $nodeId has no session key for $peerId")

val sessionPublicKey = keyData.keyDer.deserializeECPublicKey()
return SessionKey(keyData.keyId, sessionPublicKey)
}

/**
* Delete the session key for [peerId], if it exists.
* Delete the session key for [peerId], if it exists under [nodeId].
*/
@Throws(KeyStoreBackendException::class)
abstract suspend fun delete(peerId: String)
abstract suspend fun delete(nodeId: String, peerId: String)

@Throws(KeyStoreBackendException::class)
protected abstract suspend fun saveKeyData(
keyData: SessionPublicKeyData,
nodeId: String,
peerId: String
)

@Throws(KeyStoreBackendException::class)
protected abstract suspend fun retrieveKeyData(peerId: String):
protected abstract suspend fun retrieveKeyData(nodeId: String, peerId: String):
SessionPublicKeyData?
}
3 changes: 2 additions & 1 deletion src/main/kotlin/tech/relaycorp/relaynet/nodes/NodeManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ abstract class NodeManager<P : Payload>(
peerId: String,
nodeId: String,
): ByteArray {
val recipientSessionKey = sessionPublicKeyStore.retrieve(peerId)
val recipientSessionKey = sessionPublicKeyStore.retrieve(nodeId, peerId)
val senderSessionKeyPair = generateSessionKeyPair(nodeId, peerId)
return payload.encrypt(
recipientSessionKey,
Expand Down Expand Up @@ -74,6 +74,7 @@ abstract class NodeManager<P : Payload>(
val unwrapping = message.unwrapPayload(privateKeyStore)
sessionPublicKeyStore.save(
unwrapping.peerSessionKey,
message.recipient.id,
message.senderCertificate.subjectId,
message.creationDate
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ import java.time.ZoneOffset.UTC
import java.time.ZonedDateTime
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.relaynet.SessionKeyPair
import tech.relaycorp.relaynet.utils.MockSessionPublicKeyStore

@OptIn(ExperimentalCoroutinesApi::class)
class SessionPublicKeyStoreTest {
private val nodeId = "0deadc0de"
private val peerId = "0deadbeef"
private val storeKey = "$nodeId,$peerId"
private val creationTime: ZonedDateTime = ZonedDateTime.now(UTC).withNano(0)

private val sessionKeyGeneration = SessionKeyPair.generate()
Expand All @@ -27,10 +27,10 @@ class SessionPublicKeyStoreTest {
fun `Key data should be saved if there is no prior key for recipient`() = runTest {
val store = MockSessionPublicKeyStore()

store.save(sessionKey, peerId)
store.save(sessionKey, nodeId, peerId)

assertTrue(store.keys.containsKey(peerId))
val keyData = store.keys[peerId]!!
assertTrue(store.keys.containsKey(storeKey))
val keyData = store.keys[storeKey]!!
assertEquals(sessionKey.keyId.asList(), keyData.keyId.asList())
assertEquals(sessionKey.publicKey.encoded.asList(), keyData.keyDer.asList())
}
Expand All @@ -39,11 +39,11 @@ class SessionPublicKeyStoreTest {
fun `Key data should be saved if prior key is older`() = runTest {
val store = MockSessionPublicKeyStore()
val (oldSessionKey) = SessionKeyPair.generate()
store.save(oldSessionKey, peerId, creationTime.minusSeconds(1))
store.save(oldSessionKey, nodeId, peerId, creationTime.minusSeconds(1))

store.save(sessionKey, peerId, creationTime)
store.save(sessionKey, nodeId, peerId, creationTime)

val keyData = store.keys[peerId]!!
val keyData = store.keys[storeKey]!!
assertEquals(sessionKey.keyId.asList(), keyData.keyId.asList())
assertEquals(sessionKey.publicKey.encoded.asList(), keyData.keyDer.asList())
assertEquals(creationTime.toEpochSecond(), keyData.creationTimestamp)
Expand All @@ -52,12 +52,12 @@ class SessionPublicKeyStoreTest {
@Test
fun `Key data should not be saved if prior key is newer`() = runTest {
val store = MockSessionPublicKeyStore()
store.save(sessionKey, peerId, creationTime)
store.save(sessionKey, nodeId, peerId, creationTime)

val (oldSessionKey) = SessionKeyPair.generate()
store.save(oldSessionKey, peerId, creationTime.minusSeconds(1))
store.save(oldSessionKey, nodeId, peerId, creationTime.minusSeconds(1))

val keyData = store.keys[peerId]!!
val keyData = store.keys[storeKey]!!
assertEquals(sessionKey.keyId.asList(), keyData.keyId.asList())
assertEquals(sessionKey.publicKey.encoded.asList(), keyData.keyDer.asList())
assertEquals(creationTime.toEpochSecond(), keyData.creationTimestamp)
Expand All @@ -70,9 +70,9 @@ class SessionPublicKeyStoreTest {
val now = ZonedDateTime.now(UTC)
val store = MockSessionPublicKeyStore()

store.save(sessionKey, peerId)
store.save(sessionKey, nodeId, peerId)

val keyData = store.keys[peerId]!!
val keyData = store.keys[storeKey]!!
val creationTimestamp = keyData.creationTimestamp
assertTrue(now.toEpochSecond() <= creationTimestamp)
assertTrue(creationTimestamp <= ZonedDateTime.now(UTC).toEpochSecond())
Expand All @@ -82,9 +82,9 @@ class SessionPublicKeyStoreTest {
fun `Any explicit time should be honored`() = runTest {
val store = MockSessionPublicKeyStore()

store.save(sessionKey, peerId, creationTime)
store.save(sessionKey, nodeId, peerId, creationTime)

val keyData = store.keys[peerId]!!
val keyData = store.keys[storeKey]!!
assertEquals(creationTime.toEpochSecond(), keyData.creationTimestamp)
}

Expand All @@ -93,9 +93,9 @@ class SessionPublicKeyStoreTest {
val creationTime = ZonedDateTime.now(ZoneId.of("America/Caracas")).minusDays(1)
val store = MockSessionPublicKeyStore()

store.save(sessionKey, peerId, creationTime)
store.save(sessionKey, nodeId, peerId, creationTime)

val keyData = store.keys[peerId]!!
val keyData = store.keys[storeKey]!!
assertEquals(
creationTime.withZoneSameInstant(UTC).toEpochSecond(),
keyData.creationTimestamp
Expand All @@ -109,9 +109,9 @@ class SessionPublicKeyStoreTest {
@Test
fun `Key data should be returned if key for recipient exists`() = runTest {
val store = MockSessionPublicKeyStore()
store.save(sessionKey, peerId, creationTime)
store.save(sessionKey, nodeId, peerId, creationTime)

val fetchedSessionKey = store.retrieve(peerId)
val fetchedSessionKey = store.retrieve(nodeId, peerId)

assertEquals(sessionKey, fetchedSessionKey)
}
Expand All @@ -121,9 +121,9 @@ class SessionPublicKeyStoreTest {
val store = MockSessionPublicKeyStore()

val exception =
assertThrows<MissingKeyException> { (store.retrieve(peerId)) }
assertThrows<MissingKeyException> { (store.retrieve(nodeId, peerId)) }

assertEquals("There is no session key for $peerId", exception.message)
assertEquals("Node $nodeId has no session key for $peerId", exception.message)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class NodeManagerTest {
inner class WrapMessagePayload {
@BeforeEach
fun registerPeerSessionKey() = runTest {
publicKeyStore.save(peerSessionKey, peerId)
publicKeyStore.save(peerSessionKey, nodeId, peerId)
}

@Test
Expand All @@ -144,7 +144,7 @@ class NodeManagerTest {
manager.wrapMessagePayload(payload, peerId, nodeId)
}

assertEquals("There is no session key for $peerId", exception.message)
assertEquals("Node $nodeId has no session key for $peerId", exception.message)
}

@Test
Expand Down Expand Up @@ -328,8 +328,8 @@ class NodeManagerTest {

manager.unwrapMessagePayload(message)

assertEquals(peerSessionKey, publicKeyStore.retrieve(peerId))
val storedKey = publicKeyStore.keys[peerId]!!
assertEquals(peerSessionKey, publicKeyStore.retrieve(nodeId, peerId))
val storedKey = publicKeyStore.keys["$nodeId,$peerId"]!!
assertEquals(message.creationDate.toEpochSecond(), storedKey.creationTimestamp)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,26 @@ class MockSessionPublicKeyStore(
keys.clear()
}

override suspend fun saveKeyData(keyData: SessionPublicKeyData, peerId: String) {
override suspend fun saveKeyData(
keyData: SessionPublicKeyData,
nodeId: String,
peerId: String,
) {
if (savingException != null) {
throw savingException
}
this.keys[peerId] = keyData
this.keys["$nodeId,$peerId"] = keyData
}

override suspend fun retrieveKeyData(peerId: String): SessionPublicKeyData? {
override suspend fun retrieveKeyData(nodeId: String, peerId: String): SessionPublicKeyData? {
if (retrievalException != null) {
throw retrievalException
}

return keys[peerId]
return keys["$nodeId,$peerId"]
}

override suspend fun delete(peerId: String) {
keys.remove(peerId)
override suspend fun delete(nodeId: String, peerId: String) {
keys.remove("$nodeId,$peerId")
}
}

0 comments on commit af58de9

Please sign in to comment.