From 4d6a51fb1ef770dbdf5e396dcc45e8c9cf4ee922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Santos?= Date: Mon, 30 Oct 2023 12:15:07 +0000 Subject: [PATCH] Register gateway on sync --- .../main/java/tech/relaycorp/gateway/App.kt | 9 +-- .../gateway/domain/publicsync/PublicSync.kt | 14 ++-- .../domain/publicsync/RegisterGateway.kt | 65 ++++++++++++------- .../domain/publicsync/PublicSyncTest.kt | 39 +++++------ .../domain/publicsync/RegisterGatewayTest.kt | 20 ++---- .../tech/relaycorp/gateway/test/PNRFactory.kt | 14 ++++ 6 files changed, 91 insertions(+), 70 deletions(-) create mode 100644 app/src/test/java/tech/relaycorp/gateway/test/PNRFactory.kt diff --git a/app/src/main/java/tech/relaycorp/gateway/App.kt b/app/src/main/java/tech/relaycorp/gateway/App.kt index 64db00ad..dd5ba17c 100644 --- a/app/src/main/java/tech/relaycorp/gateway/App.kt +++ b/app/src/main/java/tech/relaycorp/gateway/App.kt @@ -23,7 +23,6 @@ import tech.relaycorp.gateway.common.di.AppComponent import tech.relaycorp.gateway.common.di.DaggerAppComponent import tech.relaycorp.gateway.domain.LocalConfig import tech.relaycorp.gateway.domain.publicsync.PublicSync -import tech.relaycorp.gateway.domain.publicsync.RegisterGateway import java.security.Security import java.time.Duration import java.util.logging.Level @@ -56,9 +55,6 @@ open class App : Application() { @Inject lateinit var localConfig: LocalConfig - @Inject - lateinit var registerGateway: RegisterGateway - @Inject lateinit var publicSync: PublicSync @@ -77,8 +73,8 @@ open class App : Application() { backgroundScope.launch { bootstrapGateway() - startPublicSyncWhenPossible() - deleteExpiredCertificates() + launch { startPublicSyncWhenPossible() } + launch { deleteExpiredCertificates() } } registerActivityLifecycleCallbacks(foregroundAppMonitor) @@ -132,7 +128,6 @@ open class App : Application() { private suspend fun bootstrapGateway() { if (mode != Mode.Test) { localConfig.bootstrap() - registerGateway.registerIfNeeded() } } diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/PublicSync.kt b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/PublicSync.kt index 803a9dbb..f97a7624 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/PublicSync.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/PublicSync.kt @@ -13,12 +13,14 @@ import tech.relaycorp.gateway.background.ConnectionState import tech.relaycorp.gateway.background.ConnectionStateObserver import tech.relaycorp.gateway.background.ForegroundAppMonitor import tech.relaycorp.gateway.common.Logging.logger +import tech.relaycorp.gateway.common.interval import tech.relaycorp.gateway.data.model.RegistrationState import tech.relaycorp.gateway.data.preference.InternetGatewayPreferences import tech.relaycorp.gateway.pdc.local.PDCServer import tech.relaycorp.gateway.pdc.local.PDCServerStateManager import javax.inject.Inject import javax.inject.Singleton +import kotlin.time.Duration.Companion.minutes import kotlin.time.Duration.Companion.seconds @Singleton @@ -28,6 +30,7 @@ class PublicSync private val pdcServerStateManager: PDCServerStateManager, private val internetGatewayPreferences: InternetGatewayPreferences, private val connectionStateObserver: ConnectionStateObserver, + private val registerGateway: RegisterGateway, private val deliverParcelsToGateway: DeliverParcelsToGateway, private val collectParcelsFromGateway: CollectParcelsFromGateway ) { @@ -42,11 +45,11 @@ class PublicSync combine( foregroundAppMonitor.observe(), pdcServerStateManager.observe(), - internetGatewayPreferences.observeRegistrationState(), - connectionStateObserver.observe() - ) { foregroundState, pdcState, registrationState, connectionState -> + connectionStateObserver.observe(), + // Retry registration and sync every minute in case there's a failure + interval(1.minutes) + ) { foregroundState, pdcState, connectionState, _ -> if ( - registrationState == RegistrationState.Done && connectionState is ConnectionState.InternetWithGateway && ( foregroundState == ForegroundAppMonitor.State.Foreground || pdcState == PDCServer.State.Started @@ -71,8 +74,9 @@ class PublicSync collectParcelsFromGateway.collect(false) } - private fun startSync() { + private suspend fun startSync() { if (isSyncing) return + if (!registerGateway.registerIfNeeded().isSuccessful) return logger.info("Starting public sync") val syncJob = Job() diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt index 3c4591e7..13797716 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt @@ -1,5 +1,7 @@ package tech.relaycorp.gateway.domain.publicsync +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import tech.relaycorp.gateway.common.Logging.logger import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException import tech.relaycorp.gateway.data.doh.ResolveServiceAddress @@ -16,7 +18,9 @@ import java.time.Duration import java.time.ZonedDateTime import java.util.logging.Level import javax.inject.Inject +import javax.inject.Singleton +@Singleton @JvmSuppressWildcards class RegisterGateway @Inject constructor( @@ -28,36 +32,42 @@ class RegisterGateway private val publicKeyStore: SessionPublicKeyStore ) { + private val mutex = Mutex() + suspend fun registerIfNeeded(): Result { - val isFirstRegistration = - internetGatewayPreferences.getRegistrationState() == RegistrationState.ToDo - - if ( - !isFirstRegistration && - !currentCertificateIsAboutToExpire() - ) { - return Result.AlreadyRegisteredAndNotExpiring - } + mutex.withLock { + val isFirstRegistration = + internetGatewayPreferences.getRegistrationState() == RegistrationState.ToDo + + if ( + !isFirstRegistration && + !currentCertificateIsAboutToExpire() + ) { + return Result.AlreadyRegisteredAndNotExpiring + } - val address = internetGatewayPreferences.getAddress() - val result = register(address) - if (result is Result.Registered) { - saveSuccessfulResult(address, result.pnr) + val address = internetGatewayPreferences.getAddress() + val result = register(address) + if (result is Result.Registered) { + saveSuccessfulResult(address, result.pnr) - if (!isFirstRegistration) { - gatewayCertificateChangeNotifier.notifyAll() + if (!isFirstRegistration) { + gatewayCertificateChangeNotifier.notifyAll() + } } - } - return result + return result + } } suspend fun registerNewAddress(newAddress: String): Result { - val result = register(newAddress) - if (result is Result.Registered) { - saveSuccessfulResult(newAddress, result.pnr) + mutex.withLock { + val result = register(newAddress) + if (result is Result.Registered) { + saveSuccessfulResult(newAddress, result.pnr) + } + return result } - return result } private suspend fun currentCertificateIsAboutToExpire() = @@ -65,6 +75,8 @@ class RegisterGateway private suspend fun register(address: String): Result { return try { + logger.info("Registering with $address") + val poWebAddress = resolveServiceAddress.resolvePoWeb(address) val poWeb = poWebClientBuilder.build(poWebAddress) val privateKey = localConfig.getIdentityKey() @@ -114,13 +126,16 @@ class RegisterGateway } sealed class Result { - object FailedToResolve : Result() - object FailedToRegister : Result() + data object FailedToResolve : Result() + data object FailedToRegister : Result() data class Registered(val pnr: PrivateNodeRegistration) : Result() - object AlreadyRegisteredAndNotExpiring : Result() + data object AlreadyRegisteredAndNotExpiring : Result() + + val isSuccessful + get() = this is Registered || this is AlreadyRegisteredAndNotExpiring } companion object { - private val ABOUT_TO_EXPIRE = Duration.ofDays(90) + private val ABOUT_TO_EXPIRE = Duration.ofDays(25) } } diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/PublicSyncTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/PublicSyncTest.kt index 43a5a195..2f259df9 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/PublicSyncTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/PublicSyncTest.kt @@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test import tech.relaycorp.gateway.background.ConnectionState import tech.relaycorp.gateway.background.ConnectionStateObserver import tech.relaycorp.gateway.background.ForegroundAppMonitor -import tech.relaycorp.gateway.data.model.RegistrationState import tech.relaycorp.gateway.data.preference.InternetGatewayPreferences import tech.relaycorp.gateway.pdc.local.PDCServer import tech.relaycorp.gateway.pdc.local.PDCServerStateManager @@ -27,6 +26,7 @@ class PublicSyncTest { private val pdcServerStateManager = mock() private val internetGatewayPreferences = mock() private val connectionStateObserver = mock() + private val registerGateway = mock() private val deliverParcelsToGateway = mock() private val collectParcelsFromGateway = mock() private val publicSync = PublicSync( @@ -34,17 +34,18 @@ class PublicSyncTest { pdcServerStateManager, internetGatewayPreferences, connectionStateObserver, + registerGateway, deliverParcelsToGateway, collectParcelsFromGateway ) @Test - internal fun `does not sync if gateway is not registered`() = testSuspend { + internal fun `does not sync if gateway is failed to register`() = testSuspend { setState( ForegroundAppMonitor.State.Foreground, PDCServer.State.Started, - RegistrationState.ToDo, - ConnectionState.InternetWithGateway + ConnectionState.InternetWithGateway, + RegisterGateway.Result.FailedToRegister ) sync() @@ -57,8 +58,8 @@ class PublicSyncTest { setState( ForegroundAppMonitor.State.Foreground, PDCServer.State.Started, - RegistrationState.ToDo, - ConnectionState.Disconnected + ConnectionState.Disconnected, + RegisterGateway.Result.AlreadyRegisteredAndNotExpiring ) sync() @@ -71,8 +72,8 @@ class PublicSyncTest { setState( ForegroundAppMonitor.State.Background, PDCServer.State.Stopped, - RegistrationState.Done, - ConnectionState.InternetWithGateway + ConnectionState.InternetWithGateway, + RegisterGateway.Result.AlreadyRegisteredAndNotExpiring ) sync() @@ -85,8 +86,8 @@ class PublicSyncTest { setState( ForegroundAppMonitor.State.Foreground, PDCServer.State.Stopped, - RegistrationState.Done, - ConnectionState.InternetWithGateway + ConnectionState.InternetWithGateway, + RegisterGateway.Result.AlreadyRegisteredAndNotExpiring ) sync() @@ -99,8 +100,8 @@ class PublicSyncTest { setState( ForegroundAppMonitor.State.Background, PDCServer.State.Started, - RegistrationState.Done, - ConnectionState.InternetWithGateway + ConnectionState.InternetWithGateway, + RegisterGateway.Result.AlreadyRegisteredAndNotExpiring ) sync() @@ -113,8 +114,8 @@ class PublicSyncTest { setState( ForegroundAppMonitor.State.Background, PDCServer.State.Stopped, - RegistrationState.Done, - ConnectionState.InternetWithGateway + ConnectionState.InternetWithGateway, + RegisterGateway.Result.AlreadyRegisteredAndNotExpiring ) val appStateFlow = MutableStateFlow(ForegroundAppMonitor.State.Background) whenever(foregroundAppMonitor.observe()).thenReturn(appStateFlow.asSharedFlow()) @@ -128,20 +129,20 @@ class PublicSyncTest { waitForAssertEquals(false) { publicSync.isSyncing } } - private fun setState( + private suspend fun setState( appState: ForegroundAppMonitor.State, pdcState: PDCServer.State, - registrationState: RegistrationState, - connectionState: ConnectionState + connectionState: ConnectionState, + registrationResult: RegisterGateway.Result ) { whenever(foregroundAppMonitor.observe()) .thenReturn(flowOf(appState)) whenever(pdcServerStateManager.observe()) .thenReturn(flowOf(pdcState)) - whenever(internetGatewayPreferences.observeRegistrationState()) - .thenReturn(flowOf(registrationState)) whenever(connectionStateObserver.observe()) .thenReturn(flowOf(connectionState)) + whenever(registerGateway.registerIfNeeded()) + .thenReturn(registrationResult) } private fun sync() { diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt index 0798e365..b0306258 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt @@ -19,11 +19,10 @@ import tech.relaycorp.gateway.domain.LocalConfig import tech.relaycorp.gateway.domain.endpoint.GatewayCertificateChangeNotifier import tech.relaycorp.gateway.pdc.PoWebClientBuilder import tech.relaycorp.gateway.test.BaseDataTestCase +import tech.relaycorp.gateway.test.PNRFactory import tech.relaycorp.poweb.PoWebClient -import tech.relaycorp.relaynet.SessionKey import tech.relaycorp.relaynet.bindings.pdc.ClientBindingException import tech.relaycorp.relaynet.issueGatewayCertificate -import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistration import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistrationAuthorization import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistrationRequest import tech.relaycorp.relaynet.testing.pki.KeyPairSet @@ -113,7 +112,7 @@ class RegisterGatewayTest : BaseDataTestCase() { whenever(poWebClient.preRegisterNode(any())) .thenReturn(buildPNRR()) whenever(poWebClient.registerNode(any())) - .thenReturn(buildPNR(internetGatewaySessionKeyPair.sessionKey)) + .thenReturn(PNRFactory.build(internetGatewaySessionKeyPair.sessionKey)) registerGateway.registerIfNeeded() @@ -126,7 +125,7 @@ class RegisterGatewayTest : BaseDataTestCase() { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) - val pnr = buildPNR(internetGatewaySessionKeyPair.sessionKey) + val pnr = PNRFactory.build(internetGatewaySessionKeyPair.sessionKey) whenever(poWebClient.registerNode(any())).thenReturn(pnr) registerGateway.registerIfNeeded() @@ -155,7 +154,7 @@ class RegisterGatewayTest : BaseDataTestCase() { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) - val pnr = buildPNR(null) + val pnr = PNRFactory.build(null) whenever(poWebClient.registerNode(any())).thenReturn(pnr) assertEquals(RegisterGateway.Result.FailedToRegister, registerGateway.registerIfNeeded()) @@ -172,7 +171,7 @@ class RegisterGatewayTest : BaseDataTestCase() { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.Done) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) - val pnr = buildPNR(internetGatewaySessionKeyPair.sessionKey) + val pnr = PNRFactory.build(internetGatewaySessionKeyPair.sessionKey) whenever(poWebClient.registerNode(any())).thenReturn(pnr) registerGateway.registerIfNeeded() @@ -185,7 +184,7 @@ class RegisterGatewayTest : BaseDataTestCase() { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) - val pnr = buildPNR(internetGatewaySessionKeyPair.sessionKey) + val pnr = PNRFactory.build(internetGatewaySessionKeyPair.sessionKey) whenever(poWebClient.registerNode(any())).thenReturn(pnr) registerGateway.registerIfNeeded() @@ -203,11 +202,4 @@ class RegisterGatewayTest : BaseDataTestCase() { authorization.serialize(KeyPairSet.PRIVATE_GW.private) ) } - - private fun buildPNR(internetGatewaySessionKey: SessionKey?) = PrivateNodeRegistration( - PDACertPath.PRIVATE_GW, - PDACertPath.INTERNET_GW, - "example.org", - internetGatewaySessionKey - ) } diff --git a/app/src/test/java/tech/relaycorp/gateway/test/PNRFactory.kt b/app/src/test/java/tech/relaycorp/gateway/test/PNRFactory.kt new file mode 100644 index 00000000..ec333603 --- /dev/null +++ b/app/src/test/java/tech/relaycorp/gateway/test/PNRFactory.kt @@ -0,0 +1,14 @@ +package tech.relaycorp.gateway.test + +import tech.relaycorp.relaynet.SessionKey +import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistration +import tech.relaycorp.relaynet.testing.pki.PDACertPath + +object PNRFactory { + fun build(internetGatewaySessionKey: SessionKey? = null) = PrivateNodeRegistration( + PDACertPath.PRIVATE_GW, + PDACertPath.INTERNET_GW, + "example.org", + internetGatewaySessionKey + ) +}