diff --git a/aws-datastore/src/main/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitor.kt b/aws-datastore/src/main/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitor.kt index 009bd67e09..6ebc0b0176 100644 --- a/aws-datastore/src/main/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitor.kt +++ b/aws-datastore/src/main/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitor.kt @@ -22,8 +22,7 @@ import android.net.Network import androidx.annotation.VisibleForTesting import com.amplifyframework.datastore.DataStoreException import io.reactivex.rxjava3.core.Observable -import io.reactivex.rxjava3.core.ObservableEmitter -import io.reactivex.rxjava3.core.ObservableOnSubscribe +import io.reactivex.rxjava3.subjects.BehaviorSubject import java.util.concurrent.TimeUnit /** @@ -54,25 +53,33 @@ public interface ReachabilityMonitor { } private class ReachabilityMonitorImpl constructor(val schedulerProvider: SchedulerProvider) : ReachabilityMonitor { - private var emitter: ObservableOnSubscribe? = null + private val subject = BehaviorSubject.create() + private var connectivityProvider: ConnectivityProvider? = null override fun configure(context: Context) { return configure(context, DefaultConnectivityProvider()) } override fun configure(context: Context, connectivityProvider: ConnectivityProvider) { - emitter = ObservableOnSubscribe { emitter -> - val callback = getCallback(emitter) - connectivityProvider.registerDefaultNetworkCallback(context, callback) - // Provide the current network status upon subscription. - emitter.onNext(connectivityProvider.hasActiveNetwork) - } + this.connectivityProvider = connectivityProvider + connectivityProvider.registerDefaultNetworkCallback( + context, + object : NetworkCallback() { + override fun onAvailable(network: Network) { + subject.onNext(true) + } + + override fun onLost(network: Network) { + subject.onNext(false) + } + } + ) } override fun getObservable(): Observable { - emitter?.let { emitter -> - return Observable.create(emitter) - .subscribeOn(schedulerProvider.io()) + connectivityProvider?.let { connectivityProvider -> + return subject.subscribeOn(schedulerProvider.io()) + .doOnSubscribe { subject.onNext(connectivityProvider.hasActiveNetwork) } .debounce(250, TimeUnit.MILLISECONDS, schedulerProvider.computation()) } ?: run { throw DataStoreException( @@ -81,17 +88,6 @@ private class ReachabilityMonitorImpl constructor(val schedulerProvider: Schedul ) } } - - private fun getCallback(emitter: ObservableEmitter): NetworkCallback { - return object : NetworkCallback() { - override fun onAvailable(network: Network) { - emitter.onNext(true) - } - override fun onLost(network: Network) { - emitter.onNext(false) - } - } - } } /** diff --git a/aws-datastore/src/test/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitorTest.kt b/aws-datastore/src/test/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitorTest.kt index 6bccd55700..945ef9320d 100644 --- a/aws-datastore/src/test/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitorTest.kt +++ b/aws-datastore/src/test/java/com/amplifyframework/datastore/syncengine/ReachabilityMonitorTest.kt @@ -8,6 +8,7 @@ import io.reactivex.rxjava3.core.BackpressureStrategy import io.reactivex.rxjava3.schedulers.TestScheduler import io.reactivex.rxjava3.subscribers.TestSubscriber import java.util.concurrent.TimeUnit +import org.junit.Assert.assertEquals import org.junit.Test import org.mockito.Mockito.mock @@ -69,4 +70,47 @@ class ReachabilityMonitorTest { testSubscriber.assertValues(true, false, true, true) } + + /** + * Test that calling getObservable() multiple times only results in the network + * callback being registered once. + */ + @Test + fun testNetworkCallbackRegisteredOnce() { + var networkCallback: ConnectivityManager.NetworkCallback? = null + var numCallbacksRegistered = 0 + + val connectivityProvider = object : ConnectivityProvider { + override val hasActiveNetwork: Boolean + get() = run { + return true + } + override fun registerDefaultNetworkCallback( + context: Context, + callback: ConnectivityManager.NetworkCallback + ) { + networkCallback = callback + numCallbacksRegistered += 1 + } + } + + // TestScheduler allows the virtual time to be advanced by exact amounts, to allow for repeatable tests + val testScheduler = TestScheduler() + val reachabilityMonitor = ReachabilityMonitor.createForTesting(TestSchedulerProvider(testScheduler)) + val mockContext = mock(Context::class.java) + reachabilityMonitor.configure(mockContext, connectivityProvider) + + reachabilityMonitor.getObservable().subscribe() + val network = mock(Network::class.java) + // Should provide initial network state (true) upon subscription (after debounce) + testScheduler.advanceTimeBy(251, TimeUnit.MILLISECONDS) + networkCallback!!.onAvailable(network) + + reachabilityMonitor.getObservable().subscribe() + testScheduler.advanceTimeBy(251, TimeUnit.MILLISECONDS) + networkCallback!!.onAvailable(network) + + // Only 1 network callback should be registered + assertEquals(1, numCallbacksRegistered) + } } diff --git a/aws-push-notifications-pinpoint-common/build.gradle.kts b/aws-push-notifications-pinpoint-common/build.gradle.kts index 372e0c6fb9..e20f95c399 100644 --- a/aws-push-notifications-pinpoint-common/build.gradle.kts +++ b/aws-push-notifications-pinpoint-common/build.gradle.kts @@ -43,7 +43,7 @@ dependencies { //noinspection GradleDependency implementation("androidx.lifecycle:lifecycle-livedata-ktx:$lifecycleVersion") //noinspection GradleDependency - implementation("com.google.android.material:material:1.8.0") + implementation(dependency.google.material) testImplementation(testDependency.junit) testImplementation(testDependency.mockk) diff --git a/settings.gradle.kts b/settings.gradle.kts index ea86e3a94d..b0f0ef2a33 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -130,7 +130,7 @@ dependencyResolutionManagement { library("rxjava", "io.reactivex.rxjava3:rxjava:3.0.6") // Google - library("google-material", "com.google.android.material:material:1.4.0") + library("google-material", "com.google.android.material:material:1.8.0") library("firebasemessaging", "com.google.firebase:firebase-messaging-ktx:23.1.0") // Misc