Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added delimiter option to S3InputFormat and S3GeoTiffRDD. #2062

Merged
merged 1 commit into from
Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class S3GeoTiffRDDSpec
with RasterMatchers
with TestEnvironment {

describe("S3GeoTiffRDD Spatial") {
describe("S3GeoTiffRDD") {
implicit val mockClient = new MockS3Client()
val bucket = this.getClass.getSimpleName

Expand All @@ -62,7 +62,7 @@ class S3GeoTiffRDDSpec

assertEqual(stitched1, stitched2)
}

it("should read the same rasters when reading small windows or with no windows, Spatial, MultibandGeoTiff") {
val key = "geoTiff/multi.tif"
val testGeoTiffPath = "raster-test/data/geotiff-test-files/3bands/byte/3bands-striped-band.tif"
Expand All @@ -82,7 +82,7 @@ class S3GeoTiffRDDSpec

assertEqual(stitched1, stitched2)
}

it("should read the same rasters when reading small windows or with no windows, TemporalSpatial, SinglebandGeoTiff") {
val key = "geoTiff/time.tif"
val testGeoTiffPath = "raster-test/data/one-month-tiles/test-200506000000_0_0.tif"
Expand All @@ -106,7 +106,7 @@ class S3GeoTiffRDDSpec
val dateTime = wholeInfo.time

val collection = source2.collect

cfor(0)(_ < source2.count, _ + 1){ i =>
val (info, _) = collection(i)

Expand Down Expand Up @@ -142,12 +142,32 @@ class S3GeoTiffRDDSpec
val dateTime = wholeInfo.time

val collection = source2.collect

cfor(0)(_ < source2.count, _ + 1){ i =>
val (info, _) = collection(i)

info.time should be (dateTime)
}
}

it("should apply the delimiter option") {
MockS3Client.reset()

val key = "geoTiff/multi-time.tif"

val source1 =
S3GeoTiffRDD.temporalMultiband(
bucket,
key,
S3GeoTiffRDD.Options(
timeTag = "ISO_TIME",
timeFormat = "yyyy-MM-dd'T'HH:mm:ss",
getS3Client = () => new MockS3Client,
delimiter = Some("/")
)
).count

MockS3Client.lastListObjectsRequest.get.getDelimiter should be ("/")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MockS3Client() extends S3Client with LazyLogging {
import MockS3Client._

def listObjects(r: ListObjectsRequest): ObjectListing = this.synchronized {
setLastListObjectsRequest(r)
if (null == r.getMaxKeys)
r.setMaxKeys(64)

Expand Down Expand Up @@ -256,8 +257,10 @@ object MockS3Client{
}
}

def reset(): Unit =
def reset(): Unit = {
buckets.clear()
_lastListObjectsRequest = None
}

val buckets = new ConcurrentHashMap[String, Bucket]()

Expand All @@ -270,4 +273,13 @@ object MockS3Client{
bucket
}
}

// Allow tests to inspect the last ListObjectRequest

var _lastListObjectsRequest: Option[ListObjectsRequest] = None
def lastListObjectsRequest = _lastListObjectsRequest
def setLastListObjectsRequest(r: ListObjectsRequest) =
_lastListObjectsRequest.synchronized {
_lastListObjectsRequest = Some(r)
}
}
3 changes: 3 additions & 0 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3GeoTiffRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ object S3GeoTiffRDD {
* @param numPartitions How many partitions Spark should create when it repartitions the data.
* @param partitionBytes Desired partition size in bytes, at least one item per partition will be assigned
* @param chunkSize How many bytes should be read in at a time.
* @param delimiter Delimiter to use for S3 objet listings. See
* @param getS3Client A function to instantiate an S3Client. Must be serializable.
*/
case class Options(
Expand All @@ -64,6 +65,7 @@ object S3GeoTiffRDD {
numPartitions: Option[Int] = None,
partitionBytes: Option[Long] = None,
chunkSize: Option[Int] = None,
delimiter: Option[String] = None,
getS3Client: () => S3Client = () => S3Client.DEFAULT
) extends RasterReader.Options

Expand All @@ -86,6 +88,7 @@ object S3GeoTiffRDD {
S3InputFormat.setCreateS3Client(conf, options.getS3Client)
options.numPartitions.foreach{ n => S3InputFormat.setPartitionCount(conf, n) }
options.partitionBytes.foreach{ n => S3InputFormat.setPartitionBytes(conf, n) }
options.delimiter.foreach { n => S3InputFormat.setDelimiter(conf, n) }
conf
}

Expand Down
28 changes: 28 additions & 0 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3InputFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ abstract class S3InputFormat[K, V] extends InputFormat[K,V] with LazyLogging {
chunkSizeConf
}

val delimiter = S3InputFormat.getDelimiter(conf)

val partitionCountConf = conf.get(PARTITION_COUNT)
val partitionSizeConf = conf.get(PARTITION_BYTES)
require(null == partitionCountConf || null == partitionSizeConf,
Expand All @@ -88,6 +90,11 @@ abstract class S3InputFormat[K, V] extends InputFormat[K,V] with LazyLogging {
.withBucketName(bucket)
.withPrefix(prefix)

delimiter match {
case Some(d) => request.setDelimiter(d)
case None => // pass
}

def makeNewSplit = {
val split = new S3InputSplit
split.bucket = bucket
Expand Down Expand Up @@ -185,6 +192,7 @@ object S3InputFormat {
final val CHUNK_SIZE = "s3.chunkSize"
final val CRS_VALUE = "s3.crs"
final val CREATE_S3CLIENT = "s3.client"
final val DELIMITER = "s3.delimiter"

private val idRx = "[A-Z0-9]{20}"
private val keyRx = "[a-zA-Z0-9+/]+={0,2}"
Expand Down Expand Up @@ -269,4 +277,24 @@ object S3InputFormat {
/** Set valid key extensions filter */
def setExtensions(conf: Configuration, extensions: Seq[String]): Unit =
conf.set(EXTENSIONS, extensions.mkString(","))

/** Set delimiter for S3 object listing requests */
def setDelimiter(job: Job, delimiter: String): Unit =
setDelimiter(job.getConfiguration, delimiter)

/** Set delimiter for S3 object listing requests */
def setDelimiter(conf: Configuration, delimiter: String): Unit =
conf.set(DELIMITER, delimiter)

def getDelimiter(job: JobContext): Option[String] =
getDelimiter(job.getConfiguration)

def getDelimiter(conf: Configuration): Option[String] = {
val d = conf.get(DELIMITER)
if(d != null)
Some(d)
else
None
}

}