Skip to content

Commit

Permalink
[SPARK-19405][STREAMING] Support for cross-account Kinesis reads via STS
Browse files Browse the repository at this point in the history
- Add dependency on aws-java-sdk-sts
- Replace SerializableAWSCredentials with new SerializableCredentialsProvider interface
- Make KinesisReceiver take SerializableCredentialsProvider as argument and
  pass credential provider to KCL
- Add new implementations of KinesisUtils.createStream() that take STS
  arguments
- Make JavaKinesisStreamSuite test the entire KinesisUtils Java API
- Update KCL/AWS SDK dependencies to 1.7.x/1.11.x

## What changes were proposed in this pull request?

[JIRA link with detailed description.](https://issues.apache.org/jira/browse/SPARK-19405)

* Replace SerializableAWSCredentials with new SerializableKCLAuthProvider class that takes 5 optional config params for configuring AWS auth and returns the appropriate credential provider object
* Add new public createStream() APIs for specifying these parameters in KinesisUtils

## How was this patch tested?

* Manually tested using explicit keypair and instance profile to read data from Kinesis stream in separate account (difficult to write a test orchestrating creation and assumption of IAM roles across separate accounts)
* Expanded JavaKinesisStreamSuite to test the entire Java API in KinesisUtils

## License acknowledgement
This contribution is my original work and that I license the work to the project under the project’s open source license.

Author: Budde <[email protected]>

Closes apache#16744 from budde/master.
  • Loading branch information
Adam Budde authored and Yun Ni committed Feb 27, 2017
1 parent 022d919 commit 93c6477
Show file tree
Hide file tree
Showing 17 changed files with 407 additions and 83 deletions.
5 changes: 5 additions & 0 deletions external/kinesis-asl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
<artifactId>amazon-kinesis-client</artifactId>
<version>${aws.kinesis.client.version}</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sts</artifactId>
<version>${aws.java.sdk.version}</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>amazon-kinesis-producer</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public static void main(String[] args) throws Exception {

// Get the region name from the endpoint URL to save Kinesis Client Library metadata in
// DynamoDB of the same region as the Kinesis stream
String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName();
String regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl);

// Setup the Spark config and StreamingContext
SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.streaming

import scala.collection.JavaConverters._

import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.AmazonKinesis

private[streaming] object KinesisExampleUtils {
def getRegionNameByEndpoint(endpoint: String): String = {
val uri = new java.net.URI(endpoint)
RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX)
.asScala
.find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost))
.map(_.getName)
.getOrElse(
throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ object KinesisWordCountASL extends Logging {

// Get the region name from the endpoint URL to save Kinesis Client Library metadata in
// DynamoDB of the same region as the Kinesis stream
val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
val regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl)

// Setup the SparkConfig and StreamingContext
val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
@transient private val isBlockIdValid: Array[Boolean] = Array.empty,
val retryTimeoutMs: Int = 10000,
val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
val awsCredentialsOption: Option[SerializableAWSCredentials] = None
val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider
) extends BlockRDD[T](sc, _blockIds) {

require(_blockIds.length == arrayOfseqNumberRanges.length,
Expand All @@ -105,9 +105,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
}

def getBlockFromKinesis(): Iterator[T] = {
val credentials = awsCredentialsOption.getOrElse {
new DefaultAWSCredentialsProviderChain().getCredentials()
}
val credentials = kinesisCredsProvider.provider.getCredentials
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
range, retryTimeoutMs).map(messageHandler)
Expand Down Expand Up @@ -143,7 +141,7 @@ class KinesisSequenceRangeIterator(
private var lastSeqNumber: String = null
private var internalIterator: Iterator[Record] = null

client.setEndpoint(endpointUrl, "kinesis", regionId)
client.setEndpoint(endpointUrl)

override protected def getNext(): Record = {
var nextRecord: Record = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent._
import scala.util.control.NonFatal

import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason

import org.apache.spark.internal.Logging
import org.apache.spark.streaming.Duration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials]
kinesisCredsProvider: SerializableCredentialsProvider
) extends ReceiverInputDStream[T](_ssc) {

private[streaming]
Expand All @@ -61,7 +61,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
isBlockIdValid = isBlockIdValid,
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
messageHandler = messageHandler,
awsCredentialsOption = awsCredentialsOption)
kinesisCredsProvider = kinesisCredsProvider)
} else {
logWarning("Kinesis sequence number information was not present with some block metadata," +
" it may not be possible to recover from failures")
Expand All @@ -71,6 +71,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](

override def getReceiver(): Receiver[T] = {
new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
checkpointAppName, checkpointInterval, storageLevel, messageHandler,
kinesisCredsProvider)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain}
import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory}
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker}
import com.amazonaws.services.kinesis.model.Record
Expand All @@ -34,13 +33,6 @@ import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
import org.apache.spark.util.Utils

