Skip to content

Commit

Permalink
switched to output stream
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Oct 18, 2024
1 parent 6546e39 commit 834513e
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.util

import java.io.OutputStream

fun OutputStream.write(string: String) {
write(string.toByteArray(Charsets.UTF_8))
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ class CsvRowToAirbyteValue {

private fun convertInner(value: String, field: AirbyteType): AirbyteValue {
return when (field) {
is ArrayType -> ArrayValue(value.split(",").map { convertInner(it, field.items.type) })
is ArrayType ->
value
.deserializeToNode()
.elements()
.asSequence()
.map { it.toAirbyteValue(field.items.type) }
.toList()
.let(::ArrayValue)
is BooleanType -> BooleanValue(value.toBoolean())
is IntegerType -> IntegerValue(value.toLong())
is NumberType -> NumberValue(value.toBigDecimal())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import io.airbyte.cdk.load.file.NoopProcessor
import io.airbyte.cdk.load.file.StreamProcessor
import java.io.InputStream
import java.io.OutputStream
import java.io.Writer
import kotlinx.coroutines.flow.Flow

interface ObjectStorageClient<T : RemoteObject<*>> {
Expand All @@ -17,11 +16,11 @@ interface ObjectStorageClient<T : RemoteObject<*>> {
suspend fun <U> get(key: String, block: (InputStream) -> U): U
suspend fun put(key: String, bytes: ByteArray): T
suspend fun delete(remoteObject: T)
suspend fun streamingUpload(key: String, block: suspend (Writer) -> Unit): T =
suspend fun streamingUpload(key: String, block: suspend (OutputStream) -> Unit): T =
streamingUpload(key, NoopProcessor, block)
suspend fun <V : OutputStream> streamingUpload(
key: String,
streamProcessor: StreamProcessor<V>,
block: suspend (Writer) -> Unit
block: suspend (OutputStream) -> Unit
): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import jakarta.inject.Singleton
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.io.OutputStream
import java.io.Writer
import kotlinx.coroutines.flow.flow

data class S3Object(override val key: String, override val storageConfig: S3BucketConfiguration) :
Expand Down Expand Up @@ -113,14 +112,17 @@ class S3Client(
client.deleteObject(request)
}

override suspend fun streamingUpload(key: String, block: suspend (Writer) -> Unit): S3Object {
override suspend fun streamingUpload(
key: String,
block: suspend (OutputStream) -> Unit
): S3Object {
return streamingUpload(key, compressionConfig?.compressor ?: NoopProcessor, block)
}

override suspend fun <U : OutputStream> streamingUpload(
key: String,
streamProcessor: StreamProcessor<U>,
block: suspend (Writer) -> Unit
block: suspend (OutputStream) -> Unit
): S3Object {
val request = CreateMultipartUploadRequest {
this.bucket = bucketConfig.s3BucketName
Expand All @@ -135,13 +137,7 @@ class S3Client(
streamProcessor,
uploadConfig
)
log.info {
"Starting multipart upload to ${response.bucket}/${response.key} (${response.uploadId}"
}
val uploadJob = upload.start()
block(upload.UploadWriter())
upload.complete()
uploadJob.join()
upload.runUsing(block)
return S3Object(key, bucketConfig)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@ import aws.sdk.kotlin.services.s3.model.UploadPartRequest
import aws.smithy.kotlin.runtime.content.ByteStream
import io.airbyte.cdk.load.command.object_storage.ObjectStorageUploadConfiguration
import io.airbyte.cdk.load.file.StreamProcessor
import io.airbyte.cdk.load.util.setOnce
import io.github.oshai.kotlinlogging.KotlinLogging
import java.io.ByteArrayOutputStream
import java.io.OutputStream
import java.io.Writer
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking

/**
* An S3MultipartUpload that provides an [OutputStream] abstraction for writing data. This should
* never be created directly, but used indirectly through [S3Client.streamingUpload].
*
* NOTE: The OutputStream interface does not support suspending functions, but the kotlin s3 SDK
* does. To stitch them together, we could use `runBlocking`, but that would risk blocking the
* thread (and defeating the purpose of using the kotlin sdk). In order to avoid this, we use a
* [Channel] to queue up work and process it a coroutine, launched asynchronously in the same
* context. The work will be coherent as long as the calls to the interface are made synchronously
* (which would be the case without coroutines).
*/
class S3MultipartUpload<T : OutputStream>(
private val client: aws.sdk.kotlin.services.s3.S3Client,
private val response: CreateMultipartUploadResponse,
Expand All @@ -37,35 +47,65 @@ class S3MultipartUpload<T : OutputStream>(
uploadConfig?.streamingUploadPartSize
?: throw IllegalStateException("Streaming upload part size is not configured")
private val wrappingBuffer = streamProcessor.wrapper(underlyingBuffer)
private val workQueue = Channel<suspend () -> Unit>(Channel.UNLIMITED)
private val closeOnce = AtomicBoolean(false)

private val work = Channel<suspend () -> Unit>(Channel.UNLIMITED)

suspend fun start(): Job =
CoroutineScope(Dispatchers.IO).launch {
for (unit in work) {
uploadPart()
/**
* Run the upload using the provided block. This should only be used by the
* [S3Client.streamingUpload] method. Work items are processed asynchronously in the [launch]
* block. The for loop will suspend until [workQueue] is closed, after which the call to
* [complete] will finish the upload.
*
* Moreover, [runUsing] will not return until the launch block exits. This ensures
* - work items are processed in order
* - minimal work is done in [runBlocking] (just enough to enqueue the work items)
* - the upload will not complete until the [OutputStream.close] is called (either by the user
* in [block] or when the [use] block terminates).
* - the upload will not complete until all the work is done
*/
suspend fun runUsing(block: suspend (OutputStream) -> Unit) = coroutineScope {
log.info {
"Starting multipart upload to ${response.bucket}/${response.key} (${response.uploadId}"
}
launch {
for (item in workQueue) {
item()
}
completeInner()
complete()
}

inner class UploadWriter : Writer() {
override fun close() {
log.warn { "Close called on UploadWriter, ignoring." }
UploadStream().use { block(it) }
log.info {
"Completed multipart upload to ${response.bucket}/${response.key} (${response.uploadId}"
}
}

override fun flush() {
throw NotImplementedError("flush() is not supported on S3MultipartUpload.UploadWriter")
inner class UploadStream : OutputStream() {
override fun close() = runBlocking {
workQueue.send {
if (closeOnce.setOnce()) {
workQueue.close()
}
}
}

override fun write(str: String) {
wrappingBuffer.write(str.toByteArray(Charsets.UTF_8))
if (underlyingBuffer.size() >= partSize) {
runBlocking { work.send { uploadPart() } }
override fun flush() = runBlocking { workQueue.send { wrappingBuffer.flush() } }

override fun write(b: Int) = runBlocking {
workQueue.send {
wrappingBuffer.write(b)
if (underlyingBuffer.size() >= partSize) {
uploadPart()
}
}
}

override fun write(cbuf: CharArray, off: Int, len: Int) {
write(String(cbuf, off, len))
override fun write(b: ByteArray) = runBlocking {
workQueue.send {
wrappingBuffer.write(b)
if (underlyingBuffer.size() >= partSize) {
uploadPart()
}
}
}
}

Expand All @@ -89,17 +129,10 @@ class S3MultipartUpload<T : OutputStream>(
underlyingBuffer.reset()
}

suspend fun complete() {
work.close()
}

private suspend fun completeInner() {
private suspend fun complete() {
if (underlyingBuffer.size() > 0) {
uploadPart()
}
log.info {
"Completing multipart upload to ${response.bucket}/${response.key} (${response.uploadId}"
}
val request = CompleteMultipartUploadRequest {
uploadId = response.uploadId
bucket = response.bucket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ data:
secretStore:
type: GSM
alias: airbyte-connector-testing-secret-store
- name: SECRET_DESTINATION-S3-V2-CSV-CONFIG
- name: SECRET_DESTINATION-S3-V2-CSV
fileName: s3_dest_v2_csv_config.json
secretStore:
type: GSM
alias: airbyte-connector-testing-secret-store
- name: SECRET_DESTINATION-S3-V2-CSV-GZIP-CONFIG
- name: SECRET_DESTINATION-S3-V2-CSV-GZIP
fileName: s3_dest_v2_csv_gzip_config.json
secretStore:
type: GSM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.airbyte.cdk.load.file.TimeProvider
import io.airbyte.cdk.load.file.object_storage.ObjectStoragePathFactory
import io.airbyte.cdk.load.file.s3.S3ClientFactory
import io.airbyte.cdk.load.file.s3.S3Object
import io.airbyte.cdk.load.util.write
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.exceptions.ConfigurationException
import jakarta.inject.Singleton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.airbyte.cdk.load.file.s3.S3Object
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.util.serializeToString
import io.airbyte.cdk.load.util.write
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import jakarta.inject.Singleton
Expand Down Expand Up @@ -57,22 +58,26 @@ class S3V2Writer(
val partNumber = partNumber.getAndIncrement()
val key = pathFactory.getPathToFile(stream, partNumber, isStaging = true).toString()
val s3Object =
s3Client.streamingUpload(key) { writer ->
s3Client.streamingUpload(key) { outputStream ->
when (formatConfig.objectStorageFormatConfiguration) {
is JsonFormatConfiguration -> {
records.forEach {
val serialized =
recordDecorator.decorate(it).toJson().serializeToString()
writer.write(serialized)
writer.write("\n")
outputStream.write(serialized)
outputStream.write("\n")
}
}
is CSVFormatConfiguration -> {
stream.schemaWithMeta.toCsvPrinterWithHeader(writer).use { printer ->
records.forEach {
printer.printRecord(*recordDecorator.decorate(it).toCsvRecord())
stream.schemaWithMeta
.toCsvPrinterWithHeader(outputStream.writer())
.use { printer ->
records.forEach {
printer.printRecord(
*recordDecorator.decorate(it).toCsvRecord()
)
}
}
}
}
else -> throw IllegalStateException("Unsupported format")
}
Expand Down

0 comments on commit 834513e

Please sign in to comment.