diff --git a/app/src/main/java/tech/relaycorp/gateway/data/disk/CargoStorage.kt b/app/src/main/java/tech/relaycorp/gateway/data/disk/CargoStorage.kt index fa7f397d..f04508af 100644 --- a/app/src/main/java/tech/relaycorp/gateway/data/disk/CargoStorage.kt +++ b/app/src/main/java/tech/relaycorp/gateway/data/disk/CargoStorage.kt @@ -27,7 +27,7 @@ class CargoStorage } try { - cargo.validate(RecipientAddressType.PRIVATE, setOf(localConfig.getCargoDeliveryAuth())) + cargo.validate(RecipientAddressType.PRIVATE, localConfig.getAllValidCargoDeliveryAuth()) } catch (exc: RelaynetException) { logger.warning("Invalid cargo received: ${exc.message}") throw Exception.InvalidCargo(null, exc) diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt b/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt index 135d67d7..b3e355e4 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt @@ -1,8 +1,9 @@ package tech.relaycorp.gateway.domain +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import tech.relaycorp.gateway.common.nowInUtc import tech.relaycorp.gateway.common.toPublicKey -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences import tech.relaycorp.gateway.domain.courier.CalculateCRCMessageCreationDate import tech.relaycorp.relaynet.issueGatewayCertificate @@ -17,15 +18,17 @@ import java.security.PrivateKey import java.security.PublicKey import javax.inject.Inject import javax.inject.Provider +import kotlin.time.Duration.Companion.days import kotlin.time.toJavaDuration class LocalConfig @Inject constructor( - private val fileStore: FileStore, private val privateKeyStore: Provider, private val certificateStore: Provider, private val publicGatewayPreferences: PublicGatewayPreferences ) { + private val mutex = Mutex() + // Private Gateway Key Pair suspend fun getIdentityKey(): PrivateKey = @@ -70,26 +73,41 @@ class LocalConfig return certificate } - @Synchronized suspend fun bootstrap() { - try { - getIdentityKey() - } catch (_: RuntimeException) { - val keyPair = generateIdentityKeyPair() - generateIdentityCertificate(keyPair.private) - } - - try { - getCargoDeliveryAuth() - } catch (_: RuntimeException) { - generateCargoDeliveryAuth() + mutex.withLock { + try { + getIdentityKey() + } catch (_: RuntimeException) { + val keyPair = generateIdentityKeyPair() + generateIdentityCertificate(keyPair.private) + } + + getCargoDeliveryAuth() // Generates new CDA if non-existent } } suspend fun getCargoDeliveryAuth() = - fileStore.read(CDA_CERTIFICATE_FILE_NAME) - ?.let { Certificate.deserialize(it) } - ?: throw RuntimeException("No CDA issuer was found") + certificateStore.get() + .retrieveLatest( + getIdentityKey().privateAddress, + getIdentityCertificate().subjectPrivateAddress + ) + ?.leafCertificate + .let { storedCertificate -> + if (storedCertificate?.isExpiringSoon() == false) { + storedCertificate + } else { + generateCargoDeliveryAuth() + } + } + + suspend fun getAllValidCargoDeliveryAuth() = + certificateStore.get() + .retrieveAll( + getIdentityKey().privateAddress, + getIdentityCertificate().subjectPrivateAddress + ) + .map { it.leafCertificate } private fun selfIssueCargoDeliveryAuth( privateKey: PrivateKey, @@ -100,15 +118,16 @@ class LocalConfig issuerPrivateKey = privateKey, validityStartDate = nowInUtc() .minus(CalculateCRCMessageCreationDate.CLOCK_DRIFT_TOLERANCE.toJavaDuration()), - validityEndDate = nowInUtc().plusYears(1) + validityEndDate = nowInUtc().plusMonths(6) ) } - private suspend fun generateCargoDeliveryAuth() { + private suspend fun generateCargoDeliveryAuth(): Certificate { val key = getIdentityKey() val certificate = getIdentityCertificate() val cda = selfIssueCargoDeliveryAuth(key, certificate.subjectPublicKey) - fileStore.store(CDA_CERTIFICATE_FILE_NAME, cda.serialize()) + certificateStore.get().save(cda, emptyList(), certificate.subjectPrivateAddress) + return cda } suspend fun deleteExpiredCertificates() { @@ -118,9 +137,10 @@ class LocalConfig private suspend fun getPublicGatewayPrivateAddress() = publicGatewayPreferences.getPrivateAddress() - // Helpers + private fun Certificate.isExpiringSoon() = + expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds)) companion object { - internal const val CDA_CERTIFICATE_FILE_NAME = "cda_local_gateway.certificate" + private val CERTIFICATE_EXPIRING_THRESHOLD = 90.days } } diff --git a/app/src/test/java/tech/relaycorp/gateway/data/disk/CargoStorageTest.kt b/app/src/test/java/tech/relaycorp/gateway/data/disk/CargoStorageTest.kt index aed0900c..9ac0751e 100644 --- a/app/src/test/java/tech/relaycorp/gateway/data/disk/CargoStorageTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/data/disk/CargoStorageTest.kt @@ -41,8 +41,8 @@ internal class CargoStorageTest { @Test fun `Valid cargo bound for a public gateway should be refused`() = runBlockingTest { - whenever(mockLocalConfig.getCargoDeliveryAuth()) - .thenReturn(CargoDeliveryCertPath.PRIVATE_GW) + whenever(mockLocalConfig.getAllValidCargoDeliveryAuth()) + .thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW)) val cargo = Cargo( "https://foo.relaycorp.tech", @@ -62,8 +62,8 @@ internal class CargoStorageTest { @Test fun `Well-formed but unauthorized cargo should be refused`() = runBlockingTest { - whenever(mockLocalConfig.getCargoDeliveryAuth()) - .thenReturn(CargoDeliveryCertPath.PRIVATE_GW) + whenever(mockLocalConfig.getAllValidCargoDeliveryAuth()) + .thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW)) val unauthorizedSenderKeyPair = generateRSAKeyPair() val unauthorizedSenderCert = issueGatewayCertificate( @@ -88,8 +88,8 @@ internal class CargoStorageTest { @Test fun `Authorized cargo should be accepted`() = runBlockingTest { - whenever(mockLocalConfig.getCargoDeliveryAuth()) - .thenReturn(CargoDeliveryCertPath.PRIVATE_GW) + whenever(mockLocalConfig.getAllValidCargoDeliveryAuth()) + .thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW)) val cargoSerialized = CargoFactory.buildSerialized() cargoStorage.store(cargoSerialized.inputStream()) diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt index 5dcbfa30..6c2738a2 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt @@ -7,39 +7,28 @@ import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runBlockingTest +import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences import tech.relaycorp.gateway.test.BaseDataTestCase import tech.relaycorp.relaynet.testing.pki.PDACertPath import kotlin.test.assertEquals +import kotlin.test.assertNotNull class LocalConfigTest : BaseDataTestCase() { - private val fileStore = mock() private val publicGatewayPreferences = mock() private val localConfig = LocalConfig( - fileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences + privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences ) @BeforeEach fun setUp() { runBlocking { - val memoryStore = mutableMapOf() - whenever(fileStore.store(any(), any())).then { - val key = it.getArgument(0) - val value = it.getArgument(1) as ByteArray - memoryStore[key] = value - Unit - } - whenever(fileStore.read(any())).thenAnswer { - val key = it.getArgument(0) - memoryStore[key] - } whenever(publicGatewayPreferences.getPrivateAddress()) .thenReturn(PDACertPath.PUBLIC_GW.subjectPrivateAddress) } @@ -79,12 +68,11 @@ class LocalConfigTest : BaseDataTestCase() { } @Test - fun `Exception should be thrown if certificate does not exist yet`() = runBlockingTest { - val exception = assertThrows { - localConfig.getCargoDeliveryAuth() - } + fun `New certificate is generated if none exists`() = runBlockingTest { + localConfig.bootstrap() + certificateStore.clear() - assertEquals("No CDA issuer was found", exception.message) + assertNotNull(localConfig.getCargoDeliveryAuth()) } } @@ -139,7 +127,7 @@ class LocalConfigTest : BaseDataTestCase() { localConfig.bootstrap() val cdaIssuer = localConfig.getCargoDeliveryAuth() - assertEquals(originalCDAIssuer, cdaIssuer) + assertArrayEquals(originalCDAIssuer.serialize(), cdaIssuer.serialize()) } } diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt index 979832f4..33dea112 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt @@ -1,31 +1,33 @@ package tech.relaycorp.gateway.domain.courier +import com.nhaarman.mockitokotlin2.any import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runBlockingTest +import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import tech.relaycorp.gateway.common.nowInUtc -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences import tech.relaycorp.gateway.domain.LocalConfig import tech.relaycorp.gateway.test.BaseDataTestCase import tech.relaycorp.relaynet.issueGatewayCertificate +import tech.relaycorp.relaynet.keystores.CertificationPath import tech.relaycorp.relaynet.messages.CargoCollectionAuthorization import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath +import tech.relaycorp.relaynet.wrappers.privateAddress import java.time.Duration class GenerateCCATest : BaseDataTestCase() { private val publicGatewayPreferences = mock() - private val mockFileStore = mock() private val localConfig = LocalConfig( - mockFileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences + privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences ) private val calculateCreationDate = mock() @@ -46,14 +48,14 @@ class GenerateCCATest : BaseDataTestCase() { registerPrivateGatewayIdentity() val keyPair = KeyPairSet.PRIVATE_GW - val certificate = issueGatewayCertificate( + val cda = issueGatewayCertificate( subjectPublicKey = keyPair.public, issuerPrivateKey = keyPair.private, validityEndDate = nowInUtc().plusMinutes(1), validityStartDate = nowInUtc().minusDays(1) ) - whenever(mockFileStore.read(eq(LocalConfig.CDA_CERTIFICATE_FILE_NAME))) - .thenReturn(certificate.serialize()) + whenever(certificateStore.retrieveLatest(any(), eq(keyPair.public.privateAddress))) + .thenReturn(CertificationPath(cda, emptyList())) whenever(publicGatewayPreferences.getPrivateAddress()) .thenReturn(PDACertPath.PUBLIC_GW.subjectPrivateAddress) @@ -75,7 +77,7 @@ class GenerateCCATest : BaseDataTestCase() { cca.validate(null) assertEquals(ADDRESS, cca.recipientAddress) - assertEquals(PDACertPath.PRIVATE_GW, cca.senderCertificate) + assertArrayEquals(PDACertPath.PRIVATE_GW.serialize(), cca.senderCertificate.serialize()) assertTrue(Duration.between(creationDate, cca.creationDate).abs().seconds <= 1) // Check it was encrypted with the public gateway's session key diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt index ea1c6a5d..eed06a3f 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt @@ -13,7 +13,6 @@ import tech.relaycorp.gateway.common.nowInUtc import tech.relaycorp.gateway.data.database.ParcelCollectionDao import tech.relaycorp.gateway.data.database.StoredParcelDao import tech.relaycorp.gateway.data.disk.DiskMessageOperations -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences import tech.relaycorp.gateway.domain.LocalConfig import tech.relaycorp.gateway.test.BaseDataTestCase @@ -31,9 +30,8 @@ class GenerateCargoTest : BaseDataTestCase() { private val parcelCollectionDao = mock() private val diskMessageOperations = mock() private val publicGatewayPreferences = mock() - private val mockFileStore = mock() private val localConfig = LocalConfig( - mockFileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences + privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences ) private val calculateCRCMessageCreationDate = mock() private val generateCargo = GenerateCargo( diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt index ff43a107..764deedc 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt @@ -9,7 +9,6 @@ import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import tech.relaycorp.gateway.data.database.LocalEndpointDao -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.model.LocalEndpoint import tech.relaycorp.gateway.data.model.PrivateMessageAddress import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences @@ -30,11 +29,9 @@ import kotlin.test.assertTrue class EndpointRegistrationTest : BaseDataTestCase() { private val mockLocalEndpointDao = mock() - private val mockFileStore = mock() private val mockPublicGatewayPreferences = mock() private val mockLocalConfig = LocalConfig( - mockFileStore, privateKeyStoreProvider, certificateStoreProvider, - mockPublicGatewayPreferences + privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences ) private val endpointRegistration = EndpointRegistration(mockLocalEndpointDao, mockLocalConfig) diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt index 902671f6..fb04b6f9 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt @@ -20,7 +20,6 @@ import kotlinx.coroutines.test.runBlockingTest import org.junit.Assert.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException import tech.relaycorp.gateway.data.model.MessageAddress import tech.relaycorp.gateway.data.model.RecipientLocation @@ -48,11 +47,9 @@ class CollectParcelsFromGatewayTest : BaseDataTestCase() { private val poWebClientBuilder = object : PoWebClientProvider { override suspend fun get() = poWebClient } - private val mockFileStore = mock() private val mockPublicGatewayPreferences = mock() private val mockLocalConfig = LocalConfig( - mockFileStore, privateKeyStoreProvider, certificateStoreProvider, - mockPublicGatewayPreferences + privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences ) private val notifyEndpoints = mock() private val subject = CollectParcelsFromGateway( diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt index 78ed9f7f..b6a38617 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt @@ -16,7 +16,6 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import tech.relaycorp.gateway.data.database.StoredParcelDao import tech.relaycorp.gateway.data.disk.DiskMessageOperations -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.disk.MessageDataNotFoundException import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences @@ -39,11 +38,9 @@ class DeliverParcelsToGatewayTest : BaseDataTestCase() { private val poWebClientProvider = object : PoWebClientProvider { override suspend fun get() = poWebClient } - private val mockFileStore = mock() private val mockPublicGatewayPreferences = mock() private val localConfig = LocalConfig( - mockFileStore, privateKeyStoreProvider, certificateStoreProvider, - mockPublicGatewayPreferences + privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences ) private val deleteParcel = mock() private val subject = DeliverParcelsToGateway( diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt index 9a9443c2..20109fc3 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt @@ -10,7 +10,6 @@ import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.test.runBlockingTest import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test -import tech.relaycorp.gateway.data.disk.FileStore import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException import tech.relaycorp.gateway.data.doh.ResolveServiceAddress import tech.relaycorp.gateway.data.model.RegistrationState @@ -35,9 +34,8 @@ import kotlin.test.assertEquals class RegisterGatewayTest : BaseDataTestCase() { private val pgwPreferences = mock() - private val mockFileStore = mock() private val localConfig = LocalConfig( - mockFileStore, privateKeyStoreProvider, certificateStoreProvider, pgwPreferences + privateKeyStoreProvider, certificateStoreProvider, pgwPreferences ) private val poWebClient = mock() private val poWebClientBuilder = object : PoWebClientBuilder {