Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datastore): OIDC Rework #966

Merged
merged 15 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ This code is the iOS part of the Amplify Flutter Pinpoint Analytics Plugin. The
s.source = { :path => '.' }
s.source_files = 'Classes/**/*'
s.dependency 'Flutter'
s.dependency 'Amplify', '~> 1.15.0'
s.dependency 'AmplifyPlugins/AWSPinpointAnalyticsPlugin', '~> 1.15.0'
s.dependency 'Amplify', '~> 1.15.2'
s.dependency 'AmplifyPlugins/AWSPinpointAnalyticsPlugin', '~> 1.15.2'
s.dependency 'amplify_core'
s.platform = :ios, '11.0'

Expand Down
1 change: 1 addition & 0 deletions packages/amplify_api/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dependencies {
testImplementation 'junit:junit:4.13.2'
testImplementation 'org.mockito:mockito-core:3.10.0'
testImplementation 'org.mockito:mockito-inline:3.10.0'
testImplementation "org.mockito.kotlin:mockito-kotlin:3.2.0"
testImplementation 'androidx.test:core:1.4.0'
testImplementation 'org.robolectric:robolectric:4.3.1'
testImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.3.9'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ import androidx.annotation.NonNull
import androidx.annotation.VisibleForTesting
import com.amazonaws.amplify.amplify_api.auth.FlutterAuthProviders
import com.amazonaws.amplify.amplify_api.rest_api.FlutterRestApi
import com.amazonaws.amplify.amplify_core.cast
import com.amazonaws.amplify.amplify_core.exception.ExceptionUtil.Companion.createSerializedUnrecognizedError
import com.amazonaws.amplify.amplify_core.exception.ExceptionUtil.Companion.handleAddPluginException
import com.amazonaws.amplify.amplify_core.exception.ExceptionUtil.Companion.postExceptionToFlutterChannel
import com.amplifyframework.api.ApiException
import com.amplifyframework.api.aws.AWSApiPlugin
import com.amplifyframework.api.aws.AuthorizationType
import com.amplifyframework.core.Amplify
import io.flutter.embedding.engine.plugins.FlutterPlugin
import io.flutter.plugin.common.EventChannel
Expand All @@ -39,15 +36,6 @@ import io.flutter.plugin.common.MethodChannel.Result

/** AmplifyApiPlugin */
class AmplifyApiPlugin : FlutterPlugin, MethodCallHandler {
companion object {
/**
* Thrown when [tokenType] is used but is not a valid [AuthorizationType].
*/
private fun invalidTokenType(tokenType: String? = null) = ApiException.ApiAuthException(
"Invalid arguments",
"Invalid token type: $tokenType"
)
}

private lateinit var channel: MethodChannel
private lateinit var eventchannel: EventChannel
Expand Down Expand Up @@ -89,7 +77,7 @@ class AmplifyApiPlugin : FlutterPlugin, MethodCallHandler {
Amplify.addPlugin(
AWSApiPlugin
.builder()
.apiAuthProviders(FlutterAuthProviders.factory)
.apiAuthProviders(FlutterAuthProviders(channel).factory)
.build()
)
logger.info("Added API plugin")
Expand All @@ -103,12 +91,6 @@ class AmplifyApiPlugin : FlutterPlugin, MethodCallHandler {
try {
val arguments: Map<String, Any> = call.arguments as Map<String, Any>

// Update tokens if included with request
val tokens = arguments["tokens"] as? List<*>
if (tokens != null && tokens.isNotEmpty()) {
updateTokens(tokens)
}

when (call.method) {
"get" -> FlutterRestApi.get(result, arguments)
"post" -> FlutterRestApi.post(result, arguments)
Expand All @@ -123,17 +105,6 @@ class AmplifyApiPlugin : FlutterPlugin, MethodCallHandler {
arguments,
graphqlSubscriptionStreamHandler
)
"updateTokens" -> {
if (tokens == null || tokens.isEmpty()) {
throw ApiException(
"Invalid token map provided",
"Provide tokens in the \"tokens\" field"
)
}

// Tokens already updated
result.success(null)
}
else -> result.notImplemented()
}
} catch (e: Exception) {
Expand Down Expand Up @@ -162,19 +133,6 @@ class AmplifyApiPlugin : FlutterPlugin, MethodCallHandler {
}
}

private fun updateTokens(tokens: List<*>) {
for (authToken in tokens.cast<Map<String, Any?>>()) {
val token = authToken["token"] as? String?
val tokenType = authToken["type"] as? String ?: throw invalidTokenType()
val authType: AuthorizationType = try {
AuthorizationType.from(tokenType)
} catch (e: Exception) {
throw invalidTokenType(tokenType)
}
FlutterAuthProviders.setToken(authType, token)
}
}

override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) {
channel.setMethodCallHandler(null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,137 @@
*/
package com.amazonaws.amplify.amplify_api.auth

import android.os.Looper
import com.amplifyframework.api.ApiException
import com.amplifyframework.api.aws.ApiAuthProviders
import com.amplifyframework.api.aws.AuthorizationType
import com.amplifyframework.api.aws.sigv4.FunctionAuthProvider
import com.amplifyframework.api.aws.sigv4.OidcAuthProvider
import io.flutter.Log
import io.flutter.plugin.common.MethodChannel
import kotlinx.coroutines.*

/**
* Manages the shared state of all [FlutterAuthProvider] instances.
*/
object FlutterAuthProviders {
class FlutterAuthProviders(private val methodChannel: MethodChannel) {

private companion object {
/**
* Timeout on a single [getToken] call.
*/
const val getTokenTimeoutMillis = 2000L

/**
* Logger tag.
*/
const val tag = "FlutterAuthProviders"

/**
* Name for suspending block in [getToken]. Used for debugging
*/
val coroutineName = CoroutineName(tag)
}

/**
* A factory of [FlutterAuthProvider] instances.
*/
val factory: ApiAuthProviders by lazy {
ApiAuthProviders
.Builder()
.functionAuthProvider(FlutterAuthProvider(AuthorizationType.AWS_LAMBDA))
.oidcAuthProvider(FlutterAuthProvider(AuthorizationType.OPENID_CONNECT))
.functionAuthProvider(FlutterAuthProvider(this, AuthorizationType.AWS_LAMBDA))
.oidcAuthProvider(FlutterAuthProvider(this, AuthorizationType.OPENID_CONNECT))
.build()
}

/**
* Token cache for all [FlutterAuthProvider] instances.
*/
private var tokens: MutableMap<AuthorizationType, String?> = mutableMapOf()

/**
* Retrieves the token for [authType] or `null`, if unavailable.
*
* This function is typically called from within the Amplify library and from a thread besides
* the main one, where it is safe to block. In API REST, the calling thread is main and we must
* return `null`.
*
* Not blocking the main thread is required for making platform channel calls without deadlock.
*/
fun getToken(authType: AuthorizationType): String? = tokens[authType]
fun getToken(authType: AuthorizationType): String? {
if (Thread.currentThread() == Looper.getMainLooper().thread) {
// API REST will call this function from the main thread on configuration. This is bad.
// Since we have to block the calling thread to retrieve the token, just return null.
Log.e(tag, "REST OIDC/Lambda is not supported yet.")
return null
}
try {
return runBlocking(coroutineName) {
val completer = Job()

/**
* Sets the token for [authType] to [value].
*/
fun setToken(authType: AuthorizationType, value: String?) {
tokens[authType] = value
val result = object : MethodChannel.Result {
var token: String? = null

override fun success(result: Any?) {
token = result as? String
launch(Dispatchers.Main) {
completer.complete()
}
}

override fun error(
errorCode: String?,
errorMessage: String?,
errorDetails: Any?
) {
launch(Dispatchers.Main) {
completer.complete()
}
}

override fun notImplemented() {
launch(Dispatchers.Main) {
completer.complete()
}
}
}
launch(Dispatchers.Main) {
methodChannel.invokeMethod(
"getLatestAuthToken",
authType.name,
result
)
}

withTimeout(getTokenTimeoutMillis) {
completer.join()
}

return@runBlocking result.token
}
} catch (e: Exception) {
Log.e(tag, "Exception in getToken", e)
return null
}
}
}

/**
* A provider which manages token retrieval for its [AuthorizationType].
*/
class FlutterAuthProvider(private val type: AuthorizationType) : FunctionAuthProvider,
class FlutterAuthProvider(
private val provider: FlutterAuthProviders,
private val type: AuthorizationType
) : FunctionAuthProvider,
OidcAuthProvider {
private companion object {
/**
* Thrown when there is no token available for [type].
*/
fun noTokenAvailable(type: AuthorizationType) = ApiException.ApiAuthException(
"No $type token available",
"Ensure that `getLatestAuthToken` returns a value"
"Unable to retrieve token for $type",
"""
Make sure you register your auth providers in the addPlugin call and
that getLatestAuthToken returns a value.
""".trimIndent()
)
}

override fun getLatestAuthToken(): String =
FlutterAuthProviders.getToken(type) ?: throw noTokenAvailable(type)
provider.getToken(type) ?: throw noTokenAvailable(type)
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.*
import org.mockito.kotlin.clearInvocations
import org.mockito.kotlin.mock
import org.mockito.kotlin.times
import org.mockito.kotlin.verify

/**
* Mock model object for building dummy GraphQL requests.
Expand All @@ -49,6 +52,7 @@ class MockModel : Model {
@ExperimentalCoroutinesApi
@RunWith(AndroidJUnit4::class)
class AuthProviderTests {

/**
* Using categories allows us to create a new one per-test and configure as appropriate.
*/
Expand All @@ -67,13 +71,12 @@ class AuthProviderTests {
/**
* Mock OIDC provider.
*/
private val mockOidcAuthProvider: OidcAuthProvider = mock(OidcAuthProvider::class.java)
private val mockOidcAuthProvider: OidcAuthProvider = mock()

/**
* Mock Lambda provider.
*/
private val mockFunctionAuthProvider: FunctionAuthProvider =
mock(FunctionAuthProvider::class.java)
private val mockFunctionAuthProvider: FunctionAuthProvider = mock()

/**
* Multi-auth configuration JSON for API category.
Expand Down
13 changes: 13 additions & 0 deletions packages/amplify_api/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
targets:
$default:
sources:
- $package$
- lib/$lib$
- lib/**.dart
- test/**.dart
- integration_test/**.dart
builders:
mockito|mockBuilder:
generate_for:
- test/**.dart
- integration_test/**.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
/* Begin PBXBuildFile section */
1498D2341E8E89220040F4C2 /* GeneratedPluginRegistrant.m in Sources */ = {isa = PBXBuildFile; fileRef = 1498D2331E8E89220040F4C2 /* GeneratedPluginRegistrant.m */; };
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */ = {isa = PBXBuildFile; fileRef = 3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */; };
4B6CB97626BAEC7B004E4AA2 /* AuthProviderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4B6CB97526BAEC7B004E4AA2 /* AuthProviderTests.swift */; };
4BF054AD269DE2FB00D1F2BF /* FlutterURLSessionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4BF054AC269DE2FB00D1F2BF /* FlutterURLSessionTests.swift */; };
74858FAF1ED2DC5600515810 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 74858FAE1ED2DC5600515810 /* AppDelegate.swift */; };
7E94B1FC95440E5C0AD8AB8D /* Pods_Runner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 30DF0857352AA115AF86C33E /* Pods_Runner.framework */; };
Expand Down Expand Up @@ -40,7 +39,6 @@
1498D2331E8E89220040F4C2 /* GeneratedPluginRegistrant.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = GeneratedPluginRegistrant.m; sourceTree = "<group>"; };
30DF0857352AA115AF86C33E /* Pods_Runner.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_Runner.framework; sourceTree = BUILT_PRODUCTS_DIR; };
3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; name = AppFrameworkInfo.plist; path = Flutter/AppFrameworkInfo.plist; sourceTree = "<group>"; };
4B6CB97526BAEC7B004E4AA2 /* AuthProviderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AuthProviderTests.swift; sourceTree = "<group>"; };
4BF054AC269DE2FB00D1F2BF /* FlutterURLSessionTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FlutterURLSessionTests.swift; sourceTree = "<group>"; };
7457497C86253194DEE2467D /* Pods-unit_tests.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-unit_tests.release.xcconfig"; path = "Target Support Files/Pods-unit_tests/Pods-unit_tests.release.xcconfig"; sourceTree = "<group>"; };
74858FAD1ED2DC5600515810 /* Runner-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "Runner-Bridging-Header.h"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -160,7 +158,6 @@
FB7C4075259A81E00021F98A /* Info.plist */,
840F5D16259C147800C968A8 /* RestApiUnitTests.swift */,
4BF054AC269DE2FB00D1F2BF /* FlutterURLSessionTests.swift */,
4B6CB97526BAEC7B004E4AA2 /* AuthProviderTests.swift */,
);
path = unit_tests;
sourceTree = "<group>";
Expand Down Expand Up @@ -393,7 +390,6 @@
files = (
FB7C4074259A81E00021F98A /* GraphQLApiUnitTests.swift in Sources */,
840F5D17259C147800C968A8 /* RestApiUnitTests.swift in Sources */,
4B6CB97626BAEC7B004E4AA2 /* AuthProviderTests.swift in Sources */,
4BF054AD269DE2FB00D1F2BF /* FlutterURLSessionTests.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down
Loading