Skip to content

Commit

Permalink
Merge pull request #3877 from element-hq/feature/bma/fixUnifiedPushUn…
Browse files Browse the repository at this point in the history
…register

Fix unified push unregister
  • Loading branch information
bmarty authored Nov 15, 2024
2 parents d62313d + f0aca00 commit bb69e1e
Show file tree
Hide file tree
Showing 22 changed files with 243 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ class LoggedInPresenter @Inject constructor(
.also { Timber.tag(pusherTag.value).w("No distributors available") }
.also {
// In this case, consider the push provider is chosen.
pushService.selectPushProvider(matrixClient, pushProvider)
pushService.selectPushProvider(matrixClient.sessionId, pushProvider)
}
.also { pusherRegistrationState.value = AsyncData.Failure(PusherRegistrationFailure.NoDistributorsAvailable()) }
pushService.registerWith(matrixClient, pushProvider, distributor)
} else {
val currentPushDistributor = currentPushProvider.getCurrentDistributor(matrixClient)
val currentPushDistributor = currentPushProvider.getCurrentDistributor(matrixClient.sessionId)
if (currentPushDistributor == null) {
Timber.tag(pusherTag.value).d("Register with the first available distributor")
val distributor = currentPushProvider.getDistributors().firstOrNull()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class LoggedInPresenterTest {
val lambda = lambdaRecorder<MatrixClient, PushProvider, Distributor, Result<Unit>> { _, _, _ ->
Result.success(Unit)
}
val selectPushProviderLambda = lambdaRecorder<MatrixClient, PushProvider, Unit> { _, _ -> }
val selectPushProviderLambda = lambdaRecorder<SessionId, PushProvider, Unit> { _, _ -> }
val sessionVerificationService = FakeSessionVerificationService(
initialSessionVerifiedStatus = SessionVerifiedStatus.Verified
)
Expand Down Expand Up @@ -408,8 +408,8 @@ class LoggedInPresenterTest {
selectPushProviderLambda.assertions()
.isCalledOnce()
.with(
// MatrixClient
any(),
// SessionId
value(A_SESSION_ID),
// PushProvider
value(pushProvider),
)
Expand Down Expand Up @@ -481,7 +481,7 @@ class LoggedInPresenterTest {
registerWithLambda: (MatrixClient, PushProvider, Distributor) -> Result<Unit> = { _, _, _ ->
Result.success(Unit)
},
selectPushProviderLambda: (MatrixClient, PushProvider) -> Unit = { _, _ -> lambdaError() },
selectPushProviderLambda: (SessionId, PushProvider) -> Unit = { _, _ -> lambdaError() },
currentPushProvider: () -> PushProvider? = { null },
setIgnoreRegistrationErrorLambda: (SessionId, Boolean) -> Unit = { _, _ -> lambdaError() },
): PushService {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class NotificationSettingsPresenter @Inject constructor(

LaunchedEffect(refreshPushProvider) {
val p = pushService.getCurrentPushProvider()
val name = p?.getCurrentDistributor(matrixClient)?.name
val name = p?.getCurrentDistributor(matrixClient.sessionId)?.name
currentDistributorName = if (name != null) {
AsyncData.Success(name)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ interface PushService {
* To be used when there is no distributor available.
*/
suspend fun selectPushProvider(
matrixClient: MatrixClient,
sessionId: SessionId,
pushProvider: PushProvider,
)

Expand Down
2 changes: 2 additions & 0 deletions libraries/push/impl/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies {
implementation(projects.libraries.matrix.api)
implementation(projects.libraries.matrixui)
implementation(projects.libraries.preferences.api)
implementation(projects.libraries.sessionStorage.api)
implementation(projects.libraries.uiStrings)
implementation(projects.libraries.troubleshoot.api)
implementation(projects.features.call.api)
Expand All @@ -64,6 +65,7 @@ dependencies {
testImplementation(libs.coroutines.test)
testImplementation(projects.libraries.matrix.test)
testImplementation(projects.libraries.preferences.test)
testImplementation(projects.libraries.sessionStorage.test)
testImplementation(projects.libraries.push.test)
testImplementation(projects.libraries.pushproviders.test)
testImplementation(projects.libraries.pushstore.test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package io.element.android.libraries.push.impl

import com.squareup.anvil.annotations.ContributesBinding
import io.element.android.libraries.di.AppScope
import io.element.android.libraries.di.SingleIn
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.api.GetCurrentPushProvider
Expand All @@ -17,17 +18,27 @@ import io.element.android.libraries.push.impl.test.TestPush
import io.element.android.libraries.pushproviders.api.Distributor
import io.element.android.libraries.pushproviders.api.PushProvider
import io.element.android.libraries.pushstore.api.UserPushStoreFactory
import io.element.android.libraries.pushstore.api.clientsecret.PushClientSecretStore
import io.element.android.libraries.sessionstorage.api.observer.SessionListener
import io.element.android.libraries.sessionstorage.api.observer.SessionObserver
import kotlinx.coroutines.flow.Flow
import timber.log.Timber
import javax.inject.Inject

@ContributesBinding(AppScope::class)
@ContributesBinding(AppScope::class, boundType = PushService::class)
@SingleIn(AppScope::class)
class DefaultPushService @Inject constructor(
private val testPush: TestPush,
private val userPushStoreFactory: UserPushStoreFactory,
private val pushProviders: Set<@JvmSuppressWildcards PushProvider>,
private val getCurrentPushProvider: GetCurrentPushProvider,
) : PushService {
private val sessionObserver: SessionObserver,
private val pushClientSecretStore: PushClientSecretStore,
) : PushService, SessionListener {
init {
observeSessions()
}

override suspend fun getCurrentPushProvider(): PushProvider? {
val currentPushProvider = getCurrentPushProvider.getCurrentPushProvider()
return pushProviders.find { it.name == currentPushProvider }
Expand All @@ -47,7 +58,7 @@ class DefaultPushService @Inject constructor(
val userPushStore = userPushStoreFactory.getOrCreate(matrixClient.sessionId)
val currentPushProviderName = userPushStore.getPushProviderName()
val currentPushProvider = pushProviders.find { it.name == currentPushProviderName }
val currentDistributorValue = currentPushProvider?.getCurrentDistributor(matrixClient)?.value
val currentDistributorValue = currentPushProvider?.getCurrentDistributor(matrixClient.sessionId)?.value
if (currentPushProviderName != pushProvider.name || currentDistributorValue != distributor.value) {
// Unregister previous one if any
currentPushProvider
Expand All @@ -65,11 +76,11 @@ class DefaultPushService @Inject constructor(
}

override suspend fun selectPushProvider(
matrixClient: MatrixClient,
sessionId: SessionId,
pushProvider: PushProvider,
) {
Timber.d("Select ${pushProvider.name}")
val userPushStore = userPushStoreFactory.getOrCreate(matrixClient.sessionId)
val userPushStore = userPushStoreFactory.getOrCreate(sessionId)
userPushStore.setPushProviderName(pushProvider.name)
}

Expand All @@ -87,4 +98,31 @@ class DefaultPushService @Inject constructor(
testPush.execute(config)
return true
}

private fun observeSessions() {
sessionObserver.addListener(this)
}

override suspend fun onSessionCreated(userId: String) {
// Nothing to do
}

/**
* The session has been deleted.
* In this case, this is not necessary to unregister the pusher from the homeserver,
* but we need to do some cleanup locally.
* The current push provider may want to take action, and we need to
* cleanup the stores.
*/
override suspend fun onSessionDeleted(userId: String) {
val sessionId = SessionId(userId)
val userPushStore = userPushStoreFactory.getOrCreate(sessionId)
val currentPushProviderName = userPushStore.getPushProviderName()
val currentPushProvider = pushProviders.find { it.name == currentPushProviderName }
// Cleanup the current push provider. They may need the client secret, so delete the secret after.
currentPushProvider?.onSessionDeleted(sessionId)
// Now we can safely reset the stores.
pushClientSecretStore.resetSecret(sessionId)
userPushStore.reset()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package io.element.android.libraries.push.impl

import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.test.AN_EXCEPTION
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.FakeMatrixClient
Expand All @@ -22,8 +23,12 @@ import io.element.android.libraries.pushproviders.api.PushProvider
import io.element.android.libraries.pushproviders.test.FakePushProvider
import io.element.android.libraries.pushproviders.test.aCurrentUserPushConfig
import io.element.android.libraries.pushstore.api.UserPushStoreFactory
import io.element.android.libraries.pushstore.api.clientsecret.PushClientSecretStore
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStore
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStoreFactory
import io.element.android.libraries.pushstore.test.userpushstore.clientsecret.InMemoryPushClientSecretStore
import io.element.android.libraries.sessionstorage.api.observer.SessionObserver
import io.element.android.libraries.sessionstorage.test.observer.NoOpSessionObserver
import io.element.android.tests.testutils.lambda.lambdaRecorder
import io.element.android.tests.testutils.lambda.value
import kotlinx.coroutines.flow.first
Expand Down Expand Up @@ -210,17 +215,87 @@ class DefaultPushServiceTest {
assertThat(defaultPushService.ignoreRegistrationError(A_SESSION_ID).first()).isTrue()
}

@Test
fun `onSessionCreated is noop`() = runTest {
val defaultPushService = createDefaultPushService()
defaultPushService.onSessionCreated(A_SESSION_ID.value)
}

@Test
fun `onSessionDeleted should transmit the info to the current push provider and cleanup the stores`() = runTest {
val onSessionDeletedLambda = lambdaRecorder<SessionId, Unit> { }
val aCurrentPushProvider = FakePushProvider(
name = "aCurrentPushProvider",
onSessionDeletedLambda = onSessionDeletedLambda,
)
val userPushStore = FakeUserPushStore(
pushProviderName = aCurrentPushProvider.name,
)
val pushClientSecretStore = InMemoryPushClientSecretStore()
val defaultPushService = createDefaultPushService(
pushProviders = setOf(aCurrentPushProvider),
getCurrentPushProvider = FakeGetCurrentPushProvider(currentPushProvider = aCurrentPushProvider.name),
userPushStoreFactory = FakeUserPushStoreFactory(
userPushStore = { userPushStore },
),
pushClientSecretStore = pushClientSecretStore,
)
defaultPushService.onSessionDeleted(A_SESSION_ID.value)
assertThat(userPushStore.getPushProviderName()).isNull()
assertThat(pushClientSecretStore.getSecret(A_SESSION_ID)).isNull()
onSessionDeletedLambda.assertions().isCalledOnce().with(value(A_SESSION_ID))
}

@Test
fun `onSessionDeleted when there is no push provider should just cleanup the stores`() = runTest {
val userPushStore = FakeUserPushStore(
pushProviderName = null,
)
val pushClientSecretStore = InMemoryPushClientSecretStore()
val defaultPushService = createDefaultPushService(
pushProviders = emptySet(),
getCurrentPushProvider = FakeGetCurrentPushProvider(currentPushProvider = null),
userPushStoreFactory = FakeUserPushStoreFactory(
userPushStore = { userPushStore },
),
pushClientSecretStore = pushClientSecretStore,
)
defaultPushService.onSessionDeleted(A_SESSION_ID.value)
assertThat(userPushStore.getPushProviderName()).isNull()
assertThat(pushClientSecretStore.getSecret(A_SESSION_ID)).isNull()
}

@Test
fun `selectPushProvider should store the data in the store`() = runTest {
val userPushStore = FakeUserPushStore()
val defaultPushService = createDefaultPushService(
userPushStoreFactory = FakeUserPushStoreFactory(
userPushStore = { userPushStore },
),
)
val aPushProvider = FakePushProvider(
name = "aCurrentPushProvider",
)
assertThat(userPushStore.getPushProviderName()).isNull()
defaultPushService.selectPushProvider(A_SESSION_ID, aPushProvider)
assertThat(userPushStore.getPushProviderName()).isEqualTo(aPushProvider.name)
}

private fun createDefaultPushService(
testPush: TestPush = FakeTestPush(),
userPushStoreFactory: UserPushStoreFactory = FakeUserPushStoreFactory(),
pushProviders: Set<@JvmSuppressWildcards PushProvider> = emptySet(),
getCurrentPushProvider: GetCurrentPushProvider = FakeGetCurrentPushProvider(currentPushProvider = null),
sessionObserver: SessionObserver = NoOpSessionObserver(),
pushClientSecretStore: PushClientSecretStore = InMemoryPushClientSecretStore(),
): DefaultPushService {
return DefaultPushService(
testPush = testPush,
userPushStoreFactory = userPushStoreFactory,
pushProviders = pushProviders,
getCurrentPushProvider = getCurrentPushProvider,
sessionObserver = sessionObserver,
pushClientSecretStore = pushClientSecretStore,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FakePushService(
Result.success(Unit)
},
private val currentPushProvider: () -> PushProvider? = { availablePushProviders.firstOrNull() },
private val selectPushProviderLambda: suspend (MatrixClient, PushProvider) -> Unit = { _, _ -> lambdaError() },
private val selectPushProviderLambda: suspend (SessionId, PushProvider) -> Unit = { _, _ -> lambdaError() },
private val setIgnoreRegistrationErrorLambda: (SessionId, Boolean) -> Unit = { _, _ -> lambdaError() },
) : PushService {
override suspend fun getCurrentPushProvider(): PushProvider? {
Expand All @@ -50,8 +50,8 @@ class FakePushService(
}
}

override suspend fun selectPushProvider(matrixClient: MatrixClient, pushProvider: PushProvider) {
selectPushProviderLambda(matrixClient, pushProvider)
override suspend fun selectPushProvider(sessionId: SessionId, pushProvider: PushProvider) {
selectPushProviderLambda(sessionId, pushProvider)
}

private val ignoreRegistrationError = MutableStateFlow(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package io.element.android.libraries.pushproviders.api

import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId

/**
* This is the main API for this module.
Expand Down Expand Up @@ -36,13 +37,18 @@ interface PushProvider {
/**
* Return the current distributor, or null if none.
*/
suspend fun getCurrentDistributor(matrixClient: MatrixClient): Distributor?
suspend fun getCurrentDistributor(sessionId: SessionId): Distributor?

/**
* Unregister the pusher.
*/
suspend fun unregister(matrixClient: MatrixClient): Result<Unit>

/**
* To invoke when the session is deleted.
*/
suspend fun onSessionDeleted(sessionId: SessionId)

suspend fun getCurrentUserPushConfig(): CurrentUserPushConfig?

fun canRotateToken(): Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.squareup.anvil.annotations.ContributesMultibinding
import io.element.android.libraries.core.log.logger.LoggerTag
import io.element.android.libraries.di.AppScope
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.pushproviders.api.CurrentUserPushConfig
import io.element.android.libraries.pushproviders.api.Distributor
import io.element.android.libraries.pushproviders.api.PushProvider
Expand Down Expand Up @@ -51,7 +52,7 @@ class FirebasePushProvider @Inject constructor(
)
}

override suspend fun getCurrentDistributor(matrixClient: MatrixClient) = firebaseDistributor
override suspend fun getCurrentDistributor(sessionId: SessionId) = firebaseDistributor

override suspend fun unregister(matrixClient: MatrixClient): Result<Unit> {
val pushKey = firebaseStore.getFcmToken()
Expand All @@ -63,6 +64,11 @@ class FirebasePushProvider @Inject constructor(
}
}

/**
* Nothing to clean up here.
*/
override suspend fun onSessionDeleted(sessionId: SessionId) = Unit

override suspend fun getCurrentUserPushConfig(): CurrentUserPushConfig? {
return firebaseStore.getFcmToken()?.let { fcmToken ->
CurrentUserPushConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package io.element.android.libraries.pushproviders.firebase
import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.test.AN_EXCEPTION
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.FakeMatrixClient
import io.element.android.libraries.push.test.FakePusherSubscriber
import io.element.android.libraries.pushproviders.api.CurrentUserPushConfig
Expand Down Expand Up @@ -47,9 +48,9 @@ class FirebasePushProviderTest {
}

@Test
fun `getCurrentDistributor always return the unique distributor`() = runTest {
fun `getCurrentDistributor always returns the unique distributor`() = runTest {
val firebasePushProvider = createFirebasePushProvider()
val result = firebasePushProvider.getCurrentDistributor(FakeMatrixClient())
val result = firebasePushProvider.getCurrentDistributor(A_SESSION_ID)
assertThat(result).isEqualTo(Distributor("Firebase", "Firebase"))
}

Expand Down Expand Up @@ -176,6 +177,18 @@ class FirebasePushProviderTest {
lambda.assertions().isCalledOnce()
}

@Test
fun `canRotateToken should return true`() = runTest {
val firebasePushProvider = createFirebasePushProvider()
assertThat(firebasePushProvider.canRotateToken()).isTrue()
}

@Test
fun `onSessionDeleted should be noop`() = runTest {
val firebasePushProvider = createFirebasePushProvider()
firebasePushProvider.onSessionDeleted(A_SESSION_ID)
}

private fun createFirebasePushProvider(
firebaseStore: FirebaseStore = InMemoryFirebaseStore(),
pusherSubscriber: PusherSubscriber = FakePusherSubscriber(),
Expand Down
Loading

0 comments on commit bb69e1e

Please sign in to comment.