Skip to content

Commit

Permalink
Register gateway on sync
Browse files Browse the repository at this point in the history
  • Loading branch information
sdsantos committed Oct 30, 2023
1 parent e73ea1f commit 4d6a51f
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 70 deletions.
9 changes: 2 additions & 7 deletions app/src/main/java/tech/relaycorp/gateway/App.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import tech.relaycorp.gateway.common.di.AppComponent
import tech.relaycorp.gateway.common.di.DaggerAppComponent
import tech.relaycorp.gateway.domain.LocalConfig
import tech.relaycorp.gateway.domain.publicsync.PublicSync
import tech.relaycorp.gateway.domain.publicsync.RegisterGateway
import java.security.Security
import java.time.Duration
import java.util.logging.Level
Expand Down Expand Up @@ -56,9 +55,6 @@ open class App : Application() {
@Inject
lateinit var localConfig: LocalConfig

@Inject
lateinit var registerGateway: RegisterGateway

@Inject
lateinit var publicSync: PublicSync

Expand All @@ -77,8 +73,8 @@ open class App : Application() {

backgroundScope.launch {
bootstrapGateway()
startPublicSyncWhenPossible()
deleteExpiredCertificates()
launch { startPublicSyncWhenPossible() }
launch { deleteExpiredCertificates() }
}

registerActivityLifecycleCallbacks(foregroundAppMonitor)
Expand Down Expand Up @@ -132,7 +128,6 @@ open class App : Application() {
private suspend fun bootstrapGateway() {
if (mode != Mode.Test) {
localConfig.bootstrap()
registerGateway.registerIfNeeded()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import tech.relaycorp.gateway.background.ConnectionState
import tech.relaycorp.gateway.background.ConnectionStateObserver
import tech.relaycorp.gateway.background.ForegroundAppMonitor
import tech.relaycorp.gateway.common.Logging.logger
import tech.relaycorp.gateway.common.interval
import tech.relaycorp.gateway.data.model.RegistrationState
import tech.relaycorp.gateway.data.preference.InternetGatewayPreferences
import tech.relaycorp.gateway.pdc.local.PDCServer
import tech.relaycorp.gateway.pdc.local.PDCServerStateManager
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds

@Singleton
Expand All @@ -28,6 +30,7 @@ class PublicSync
private val pdcServerStateManager: PDCServerStateManager,
private val internetGatewayPreferences: InternetGatewayPreferences,
private val connectionStateObserver: ConnectionStateObserver,
private val registerGateway: RegisterGateway,
private val deliverParcelsToGateway: DeliverParcelsToGateway,
private val collectParcelsFromGateway: CollectParcelsFromGateway
) {
Expand All @@ -42,11 +45,11 @@ class PublicSync
combine(
foregroundAppMonitor.observe(),
pdcServerStateManager.observe(),
internetGatewayPreferences.observeRegistrationState(),
connectionStateObserver.observe()
) { foregroundState, pdcState, registrationState, connectionState ->
connectionStateObserver.observe(),
// Retry registration and sync every minute in case there's a failure
interval(1.minutes)
) { foregroundState, pdcState, connectionState, _ ->
if (
registrationState == RegistrationState.Done &&
connectionState is ConnectionState.InternetWithGateway && (
foregroundState == ForegroundAppMonitor.State.Foreground ||
pdcState == PDCServer.State.Started
Expand All @@ -71,8 +74,9 @@ class PublicSync
collectParcelsFromGateway.collect(false)
}

private fun startSync() {
private suspend fun startSync() {
if (isSyncing) return
if (!registerGateway.registerIfNeeded().isSuccessful) return

logger.info("Starting public sync")
val syncJob = Job()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package tech.relaycorp.gateway.domain.publicsync

import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import tech.relaycorp.gateway.common.Logging.logger
import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException
import tech.relaycorp.gateway.data.doh.ResolveServiceAddress
Expand All @@ -16,7 +18,9 @@ import java.time.Duration
import java.time.ZonedDateTime
import java.util.logging.Level
import javax.inject.Inject
import javax.inject.Singleton

@Singleton
@JvmSuppressWildcards
class RegisterGateway
@Inject constructor(
Expand All @@ -28,43 +32,51 @@ class RegisterGateway
private val publicKeyStore: SessionPublicKeyStore
) {

private val mutex = Mutex()

suspend fun registerIfNeeded(): Result {
val isFirstRegistration =
internetGatewayPreferences.getRegistrationState() == RegistrationState.ToDo

if (
!isFirstRegistration &&
!currentCertificateIsAboutToExpire()
) {
return Result.AlreadyRegisteredAndNotExpiring
}
mutex.withLock {
val isFirstRegistration =
internetGatewayPreferences.getRegistrationState() == RegistrationState.ToDo

if (
!isFirstRegistration &&
!currentCertificateIsAboutToExpire()
) {
return Result.AlreadyRegisteredAndNotExpiring
}

val address = internetGatewayPreferences.getAddress()
val result = register(address)
if (result is Result.Registered) {
saveSuccessfulResult(address, result.pnr)
val address = internetGatewayPreferences.getAddress()
val result = register(address)
if (result is Result.Registered) {
saveSuccessfulResult(address, result.pnr)

if (!isFirstRegistration) {
gatewayCertificateChangeNotifier.notifyAll()
if (!isFirstRegistration) {
gatewayCertificateChangeNotifier.notifyAll()
}
}
}

return result
return result
}
}

suspend fun registerNewAddress(newAddress: String): Result {
val result = register(newAddress)
if (result is Result.Registered) {
saveSuccessfulResult(newAddress, result.pnr)
mutex.withLock {
val result = register(newAddress)
if (result is Result.Registered) {
saveSuccessfulResult(newAddress, result.pnr)
}
return result
}
return result
}

private suspend fun currentCertificateIsAboutToExpire() =
localConfig.getIdentityCertificate().expiryDate < ZonedDateTime.now().plus(ABOUT_TO_EXPIRE)

private suspend fun register(address: String): Result {
return try {
logger.info("Registering with $address")

val poWebAddress = resolveServiceAddress.resolvePoWeb(address)
val poWeb = poWebClientBuilder.build(poWebAddress)
val privateKey = localConfig.getIdentityKey()
Expand Down Expand Up @@ -114,13 +126,16 @@ class RegisterGateway
}

sealed class Result {
object FailedToResolve : Result()
object FailedToRegister : Result()
data object FailedToResolve : Result()
data object FailedToRegister : Result()
data class Registered(val pnr: PrivateNodeRegistration) : Result()
object AlreadyRegisteredAndNotExpiring : Result()
data object AlreadyRegisteredAndNotExpiring : Result()

val isSuccessful
get() = this is Registered || this is AlreadyRegisteredAndNotExpiring
}

companion object {
private val ABOUT_TO_EXPIRE = Duration.ofDays(90)
private val ABOUT_TO_EXPIRE = Duration.ofDays(25)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.background.ConnectionState
import tech.relaycorp.gateway.background.ConnectionStateObserver
import tech.relaycorp.gateway.background.ForegroundAppMonitor
import tech.relaycorp.gateway.data.model.RegistrationState
import tech.relaycorp.gateway.data.preference.InternetGatewayPreferences
import tech.relaycorp.gateway.pdc.local.PDCServer
import tech.relaycorp.gateway.pdc.local.PDCServerStateManager
Expand All @@ -27,24 +26,26 @@ class PublicSyncTest {
private val pdcServerStateManager = mock<PDCServerStateManager>()
private val internetGatewayPreferences = mock<InternetGatewayPreferences>()
private val connectionStateObserver = mock<ConnectionStateObserver>()
private val registerGateway = mock<RegisterGateway>()
private val deliverParcelsToGateway = mock<DeliverParcelsToGateway>()
private val collectParcelsFromGateway = mock<CollectParcelsFromGateway>()
private val publicSync = PublicSync(
foregroundAppMonitor,
pdcServerStateManager,
internetGatewayPreferences,
connectionStateObserver,
registerGateway,
deliverParcelsToGateway,
collectParcelsFromGateway
)

@Test
internal fun `does not sync if gateway is not registered`() = testSuspend {
internal fun `does not sync if gateway is failed to register`() = testSuspend {
setState(
ForegroundAppMonitor.State.Foreground,
PDCServer.State.Started,
RegistrationState.ToDo,
ConnectionState.InternetWithGateway
ConnectionState.InternetWithGateway,
RegisterGateway.Result.FailedToRegister
)

sync()
Expand All @@ -57,8 +58,8 @@ class PublicSyncTest {
setState(
ForegroundAppMonitor.State.Foreground,
PDCServer.State.Started,
RegistrationState.ToDo,
ConnectionState.Disconnected
ConnectionState.Disconnected,
RegisterGateway.Result.AlreadyRegisteredAndNotExpiring
)

sync()
Expand All @@ -71,8 +72,8 @@ class PublicSyncTest {
setState(
ForegroundAppMonitor.State.Background,
PDCServer.State.Stopped,
RegistrationState.Done,
ConnectionState.InternetWithGateway
ConnectionState.InternetWithGateway,
RegisterGateway.Result.AlreadyRegisteredAndNotExpiring
)

sync()
Expand All @@ -85,8 +86,8 @@ class PublicSyncTest {
setState(
ForegroundAppMonitor.State.Foreground,
PDCServer.State.Stopped,
RegistrationState.Done,
ConnectionState.InternetWithGateway
ConnectionState.InternetWithGateway,
RegisterGateway.Result.AlreadyRegisteredAndNotExpiring
)

sync()
Expand All @@ -99,8 +100,8 @@ class PublicSyncTest {
setState(
ForegroundAppMonitor.State.Background,
PDCServer.State.Started,
RegistrationState.Done,
ConnectionState.InternetWithGateway
ConnectionState.InternetWithGateway,
RegisterGateway.Result.AlreadyRegisteredAndNotExpiring
)

sync()
Expand All @@ -113,8 +114,8 @@ class PublicSyncTest {
setState(
ForegroundAppMonitor.State.Background,
PDCServer.State.Stopped,
RegistrationState.Done,
ConnectionState.InternetWithGateway
ConnectionState.InternetWithGateway,
RegisterGateway.Result.AlreadyRegisteredAndNotExpiring
)
val appStateFlow = MutableStateFlow(ForegroundAppMonitor.State.Background)
whenever(foregroundAppMonitor.observe()).thenReturn(appStateFlow.asSharedFlow())
Expand All @@ -128,20 +129,20 @@ class PublicSyncTest {
waitForAssertEquals(false) { publicSync.isSyncing }
}

private fun setState(
private suspend fun setState(
appState: ForegroundAppMonitor.State,
pdcState: PDCServer.State,
registrationState: RegistrationState,
connectionState: ConnectionState
connectionState: ConnectionState,
registrationResult: RegisterGateway.Result
) {
whenever(foregroundAppMonitor.observe())
.thenReturn(flowOf(appState))
whenever(pdcServerStateManager.observe())
.thenReturn(flowOf(pdcState))
whenever(internetGatewayPreferences.observeRegistrationState())
.thenReturn(flowOf(registrationState))
whenever(connectionStateObserver.observe())
.thenReturn(flowOf(connectionState))
whenever(registerGateway.registerIfNeeded())
.thenReturn(registrationResult)
}

private fun sync() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ import tech.relaycorp.gateway.domain.LocalConfig
import tech.relaycorp.gateway.domain.endpoint.GatewayCertificateChangeNotifier
import tech.relaycorp.gateway.pdc.PoWebClientBuilder
import tech.relaycorp.gateway.test.BaseDataTestCase
import tech.relaycorp.gateway.test.PNRFactory
import tech.relaycorp.poweb.PoWebClient
import tech.relaycorp.relaynet.SessionKey
import tech.relaycorp.relaynet.bindings.pdc.ClientBindingException
import tech.relaycorp.relaynet.issueGatewayCertificate
import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistration
import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistrationAuthorization
import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistrationRequest
import tech.relaycorp.relaynet.testing.pki.KeyPairSet
Expand Down Expand Up @@ -113,7 +112,7 @@ class RegisterGatewayTest : BaseDataTestCase() {
whenever(poWebClient.preRegisterNode(any()))
.thenReturn(buildPNRR())
whenever(poWebClient.registerNode(any()))
.thenReturn(buildPNR(internetGatewaySessionKeyPair.sessionKey))
.thenReturn(PNRFactory.build(internetGatewaySessionKeyPair.sessionKey))

registerGateway.registerIfNeeded()

Expand All @@ -126,7 +125,7 @@ class RegisterGatewayTest : BaseDataTestCase() {
whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo)
val pnrr = buildPNRR()
whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr)
val pnr = buildPNR(internetGatewaySessionKeyPair.sessionKey)
val pnr = PNRFactory.build(internetGatewaySessionKeyPair.sessionKey)
whenever(poWebClient.registerNode(any())).thenReturn(pnr)

registerGateway.registerIfNeeded()
Expand Down Expand Up @@ -155,7 +154,7 @@ class RegisterGatewayTest : BaseDataTestCase() {
whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo)
val pnrr = buildPNRR()
whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr)
val pnr = buildPNR(null)
val pnr = PNRFactory.build(null)
whenever(poWebClient.registerNode(any())).thenReturn(pnr)

assertEquals(RegisterGateway.Result.FailedToRegister, registerGateway.registerIfNeeded())
Expand All @@ -172,7 +171,7 @@ class RegisterGatewayTest : BaseDataTestCase() {
whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.Done)
val pnrr = buildPNRR()
whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr)
val pnr = buildPNR(internetGatewaySessionKeyPair.sessionKey)
val pnr = PNRFactory.build(internetGatewaySessionKeyPair.sessionKey)
whenever(poWebClient.registerNode(any())).thenReturn(pnr)

registerGateway.registerIfNeeded()
Expand All @@ -185,7 +184,7 @@ class RegisterGatewayTest : BaseDataTestCase() {
whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo)
val pnrr = buildPNRR()
whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr)
val pnr = buildPNR(internetGatewaySessionKeyPair.sessionKey)
val pnr = PNRFactory.build(internetGatewaySessionKeyPair.sessionKey)
whenever(poWebClient.registerNode(any())).thenReturn(pnr)

registerGateway.registerIfNeeded()
Expand All @@ -203,11 +202,4 @@ class RegisterGatewayTest : BaseDataTestCase() {
authorization.serialize(KeyPairSet.PRIVATE_GW.private)
)
}

private fun buildPNR(internetGatewaySessionKey: SessionKey?) = PrivateNodeRegistration(
PDACertPath.PRIVATE_GW,
PDACertPath.INTERNET_GW,
"example.org",
internetGatewaySessionKey
)
}
Loading

0 comments on commit 4d6a51f

Please sign in to comment.