Skip to content

Commit

Permalink
fix(push): Prevent Unintended OptOuts (#2587)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjroach authored Sep 15, 2023
1 parent fdba2d9 commit 10d8914
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.amplifyframework.analytics.pinpoint
import android.content.Context
import aws.sdk.kotlin.services.pinpoint.PinpointClient
import aws.smithy.kotlin.runtime.auth.awscredentials.CredentialsProvider
import com.amplifyframework.core.store.EncryptedKeyValueRepository
import com.amplifyframework.pinpoint.core.AnalyticsClient
import com.amplifyframework.pinpoint.core.TargetingClient
import com.amplifyframework.pinpoint.core.data.AndroidAppDetails
Expand Down Expand Up @@ -61,11 +62,17 @@ internal class PinpointManager constructor(
Context.MODE_PRIVATE
)

val encryptedStore = EncryptedKeyValueRepository(
context,
"${awsPinpointConfiguration.appId}$PINPOINT_SHARED_PREFS_SUFFIX"
)

val androidAppDetails = AndroidAppDetails(context, awsPinpointConfiguration.appId)
val androidDeviceDetails = AndroidDeviceDetails(context)
targetingClient = TargetingClient(
context,
pinpointClient,
encryptedStore,
sharedPrefs,
androidAppDetails,
androidDeviceDetails,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class AWSPinpointAnalyticsPluginBehaviorTest {
sharedPrefs.getUniqueId(),
androidAppDetails,
androidDeviceDetails,
ApplicationProvider.getApplicationContext()
ApplicationProvider.getApplicationContext(),
mockk()
)
}
val actualEndpoint = slot<EndpointProfile>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import android.content.Context
import android.database.Cursor
import android.net.Uri
import aws.sdk.kotlin.services.pinpoint.PinpointClient
import aws.sdk.kotlin.services.pinpoint.model.ChannelType
import aws.sdk.kotlin.services.pinpoint.model.EndpointDemographic
import aws.sdk.kotlin.services.pinpoint.model.EndpointItemResponse
import aws.sdk.kotlin.services.pinpoint.model.EndpointLocation
Expand Down Expand Up @@ -286,12 +287,10 @@ class EventRecorder(
demographic = endpointDemographic
effectiveDate = endpointProfile.effectiveDate.millisToIsoDate()

if (endpointProfile.address != "" && endpointProfile.channelType != null) {
if (endpointProfile.address != "" && endpointProfile.channelType == ChannelType.Gcm) {
optOut = "NONE" // no opt out, send notifications
address = endpointProfile.address
channelType = endpointProfile.channelType
} else {
optOut = "ALL" // opt out from all notifications
}

attributes = endpointProfile.allAttributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.amplifyframework.pinpoint.core
import android.content.Context
import android.content.SharedPreferences
import aws.sdk.kotlin.services.pinpoint.PinpointClient
import aws.sdk.kotlin.services.pinpoint.model.ChannelType
import aws.sdk.kotlin.services.pinpoint.model.EndpointDemographic
import aws.sdk.kotlin.services.pinpoint.model.EndpointLocation
import aws.sdk.kotlin.services.pinpoint.model.EndpointRequest
Expand All @@ -31,6 +32,7 @@ import com.amplifyframework.analytics.UserProfile
import com.amplifyframework.annotations.InternalAmplifyApi
import com.amplifyframework.core.Amplify
import com.amplifyframework.core.category.CategoryType
import com.amplifyframework.core.store.KeyValueRepository
import com.amplifyframework.pinpoint.core.data.AndroidAppDetails
import com.amplifyframework.pinpoint.core.data.AndroidDeviceDetails
import com.amplifyframework.pinpoint.core.endpointProfile.EndpointProfile
Expand All @@ -52,12 +54,13 @@ import org.json.JSONObject
class TargetingClient(
context: Context,
private val pinpointClient: PinpointClient,
store: KeyValueRepository,
private val prefs: SharedPreferences,
appDetails: AndroidAppDetails,
deviceDetails: AndroidDeviceDetails,
coroutineDispatcher: CoroutineDispatcher = Dispatchers.Default
coroutineDispatcher: CoroutineDispatcher = Dispatchers.Default,
) {
private val endpointProfile = EndpointProfile(prefs.getUniqueId(), appDetails, deviceDetails, context)
private val endpointProfile = EndpointProfile(prefs.getUniqueId(), appDetails, deviceDetails, context, store)
private val globalAttributes: MutableMap<String, List<String>>
private val globalMetrics: MutableMap<String, Double>
private val coroutineScope = CoroutineScope(coroutineDispatcher)
Expand Down Expand Up @@ -211,12 +214,10 @@ class TargetingClient(
this.location = location
this.demographic = demographic
effectiveDate = endpointProfile.effectiveDate.millisToIsoDate()
if (endpointProfile.address != "" && endpointProfile.channelType != null) {
if (endpointProfile.address != "" && endpointProfile.channelType == ChannelType.Gcm) {
optOut = "NONE" // no opt out, send notifications
address = endpointProfile.address
channelType = endpointProfile.channelType
} else {
optOut = "ALL" // opt out from all notifications
}

attributes = endpointProfile.allAttributes
Expand Down Expand Up @@ -371,6 +372,9 @@ class TargetingClient(
}

companion object {
@InternalAmplifyApi
const val AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY = "FCMDeviceToken"

private val LOG = Amplify.Logging.logger(CategoryType.ANALYTICS, "amplify:aws-analytics-pinpoint")
private const val CUSTOM_ATTRIBUTES_KEY = "ENDPOINT_PROFILE_CUSTOM_ATTRIBUTES"
private const val CUSTOM_METRICS_KEY = "ENDPOINT_PROFILE_CUSTOM_METRICS"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import aws.sdk.kotlin.services.pinpoint.model.ChannelType
import com.amplifyframework.annotations.InternalAmplifyApi
import com.amplifyframework.core.Amplify
import com.amplifyframework.core.category.CategoryType
import com.amplifyframework.core.store.KeyValueRepository
import com.amplifyframework.pinpoint.core.TargetingClient
import com.amplifyframework.pinpoint.core.data.AndroidAppDetails
import com.amplifyframework.pinpoint.core.data.AndroidDeviceDetails
import com.amplifyframework.pinpoint.core.util.millisToIsoDate
Expand All @@ -39,14 +41,17 @@ class EndpointProfile(
uniqueId: String,
appDetails: AndroidAppDetails,
deviceDetails: AndroidDeviceDetails,
applicationContext: Context
applicationContext: Context,
private val store: KeyValueRepository
) {
private val attributes: MutableMap<String, List<String>> = ConcurrentHashMap()
private val metrics: MutableMap<String, Double> = ConcurrentHashMap()
private val currentNumOfAttributesAndMetrics = AtomicInteger(0)

var channelType: ChannelType? = null
var address: String = ""
val address: String get() {
return store.get(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY) ?: ""
}

private val country: String = try {
applicationContext.resources.configuration.locales[0].isO3Country
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import android.content.Context
import android.content.SharedPreferences
import androidx.test.core.app.ApplicationProvider
import aws.sdk.kotlin.services.pinpoint.PinpointClient
import com.amplifyframework.core.store.KeyValueRepository
import com.amplifyframework.pinpoint.core.data.AndroidAppDetails
import com.amplifyframework.pinpoint.core.data.AndroidDeviceDetails
import com.amplifyframework.pinpoint.core.endpointProfile.EndpointProfile
Expand Down Expand Up @@ -48,13 +49,15 @@ internal val country = "en_US"
internal val effectiveDate = 0L

internal val preferences = mockk<SharedPreferences>()
internal val store = mockk<KeyValueRepository>()
internal val appDetails = AndroidAppDetails(appID, appTitle, packageName, versionCode, versionName)
internal val deviceDetails = AndroidDeviceDetails(carrier = carrier, locale = locale)
internal val applicationContext = mockk<Context>()

internal fun setup() {
mockkStatic("com.amplifyframework.pinpoint.core.util.SharedPreferencesUtilKt")
every { preferences.getUniqueId() }.returns(uniqueID)
every { store.get(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY) } returns ""
every { applicationContext.resources.configuration.locales[0].isO3Country }
.returns(country)
}
Expand All @@ -65,7 +68,8 @@ internal fun constructEndpointProfile(): EndpointProfile {
preferences.getUniqueId(),
appDetails,
deviceDetails,
applicationContext
applicationContext,
store
)
endpointProfile.effectiveDate = effectiveDate
return endpointProfile
Expand All @@ -83,8 +87,9 @@ internal fun constructTargetingClient(): TargetingClient {
return TargetingClient(
applicationContext,
pinpointClient,
store,
prefs,
appDetails,
deviceDetails
deviceDetails,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ package com.amplifyframework.pinpoint.core

import android.os.Build
import aws.sdk.kotlin.services.pinpoint.PinpointClient
import aws.sdk.kotlin.services.pinpoint.model.ChannelType
import aws.sdk.kotlin.services.pinpoint.model.EndpointRequest
import aws.sdk.kotlin.services.pinpoint.model.UpdateEndpointRequest
import aws.sdk.kotlin.services.pinpoint.model.UpdateEndpointResponse
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
Expand All @@ -47,34 +49,75 @@ class TargetingClientTest {
}

@Test
fun testCurrentEndpoint() {
targetingClient.addAttribute("attribute", listOf("a", "b", "c"))
targetingClient.addMetric("metric", 2.0)
val endpoint = targetingClient.currentEndpoint()
assertEquals(endpoint.getAttribute("attribute"), listOf("a", "b", "c"))
assertEquals(endpoint.getMetric("metric"), 2.0)
fun testUpdateEndpointProfile() = runTest {
setup()
targetingClient = constructTargetingClient()

val expectedToken = "token123"
every { store.get(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY) } returns expectedToken

val updateEndpointResponse = UpdateEndpointResponse.invoke {}
coEvery { pinpointClient.updateEndpoint(ofType(UpdateEndpointRequest::class)) }.returns(updateEndpointResponse)
targetingClient.updateEndpointProfile()

coVerify {
pinpointClient.updateEndpoint(
coWithArg {
assertNotNull(it.endpointRequest)
val request: EndpointRequest = it.endpointRequest!!
assertEquals("app id", it.applicationId)
assertEquals(expectedToken, request.address)
}
)
}
}

@Test
fun testUpdateEndpointProfile() = runTest {
fun testUpdateEndpointProfileOptsIn() = runTest {
setup()
targetingClient = constructTargetingClient()
targetingClient.currentEndpoint().channelType = ChannelType.Gcm

val expectedToken = "token123"
every { store.get(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY) } returns expectedToken

val updateEndpointResponse = UpdateEndpointResponse.invoke {}
coEvery { pinpointClient.updateEndpoint(ofType(UpdateEndpointRequest::class)) }.returns(updateEndpointResponse)
targetingClient.updateEndpointProfile()

coVerify {
pinpointClient.updateEndpoint(
coWithArg {
assertNotNull(it.endpointRequest)
val request: EndpointRequest = it.endpointRequest!!
assertEquals("app id", it.applicationId)
assertEquals(expectedToken, request.address)
assertEquals("NONE", request.optOut)
}
)
}
}

@Test
fun testUpdateEndpointProfileOptOutNotTouched() = runTest {
setup()
targetingClient = constructTargetingClient()
targetingClient.currentEndpoint().channelType = null

targetingClient.addAttribute("attribute", listOf("a1", "a2"))
targetingClient.addMetric("metric", 1.0)
val expectedToken = ""
every { store.get(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY) } returns expectedToken

val updateEndpointResponse = UpdateEndpointResponse.invoke {}
coEvery { pinpointClient.updateEndpoint(ofType(UpdateEndpointRequest::class)) }.returns(updateEndpointResponse)
targetingClient.updateEndpointProfile()

coVerify {
pinpointClient.updateEndpoint(
coWithArg<UpdateEndpointRequest> {
coWithArg {
assertNotNull(it.endpointRequest)
val request: EndpointRequest = it.endpointRequest!!
assertEquals("app id", it.applicationId)
assertEquals(listOf("a1", "a2"), request.attributes?.get("attribute") ?: listOf("wrong"))
assertEquals(1.0, request.metrics?.get("metric") ?: -1.0, 0.01)
assertEquals(expectedToken, request.address)
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class AWSPinpointPushNotificationsPlugin : PushNotificationsPlugin<PinpointClien
private const val DEFAULT_AUTO_FLUSH_INTERVAL = 30000L
private const val AWS_PINPOINT_PUSHNOTIFICATIONS_PREFERENCES_SUFFIX = "515d6767-01b7-49e5-8273-c8d11b0f331d"
private const val AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_LEGACY_KEY = "AWSPINPOINT.GCMTOKEN"
private const val AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY = "FCMDeviceToken"
}

private lateinit var preferences: SharedPreferences
Expand Down Expand Up @@ -116,7 +115,7 @@ class AWSPinpointPushNotificationsPlugin : PushNotificationsPlugin<PinpointClien

val deviceToken = preferences.getString(AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_LEGACY_KEY, null)
deviceToken?.let {
store.put(AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY, it)
store.put(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY, it)
preferences.edit { remove(AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_LEGACY_KEY).apply() }
}
}
Expand All @@ -136,7 +135,14 @@ class AWSPinpointPushNotificationsPlugin : PushNotificationsPlugin<PinpointClien
androidAppDetails: AndroidAppDetails,
androidDeviceDetails: AndroidDeviceDetails
): TargetingClient {
return TargetingClient(context, pinpointClient, preferences, androidAppDetails, androidDeviceDetails)
return TargetingClient(
context,
pinpointClient,
store,
preferences,
androidAppDetails,
androidDeviceDetails
)
}

private fun createAnalyticsClient(
Expand Down Expand Up @@ -214,11 +220,10 @@ class AWSPinpointPushNotificationsPlugin : PushNotificationsPlugin<PinpointClien

override fun registerDevice(token: String, onSuccess: Action, onError: Consumer<PushNotificationsException>) {
try {
store.put(AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY, token)
store.put(TargetingClient.AWS_PINPOINT_PUSHNOTIFICATIONS_DEVICE_TOKEN_KEY, token)
// targetingClient needs to send the address, optOut etc. to Pinpoint so we can receive campaigns/journeys
val endpointProfile = targetingClient.currentEndpoint().apply {
channelType = ChannelType.Gcm
address = token
}

targetingClient.updateEndpointProfile(endpointProfile)
Expand Down

0 comments on commit 10d8914

Please sign in to comment.