private[kinesis]
case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
extends AWSCredentials {
override def getAWSAccessKeyId: String = accessKeyId
override def getAWSSecretKey: String = secretKey
}

/**
* Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver.
* This implementation relies on the Kinesis Client Library (KCL) Worker as described here:
Expand Down Expand Up @@ -78,8 +70,9 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* See the Kinesis Spark Streaming documentation for more
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects
* @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
* the credentials
* @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to
* generate the AWSCredentialsProvider instance used for KCL
* authorization.
*/
private[kinesis] class KinesisReceiver[T](
val streamName: String,
Expand All @@ -90,7 +83,7 @@ private[kinesis] class KinesisReceiver[T](
checkpointInterval: Duration,
storageLevel: StorageLevel,
messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials])
kinesisCredsProvider: SerializableCredentialsProvider)
extends Receiver[T](storageLevel) with Logging { receiver =>

/*
Expand Down Expand Up @@ -147,14 +140,15 @@ private[kinesis] class KinesisReceiver[T](
workerId = Utils.localHostName() + ":" + UUID.randomUUID()

kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId)
// KCL config instance
val awsCredProvider = resolveAWSCredentialsProvider()
val kinesisClientLibConfiguration =
new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId)
.withKinesisEndpoint(endpointUrl)
.withInitialPositionInStream(initialPositionInStream)
.withTaskBackoffTimeMillis(500)
.withRegionName(regionName)
val kinesisClientLibConfiguration = new KinesisClientLibConfiguration(
checkpointAppName,
streamName,
kinesisCredsProvider.provider,
workerId)
.withKinesisEndpoint(endpointUrl)
.withInitialPositionInStream(initialPositionInStream)
.withTaskBackoffTimeMillis(500)
.withRegionName(regionName)

/*
* RecordProcessorFactory creates impls of IRecordProcessor.
Expand Down Expand Up @@ -305,25 +299,6 @@ private[kinesis] class KinesisReceiver[T](
}
}

/**
* If AWS credential is provided, return a AWSCredentialProvider returning that credential.
* Otherwise, return the DefaultAWSCredentialsProviderChain.
*/
private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = {
awsCredentialsOption match {
case Some(awsCredentials) =>
logInfo("Using provided AWS credentials")
new AWSCredentialsProvider {
override def getCredentials: AWSCredentials = awsCredentials
override def refresh(): Unit = { }
}
case None =>
logInfo("Using DefaultAWSCredentialsProviderChain")
new DefaultAWSCredentialsProviderChain()
}
}


/**
* Class to handle blocks generated by this receiver's block generator. Specifically, in
* the context of the Kinesis Receiver, this handler does the following.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.util.control.NonFatal

import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException}
import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer}
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason
import com.amazonaws.services.kinesis.model.Record

import org.apache.spark.internal.Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient
import com.amazonaws.services.dynamodbv2.document.DynamoDB
import com.amazonaws.services.kinesis.AmazonKinesisClient
import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient}
import com.amazonaws.services.kinesis.model._

import org.apache.spark.internal.Logging
Expand All @@ -43,7 +43,7 @@ import org.apache.spark.internal.Logging
private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging {

val endpointUrl = KinesisTestUtils.endpointUrl
val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl)

private val createStreamTimeoutSeconds = 300
private val describeStreamPollTimeSeconds = 1
Expand Down Expand Up @@ -205,6 +205,16 @@ private[kinesis] object KinesisTestUtils {
val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL"
val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com"

def getRegionNameByEndpoint(endpoint: String): String = {
val uri = new java.net.URI(endpoint)
RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX)
.asScala
.find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost))
.map(_.getName)
.getOrElse(
throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint"))
}

lazy val shouldRunTests = {
val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1")
if (isEnvSet) {
Expand Down
Loading

0 comments on commit 93c6477

Please sign in to comment.