diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c7918fffba..99600c94f9 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -37,3 +37,5 @@ jobs: working-directory: components/tls skip-cache: true args: --timeout 3m --verbose + - name: lint-dataflow + run: make -C scheduler/data-flow lint diff --git a/scheduler/Makefile b/scheduler/Makefile index 6097512d52..6f274ecf25 100644 --- a/scheduler/Makefile +++ b/scheduler/Makefile @@ -78,12 +78,19 @@ ${.GOLANGCILINT_PATH}/golangci-lint: curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh \ | sh -s -- -b ${.GOLANGCILINT_PATH} ${.GOLANGCILINT_VERSION} -.PHONY: lint -lint: ${.GOLANGCILINT_PATH}/golangci-lint +.PHONY: lint-go +lint-go: ${.GOLANGCILINT_PATH}/golangci-lint gofmt -w pkg gofmt -w cmd ${.GOLANGCILINT_PATH}/golangci-lint run --fix +.PHONY: lint-jvm +lint-jvm: + make -C data-flow lint + +.PHONY: lint +lint: lint-go lint-jvm + .PHONY: test-go test-go: go test ./pkg/... -coverprofile cover.out diff --git a/scheduler/data-flow/.gitignore b/scheduler/data-flow/.gitignore index 8fe2c2aa52..5dcf1836e6 100644 --- a/scheduler/data-flow/.gitignore +++ b/scheduler/data-flow/.gitignore @@ -341,4 +341,7 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +# VSCode +.vscode/ + # End of https://www.gitignore.io/api/java,linux,macos,windows,android,intellij,androidstudio \ No newline at end of file diff --git a/scheduler/data-flow/Makefile b/scheduler/data-flow/Makefile index 5efcfd9bb6..ecd7b95dd9 100644 --- a/scheduler/data-flow/Makefile +++ b/scheduler/data-flow/Makefile @@ -7,3 +7,15 @@ licenses: chmod +x ./scripts/generate_license.sh ./scripts/generate_license.sh licenses/dependency-license.json licenses/dependency-license.txt cp ../../LICENSE licenses/license.txt + +.PHONY: lint +lint: + ./gradlew ktlintCheck --no-daemon --no-build-cache --continue + +.PHONY: format +format: + ./gradlew ktlintFormat --no-daemon --no-build-cache + +.PHONY: build +build: + ./gradlew build --no-daemon --no-build-cache diff --git a/scheduler/data-flow/README.md b/scheduler/data-flow/README.md index 31285243c6..fa003d639a 100644 --- a/scheduler/data-flow/README.md +++ b/scheduler/data-flow/README.md @@ -54,6 +54,14 @@ $ ./gradlew build BUILD SUCCESSFUL in 536ms 6 actionable tasks: 6 up-to-date ``` +You can also run lint check and formatting using `ktlint`, which are also available as make targets i.e. `make lint` and `make format`. + +```bash +$ ./gradlew ktlintCheck +$ ./gradlew ktlintFormat +``` + +After a successful run, they generate reports in `build/reports/ktlint`.
Unsupported class file major version diff --git a/scheduler/data-flow/build.gradle.kts b/scheduler/data-flow/build.gradle.kts index 6c1ff53593..65930bb8e0 100644 --- a/scheduler/data-flow/build.gradle.kts +++ b/scheduler/data-flow/build.gradle.kts @@ -1,12 +1,12 @@ -import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar +import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { id("com.github.hierynomus.license-report") version "0.16.1" id("com.github.johnrengelman.shadow") version "8.1.1" - kotlin("jvm") version "1.8.20" // the kotlin version - + kotlin("jvm") version "1.8.20" // the kotlin version + id("org.jlleitschuh.gradle.ktlint") version "12.1.0" java application } @@ -98,3 +98,12 @@ downloadLicenses { includeProjectDependencies = true dependencyConfiguration = "compileClasspath" } + +ktlint { + verbose = true + debug = true + // Ignore generated code from proto + filter { + exclude { element -> element.file.path.contains("apis/mlops") } + } +} diff --git a/scheduler/data-flow/settings.gradle.kts b/scheduler/data-flow/settings.gradle.kts index ca4d372114..9e9cc74e75 100644 --- a/scheduler/data-flow/settings.gradle.kts +++ b/scheduler/data-flow/settings.gradle.kts @@ -1,2 +1 @@ rootProject.name = "dataflow" - diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt index 8ef14d2dcb..1e5485aaea 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt @@ -9,14 +9,25 @@ the Change License after the Change Date as each is defined in accordance with t package io.seldon.dataflow -import com.natpryce.konfig.* +import com.natpryce.konfig.CommandLineOption +import com.natpryce.konfig.Configuration +import com.natpryce.konfig.ConfigurationProperties +import com.natpryce.konfig.EnvironmentVariables +import com.natpryce.konfig.Key +import com.natpryce.konfig.booleanType +import com.natpryce.konfig.enumType +import com.natpryce.konfig.intType +import com.natpryce.konfig.longType +import com.natpryce.konfig.overriding +import com.natpryce.konfig.parseArgs +import com.natpryce.konfig.stringType import io.klogging.Level import io.klogging.noCoLogger import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms import io.seldon.dataflow.kafka.security.KafkaSecurityProtocols object Cli { - private const val envVarPrefix = "SELDON_" + private const val ENV_VAR_PREFIX = "SELDON_" private val logger = noCoLogger(Cli::class) // General setup @@ -94,18 +105,19 @@ object Cli { fun configWith(rawArgs: Array): Configuration { val fromProperties = ConfigurationProperties.fromResource("local.properties") - val fromEnv = EnvironmentVariables(prefix = envVarPrefix) + val fromEnv = EnvironmentVariables(prefix = ENV_VAR_PREFIX) val fromArgs = parseArguments(rawArgs) return fromArgs overriding fromEnv overriding fromProperties } private fun parseArguments(rawArgs: Array): Configuration { - val (config, unparsedArgs) = parseArgs( - rawArgs, - *this.args().map { CommandLineOption(it) }.toTypedArray(), - programName = "seldon-dataflow-engine", - ) + val (config, unparsedArgs) = + parseArgs( + rawArgs, + *this.args().map { CommandLineOption(it) }.toTypedArray(), + programName = "seldon-dataflow-engine", + ) if (unparsedArgs.isNotEmpty()) { logUnknownArguments(unparsedArgs) } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/DataflowStatus.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/DataflowStatus.kt index ca320069db..93ffa9c774 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/DataflowStatus.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/DataflowStatus.kt @@ -13,10 +13,10 @@ import kotlinx.coroutines.runBlocking * return values to indicate errors/status updates that require special handling in the code. */ interface DataflowStatus { - var exception : Exception? - var message : String? + var exception: Exception? + var message: String? - fun getDescription() : String? { + fun getDescription(): String? { val exceptionMsg = this.exception?.message return if (exceptionMsg != null) { "${this.message} Exception: $exceptionMsg" @@ -26,7 +26,10 @@ interface DataflowStatus { } // log status when logger is in a coroutine - fun log(logger: Klogger, levelIfNoException: Level) { + fun log( + logger: Klogger, + levelIfNoException: Level, + ) { val exceptionMsg = this.exception?.message val exceptionCause = this.exception?.cause ?: Exception("") val statusMsg = this.message @@ -42,7 +45,10 @@ interface DataflowStatus { } // log status when logger is outside coroutines - fun log(logger: NoCoLogger, levelIfNoException: Level) { + fun log( + logger: NoCoLogger, + levelIfNoException: Level, + ) { val exceptionMsg = this.exception?.message val exceptionCause = this.exception?.cause ?: Exception("") if (exceptionMsg != null) { @@ -53,13 +59,12 @@ interface DataflowStatus { } } -fun T.withException(e: Exception) : T { +fun T.withException(e: Exception): T { this.exception = e return this } -fun T.withMessage(msg: String): T { +fun T.withMessage(msg: String): T { this.message = msg return this } - diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/GrpcServiceConfigProvider.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/GrpcServiceConfigProvider.kt index 88ba4a3449..64862037ed 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/GrpcServiceConfigProvider.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/GrpcServiceConfigProvider.kt @@ -15,27 +15,32 @@ object GrpcServiceConfigProvider { // Details: https://github.com/grpc/proposal/blob/master/A6-client-retries.md#validation-of-retrypolicy // Example: https://github.com/grpc/grpc-java/blob/v1.35.0/examples/src/main/resources/io/grpc/examples/retrying/retrying_service_config.json // However does not work: https://github.com/grpc/grpc-kotlin/issues/277 - val config = mapOf( - "methodConfig" to listOf( - mapOf( - "name" to listOf( + val config = + mapOf( + "methodConfig" to + listOf( mapOf( - "service" to "io.seldon.mlops.chainer.Chainer", - "method" to "SubscribePipelineUpdates", + "name" to + listOf( + mapOf( + "service" to "io.seldon.mlops.chainer.Chainer", + "method" to "SubscribePipelineUpdates", + ), + ), + "retryPolicy" to + mapOf( + "maxAttempts" to "100", + "initialBackoff" to "1s", + "maxBackoff" to "30s", + "backoffMultiplier" to 1.5, + "retryableStatusCodes" to + listOf( + Status.UNAVAILABLE.code.toString(), + Status.CANCELLED.code.toString(), + Status.FAILED_PRECONDITION.code.toString(), + ), + ), ), ), - "retryPolicy" to mapOf( - "maxAttempts" to "100", - "initialBackoff" to "1s", - "maxBackoff" to "30s", - "backoffMultiplier" to 1.5, - "retryableStatusCodes" to listOf( - Status.UNAVAILABLE.code.toString(), - Status.CANCELLED.code.toString(), - Status.FAILED_PRECONDITION.code.toString(), - ) - ) - ), - ), - ) -} \ No newline at end of file + ) +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Logging.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Logging.kt index bea4b0762b..b0294fe203 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Logging.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Logging.kt @@ -15,21 +15,23 @@ import io.klogging.rendering.RENDER_ISO8601 import io.klogging.sending.STDOUT object Logging { - private const val stdoutSink = "stdout" + private const val STDOUT_SINK = "stdout" - fun configure(appLevel: Level = Level.INFO, kafkaLevel: Level = Level.WARN) = - loggingConfiguration { - kloggingMinLogLevel(appLevel) - sink(stdoutSink, RENDER_ISO8601, STDOUT) - logging { - fromLoggerBase("io.seldon") - toSink(stdoutSink) - } - logging { - fromMinLevel(kafkaLevel) { - fromLoggerBase("org.apache") - toSink(stdoutSink) - } + fun configure( + appLevel: Level = Level.INFO, + kafkaLevel: Level = Level.WARN, + ) = loggingConfiguration { + kloggingMinLogLevel(appLevel) + sink(STDOUT_SINK, RENDER_ISO8601, STDOUT) + logging { + fromLoggerBase("io.seldon") + toSink(STDOUT_SINK) + } + logging { + fromMinLevel(kafkaLevel) { + fromLoggerBase("org.apache") + toSink(STDOUT_SINK) } } -} \ No newline at end of file + } +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt index 05b7aaf7fa..8d4a899eaa 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt @@ -10,10 +10,15 @@ the Change License after the Change Date as each is defined in accordance with t package io.seldon.dataflow import io.klogging.noCoLogger -import io.seldon.dataflow.kafka.* +import io.seldon.dataflow.kafka.KafkaDomainParams +import io.seldon.dataflow.kafka.KafkaSecurityParams +import io.seldon.dataflow.kafka.KafkaStreamsParams +import io.seldon.dataflow.kafka.TopicWaitRetryParams +import io.seldon.dataflow.kafka.getKafkaAdminProperties +import io.seldon.dataflow.kafka.getKafkaProperties import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms -import io.seldon.dataflow.mtls.CertificateConfig import io.seldon.dataflow.kafka.security.SaslConfig +import io.seldon.dataflow.mtls.CertificateConfig import kotlinx.coroutines.runBlocking object Main { @@ -32,74 +37,85 @@ object Main { val effectiveArgs = Cli.args().map { arg -> arg.name to config[arg] } logger.info { "initialised with config $effectiveArgs" } - val tlsCertConfig = CertificateConfig( - caCertPath = config[Cli.tlsCACertPath], - keyPath = config[Cli.tlsKeyPath], - certPath = config[Cli.tlsCertPath], - brokerCaCertPath = config[Cli.brokerCACertPath], - clientSecret = config[Cli.clientSecret], - brokerSecret = config[Cli.brokerSecret], - endpointIdentificationAlgorithm = config[Cli.endpointIdentificationAlgorithm], - ) - - val saslConfig = when (config[Cli.saslMechanism]) { - KafkaSaslMechanisms.PLAIN -> SaslConfig.Password.Plain( - secretName = config[Cli.saslSecret], - username = config[Cli.saslUsername], - passwordField = config[Cli.saslPasswordPath], - ) - KafkaSaslMechanisms.SCRAM_SHA_256 -> SaslConfig.Password.Scram256( - secretName = config[Cli.saslSecret], - username = config[Cli.saslUsername], - passwordField = config[Cli.saslPasswordPath], + val tlsCertConfig = + CertificateConfig( + caCertPath = config[Cli.tlsCACertPath], + keyPath = config[Cli.tlsKeyPath], + certPath = config[Cli.tlsCertPath], + brokerCaCertPath = config[Cli.brokerCACertPath], + clientSecret = config[Cli.clientSecret], + brokerSecret = config[Cli.brokerSecret], + endpointIdentificationAlgorithm = config[Cli.endpointIdentificationAlgorithm], ) - KafkaSaslMechanisms.SCRAM_SHA_512 -> SaslConfig.Password.Scram512( - secretName = config[Cli.saslSecret], - username = config[Cli.saslUsername], - passwordField = config[Cli.saslPasswordPath], + + val saslConfig = + when (config[Cli.saslMechanism]) { + KafkaSaslMechanisms.PLAIN -> + SaslConfig.Password.Plain( + secretName = config[Cli.saslSecret], + username = config[Cli.saslUsername], + passwordField = config[Cli.saslPasswordPath], + ) + KafkaSaslMechanisms.SCRAM_SHA_256 -> + SaslConfig.Password.Scram256( + secretName = config[Cli.saslSecret], + username = config[Cli.saslUsername], + passwordField = config[Cli.saslPasswordPath], + ) + KafkaSaslMechanisms.SCRAM_SHA_512 -> + SaslConfig.Password.Scram512( + secretName = config[Cli.saslSecret], + username = config[Cli.saslUsername], + passwordField = config[Cli.saslPasswordPath], + ) + KafkaSaslMechanisms.OAUTH_BEARER -> + SaslConfig.Oauth( + secretName = config[Cli.saslSecret], + ) + } + + val kafkaSecurityParams = + KafkaSecurityParams( + securityProtocol = config[Cli.kafkaSecurityProtocol], + certConfig = tlsCertConfig, + saslConfig = saslConfig, ) - KafkaSaslMechanisms.OAUTH_BEARER -> SaslConfig.Oauth( - secretName = config[Cli.saslSecret], + val kafkaStreamsParams = + KafkaStreamsParams( + bootstrapServers = config[Cli.kafkaBootstrapServers], + numPartitions = config[Cli.kafkaPartitions], + replicationFactor = config[Cli.kafkaReplicationFactor], + maxMessageSizeBytes = config[Cli.kafkaMaxMessageSizeBytes], + security = kafkaSecurityParams, ) - } - - val kafkaSecurityParams = KafkaSecurityParams( - securityProtocol = config[Cli.kafkaSecurityProtocol], - certConfig = tlsCertConfig, - saslConfig = saslConfig, - ) - val kafkaStreamsParams = KafkaStreamsParams( - bootstrapServers = config[Cli.kafkaBootstrapServers], - numPartitions = config[Cli.kafkaPartitions], - replicationFactor = config[Cli.kafkaReplicationFactor], - maxMessageSizeBytes = config[Cli.kafkaMaxMessageSizeBytes], - security = kafkaSecurityParams, - ) val kafkaProperties = getKafkaProperties(kafkaStreamsParams) val kafkaAdminProperties = getKafkaAdminProperties(kafkaStreamsParams) - val kafkaDomainParams = KafkaDomainParams( - useCleanState = config[Cli.kafkaUseCleanState], - joinWindowMillis = config[Cli.kafkaJoinWindowMillis], - ) - val topicWaitRetryParams = TopicWaitRetryParams( - createTimeoutMillis = config[Cli.topicCreateTimeoutMillis], - describeTimeoutMillis = config[Cli.topicDescribeTimeoutMillis], - describeRetries = config[Cli.topicDescribeRetries], - describeRetryDelayMillis = config[Cli.topicDescribeRetryDelayMillis] - ) - val subscriber = PipelineSubscriber( - "seldon-dataflow-engine", - kafkaProperties, - kafkaAdminProperties, - kafkaStreamsParams, - kafkaDomainParams, - topicWaitRetryParams, - config[Cli.upstreamHost], - config[Cli.upstreamPort], - GrpcServiceConfigProvider.config, - config[Cli.kafkaConsumerGroupIdPrefix], - config[Cli.namespace], - ) + val kafkaDomainParams = + KafkaDomainParams( + useCleanState = config[Cli.kafkaUseCleanState], + joinWindowMillis = config[Cli.kafkaJoinWindowMillis], + ) + val topicWaitRetryParams = + TopicWaitRetryParams( + createTimeoutMillis = config[Cli.topicCreateTimeoutMillis], + describeTimeoutMillis = config[Cli.topicDescribeTimeoutMillis], + describeRetries = config[Cli.topicDescribeRetries], + describeRetryDelayMillis = config[Cli.topicDescribeRetryDelayMillis], + ) + val subscriber = + PipelineSubscriber( + "seldon-dataflow-engine", + kafkaProperties, + kafkaAdminProperties, + kafkaStreamsParams, + kafkaDomainParams, + topicWaitRetryParams, + config[Cli.upstreamHost], + config[Cli.upstreamPort], + GrpcServiceConfigProvider.config, + config[Cli.kafkaConsumerGroupIdPrefix], + config[Cli.namespace], + ) addShutdownHandler(subscriber) @@ -115,7 +131,7 @@ object Main { logger.info("received shutdown signal") subscriber.cancelPipelines("shutting down") } - } + }, ) } } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/PipelineSubscriber.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/PipelineSubscriber.kt index 7b90b9bf99..2b120a30ba 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/PipelineSubscriber.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/PipelineSubscriber.kt @@ -13,10 +13,22 @@ import com.github.michaelbull.retry.policy.binaryExponentialBackoff import com.github.michaelbull.retry.retry import io.grpc.ManagedChannelBuilder import io.klogging.Level -import io.seldon.dataflow.kafka.* +import io.seldon.dataflow.kafka.KafkaAdmin +import io.seldon.dataflow.kafka.KafkaAdminProperties +import io.seldon.dataflow.kafka.KafkaDomainParams +import io.seldon.dataflow.kafka.KafkaProperties +import io.seldon.dataflow.kafka.KafkaStreamsParams +import io.seldon.dataflow.kafka.Pipeline +import io.seldon.dataflow.kafka.PipelineId +import io.seldon.dataflow.kafka.PipelineMetadata +import io.seldon.dataflow.kafka.PipelineStatus +import io.seldon.dataflow.kafka.TopicWaitRetryParams import io.seldon.mlops.chainer.ChainerGrpcKt -import io.seldon.mlops.chainer.ChainerOuterClass.* +import io.seldon.mlops.chainer.ChainerOuterClass.PipelineStepUpdate +import io.seldon.mlops.chainer.ChainerOuterClass.PipelineSubscriptionRequest +import io.seldon.mlops.chainer.ChainerOuterClass.PipelineUpdateMessage import io.seldon.mlops.chainer.ChainerOuterClass.PipelineUpdateMessage.PipelineOperation +import io.seldon.mlops.chainer.ChainerOuterClass.PipelineUpdateStatusMessage import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -42,21 +54,22 @@ class PipelineSubscriber( private val namespace: String, ) { private val kafkaAdmin = KafkaAdmin(kafkaAdminProperties, kafkaStreamsParams, topicWaitRetryParams) - private val channel = ManagedChannelBuilder - .forAddress(upstreamHost, upstreamPort) - .defaultServiceConfig(grpcServiceConfig) - .usePlaintext() // Use TLS - .enableRetry() - .build() + private val channel = + ManagedChannelBuilder + .forAddress(upstreamHost, upstreamPort) + .defaultServiceConfig(grpcServiceConfig) + .usePlaintext() // Use TLS + .enableRetry() + .build() private val client = ChainerGrpcKt.ChainerCoroutineStub(channel) private val pipelines = ConcurrentHashMap() suspend fun subscribe() { while (true) { - logger.info("will connect to ${upstreamHost}:${upstreamPort}") + logger.info("will connect to $upstreamHost:$upstreamPort") retry(binaryExponentialBackoff(50..5_000L)) { - logger.debug("retrying to connect to ${upstreamHost}:${upstreamPort}") + logger.debug("retrying to connect to $upstreamHost:$upstreamPort") subscribePipelines(kafkaConsumerGroupIdPrefix, namespace) } } @@ -69,18 +82,22 @@ class PipelineSubscriber( // Pipeline UID should be enough to uniquely key it, even across versions? // ... // - Add map of model name -> (weak) referrents/reference count to avoid recreation of streams - private suspend fun subscribePipelines(kafkaConsumerGroupIdPrefix: String, namespace: String) { + private suspend fun subscribePipelines( + kafkaConsumerGroupIdPrefix: String, + namespace: String, + ) { logger.info("Subscribing to pipeline updates") client .subscribePipelineUpdates(request = makeSubscriptionRequest()) .onEach { update -> logger.info("received request for ${update.pipeline}:${update.version} Id:${update.uid}") - val metadata = PipelineMetadata( - id = update.uid, - name = update.pipeline, - version = update.version, - ) + val metadata = + PipelineMetadata( + id = update.uid, + name = update.pipeline, + version = update.version, + ) when (update.op) { PipelineOperation.Create -> handleCreate(metadata, update.updatesList, kafkaConsumerGroupIdPrefix, namespace) @@ -97,12 +114,15 @@ class PipelineSubscriber( // Defend against any existing pipelines that have failed but are not yet stopped, so that // kafka streams may clean up resources (including temporary files). This is a catch-all // and indicates we've missed calling stop in a failure case. - if(it.value.status.isError) { - logger.debug("(bug) pipeline in error state when subscription terminates with error. pipeline id: {pipelineId}", it.key) + if (it.value.status.isError) { + logger.debug( + "(bug) pipeline in error state when subscription terminates with error. pipeline id: {pipelineId}", + it.key, + ) it.value.stop() } } - logger.error("pipeline subscription terminated with error ${cause}") + logger.error("pipeline subscription terminated with error $cause") } } .collect() @@ -125,16 +145,17 @@ class PipelineSubscriber( "Create pipeline {pipelineName} version: {pipelineVersion} id: {pipelineId}", metadata.name, metadata.version, - metadata.id - ) - val (pipeline, err) = Pipeline.forSteps( - metadata, - steps, - kafkaProperties, - kafkaDomainParams, - kafkaConsumerGroupIdPrefix, - namespace + metadata.id, ) + val (pipeline, err) = + Pipeline.forSteps( + metadata, + steps, + kafkaProperties, + kafkaDomainParams, + kafkaConsumerGroupIdPrefix, + namespace, + ) if (err != null) { err.log(logger, Level.ERROR) client.pipelineUpdateEvent( @@ -142,13 +163,13 @@ class PipelineSubscriber( metadata = metadata, operation = PipelineOperation.Create, success = false, - reason = err.getDescription() ?: "failed to initialize dataflow engine" - ) + reason = err.getDescription() ?: "failed to initialize dataflow engine", + ), ) return } - pipeline!! //assert pipeline is not null when err is null + pipeline!! // assert pipeline is not null when err is null if (pipeline.size != steps.size) { pipeline.stop() client.pipelineUpdateEvent( @@ -156,8 +177,8 @@ class PipelineSubscriber( metadata = metadata, operation = PipelineOperation.Create, success = false, - reason = "failed to create all pipeline steps" - ) + reason = "failed to create all pipeline steps", + ), ) return @@ -166,13 +187,14 @@ class PipelineSubscriber( val previous = pipelines.putIfAbsent(metadata.id, pipeline) var pipelineStatus: PipelineStatus if (previous == null) { - val err = kafkaAdmin.ensureTopicsExist(steps) - if (err == null) { + val errTopics = kafkaAdmin.ensureTopicsExist(steps) + if (errTopics == null) { pipelineStatus = pipeline.start() } else { - pipelineStatus = PipelineStatus.Error(null) - .withException(err) - .withMessage("kafka streams topic creation error") + pipelineStatus = + PipelineStatus.Error(null) + .withException(errTopics) + .withMessage("kafka streams topic creation error") pipeline.stop() } } else { @@ -190,7 +212,7 @@ class PipelineSubscriber( // pipeline has started. While states such as "StreamStarting" or "StreamStopped" are // not in themselves errors, if the pipeline is not running here then it can't // be marked as ready. - if(pipelineStatus !is PipelineStatus.Started) { + if (pipelineStatus !is PipelineStatus.Started) { pipelineStatus.isError = true } pipelineStatus.log(logger, Level.INFO) @@ -199,13 +221,18 @@ class PipelineSubscriber( metadata = metadata, operation = PipelineOperation.Create, success = !pipelineStatus.isError, - reason = pipelineStatus.getDescription() ?: "pipeline created" - ) + reason = pipelineStatus.getDescription() ?: "pipeline created", + ), ) } private suspend fun handleDelete(metadata: PipelineMetadata) { - logger.info("Delete pipeline {pipelineName} version: {pipelineVersion} id: {pipelineId}", metadata.name, metadata.version, metadata.id ) + logger.info( + "Delete pipeline {pipelineName} version: {pipelineVersion} id: {pipelineId}", + metadata.name, + metadata.version, + metadata.id, + ) pipelines .remove(metadata.id) ?.also { pipeline -> @@ -219,7 +246,7 @@ class PipelineSubscriber( operation = PipelineOperation.Delete, success = true, reason = "pipeline removed", - ) + ), ) } @@ -251,7 +278,7 @@ class PipelineSubscriber( .setPipeline(metadata.name) .setVersion(metadata.version) .setUid(metadata.id) - .build() + .build(), ) .build() } @@ -259,4 +286,4 @@ class PipelineSubscriber( companion object { private val logger = coLogger(PipelineSubscriber::class) } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/TypeExtensions.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/TypeExtensions.kt index ccbd18ffe6..90c6d0418c 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/TypeExtensions.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/TypeExtensions.kt @@ -21,14 +21,14 @@ import kotlinx.coroutines.flow.flow suspend fun Flow.parallel( scope: CoroutineScope, concurrency: Int = DEFAULT_CONCURRENCY, - transform: suspend (T) -> R + transform: suspend (T) -> R, ): Flow { return with(scope) { this@parallel .flatMapMerge(concurrency) { value -> flow { emit( - async { transform(value) } + async { transform(value) }, ) } } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/hashutils/HashUtils.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/hashutils/HashUtils.kt index 30c9e83caa..f9874194cc 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/hashutils/HashUtils.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/hashutils/HashUtils.kt @@ -13,18 +13,18 @@ import java.math.BigInteger import java.security.MessageDigest object HashUtils { - private const val algoMD5 = "MD5" - private const val maxOutputLength = 16 + private const val ALGO_MD5 = "MD5" + private const val MAX_OUTPUT_LENGTH = 16 fun hashIfLong(input: String): String { - if (input.length <= maxOutputLength) { + if (input.length <= MAX_OUTPUT_LENGTH) { return input } - val md = MessageDigest.getInstance(algoMD5) + val md = MessageDigest.getInstance(ALGO_MD5) val hashedBytes = md.digest(input.toByteArray()) return BigInteger(1, hashedBytes) .toString(16) - .padStart(maxOutputLength, '0') + .padStart(MAX_OUTPUT_LENGTH, '0') } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BatchProcessor.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BatchProcessor.kt index 6b2a02ab21..d3c891a5d6 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BatchProcessor.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BatchProcessor.kt @@ -25,7 +25,6 @@ import org.apache.kafka.streams.state.Stores typealias TBatchRequest = KeyValue? typealias TBatchStore = KeyValueStore - class BatchProcessor(private val threshold: Int) : Transformer { private var ctx: ProcessorContext? = null private val aggregateStore: TBatchStore by lazy { @@ -39,7 +38,10 @@ class BatchProcessor(private val threshold: Int) : Transformer= threshold -> { - aggregateStore.delete(stateStoreKey) - KeyValue.pair(key, batchedRequest) + val returnValue = + when { + batchSize >= threshold -> { + aggregateStore.delete(stateStoreKey) + KeyValue.pair(key, batchedRequest) + } + else -> null } - else -> null - } return returnValue } @@ -65,12 +68,13 @@ class BatchProcessor(private val threshold: Int) : Transformer): ModelInferRequest { val batchReferenceRequest = requests.last() - val combinedRequest = ModelInferRequest - .newBuilder() - .setId(batchReferenceRequest.id) - .setModelName(batchReferenceRequest.modelName) - .setModelVersion(batchReferenceRequest.modelVersion) - .putAllParameters(batchReferenceRequest.parametersMap) + val combinedRequest = + ModelInferRequest + .newBuilder() + .setId(batchReferenceRequest.id) + .setModelName(batchReferenceRequest.modelName) + .setModelVersion(batchReferenceRequest.modelVersion) + .putAllParameters(batchReferenceRequest.parametersMap) when { requests.any { it.rawInputContentsCount > 0 } -> { @@ -96,48 +100,61 @@ class BatchProcessor(private val threshold: Int) : Transformer this.addAllUintContents( - tensors.flatMap { it.contents.uintContentsList } - ) - DataType.UINT64 -> this.addAllUint64Contents( - tensors.flatMap { it.contents.uint64ContentsList } - ) - DataType.INT8, - DataType.INT16, - DataType.INT32 -> this.addAllIntContents( - tensors.flatMap { it.contents.intContentsList } - ) - DataType.INT64 -> this.addAllInt64Contents( - tensors.flatMap { it.contents.int64ContentsList } - ) - DataType.FP16, // may need to handle this separately in future - DataType.FP32 -> this.addAllFp32Contents( - tensors.flatMap { it.contents.fp32ContentsList } - ) - DataType.FP64 -> this.addAllFp64Contents( - tensors.flatMap { it.contents.fp64ContentsList } - ) - DataType.BOOL -> this.addAllBoolContents( - tensors.flatMap { it.contents.boolContentsList } - ) - DataType.BYTES -> this.addAllBytesContents( - tensors.flatMap { it.contents.bytesContentsList } - ) + val contents = + InferTensorContents + .newBuilder() + .apply { + when (DataType.valueOf(datatype)) { + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + -> + this.addAllUintContents( + tensors.flatMap { it.contents.uintContentsList }, + ) + DataType.UINT64 -> + this.addAllUint64Contents( + tensors.flatMap { it.contents.uint64ContentsList }, + ) + DataType.INT8, + DataType.INT16, + DataType.INT32, + -> + this.addAllIntContents( + tensors.flatMap { it.contents.intContentsList }, + ) + DataType.INT64 -> + this.addAllInt64Contents( + tensors.flatMap { it.contents.int64ContentsList }, + ) + DataType.FP16, // may need to handle this separately in future + DataType.FP32, + -> + this.addAllFp32Contents( + tensors.flatMap { it.contents.fp32ContentsList }, + ) + DataType.FP64 -> + this.addAllFp64Contents( + tensors.flatMap { it.contents.fp64ContentsList }, + ) + DataType.BOOL -> + this.addAllBoolContents( + tensors.flatMap { it.contents.boolContentsList }, + ) + DataType.BYTES -> + this.addAllBytesContents( + tensors.flatMap { it.contents.bytesContentsList }, + ) + } } - } - .build() + .build() InferInputTensor .newBuilder() @@ -168,13 +185,14 @@ class BatchProcessor(private val threshold: Int) : Transformer> = Stores - .keyValueStoreBuilder( - Stores.inMemoryKeyValueStore(STATE_STORE_ID), - Serdes.String(), - Serdes.ByteArray(), - ) - .withLoggingDisabled() - .withCachingDisabled() + val stateStoreBuilder: StoreBuilder> = + Stores + .keyValueStoreBuilder( + Stores.inMemoryKeyValueStore(STATE_STORE_ID), + Serdes.String(), + Serdes.ByteArray(), + ) + .withLoggingDisabled() + .withCachingDisabled() } } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BinaryContent.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BinaryContent.kt index 213fa18596..12e0359ed5 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BinaryContent.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/BinaryContent.kt @@ -7,7 +7,6 @@ Use of this software is governed BY the Change License after the Change Date as each is defined in accordance with the LICENSE file. */ - package io.seldon.dataflow.kafka import com.google.protobuf.kotlin.toByteString @@ -21,9 +20,17 @@ import java.nio.ByteOrder enum class DataType { BOOL, BYTES, - UINT8, UINT16, UINT32, UINT64, - INT8, INT16, INT32, INT64, - FP16, FP32, FP64, + UINT8, + UINT16, + UINT32, + UINT64, + INT8, + INT16, + INT32, + INT64, + FP16, + FP32, + FP64, } private val binaryContentsByteOrder = ByteOrder.LITTLE_ENDIAN @@ -33,21 +40,23 @@ fun List.withBinaryContents() = this.map { it.withBinaryConte fun ModelInferRequest.withBinaryContents(): ModelInferRequest { return this.toBuilder().run { inputsList.forEachIndexed { idx, input -> - val v = when (DataType.valueOf(input.datatype)) { - DataType.UINT8 -> input.contents.toUint8Bytes() - DataType.UINT16 -> input.contents.toUint16Bytes() - DataType.UINT32 -> input.contents.toUint32Bytes() - DataType.UINT64 -> input.contents.toUint64Bytes() - DataType.INT8 -> input.contents.toInt8Bytes() - DataType.INT16 -> input.contents.toInt16Bytes() - DataType.INT32 -> input.contents.toInt32Bytes() - DataType.INT64 -> input.contents.toInt64Bytes() - DataType.BOOL -> input.contents.toBoolBytes() - DataType.FP16, // may need to handle this separately in future - DataType.FP32 -> input.contents.toFp32Bytes() - DataType.FP64 -> input.contents.toFp64Bytes() - DataType.BYTES -> input.contents.toRawBytes() - } + val v = + when (DataType.valueOf(input.datatype)) { + DataType.UINT8 -> input.contents.toUint8Bytes() + DataType.UINT16 -> input.contents.toUint16Bytes() + DataType.UINT32 -> input.contents.toUint32Bytes() + DataType.UINT64 -> input.contents.toUint64Bytes() + DataType.INT8 -> input.contents.toInt8Bytes() + DataType.INT16 -> input.contents.toInt16Bytes() + DataType.INT32 -> input.contents.toInt32Bytes() + DataType.INT64 -> input.contents.toInt64Bytes() + DataType.BOOL -> input.contents.toBoolBytes() + DataType.FP16, // may need to handle this separately in future + DataType.FP32, + -> input.contents.toFp32Bytes() + DataType.FP64 -> input.contents.toFp64Bytes() + DataType.BYTES -> input.contents.toRawBytes() + } // Add binary data and clear corresponding contents. addRawInputContents(v.toByteString()) @@ -61,21 +70,23 @@ fun ModelInferRequest.withBinaryContents(): ModelInferRequest { fun ModelInferResponse.withBinaryContents(): ModelInferResponse { return this.toBuilder().run { outputsList.forEachIndexed { idx, output -> - val v = when (DataType.valueOf(output.datatype)) { - DataType.UINT8 -> output.contents.toUint8Bytes() - DataType.UINT16 -> output.contents.toUint16Bytes() - DataType.UINT32 -> output.contents.toUint32Bytes() - DataType.UINT64 -> output.contents.toUint64Bytes() - DataType.INT8 -> output.contents.toInt8Bytes() - DataType.INT16 -> output.contents.toInt16Bytes() - DataType.INT32 -> output.contents.toInt32Bytes() - DataType.INT64 -> output.contents.toInt64Bytes() - DataType.BOOL -> output.contents.toBoolBytes() - DataType.FP16, // may need to handle this separately in future - DataType.FP32 -> output.contents.toFp32Bytes() - DataType.FP64 -> output.contents.toFp64Bytes() - DataType.BYTES -> output.contents.toRawBytes() - } + val v = + when (DataType.valueOf(output.datatype)) { + DataType.UINT8 -> output.contents.toUint8Bytes() + DataType.UINT16 -> output.contents.toUint16Bytes() + DataType.UINT32 -> output.contents.toUint32Bytes() + DataType.UINT64 -> output.contents.toUint64Bytes() + DataType.INT8 -> output.contents.toInt8Bytes() + DataType.INT16 -> output.contents.toInt16Bytes() + DataType.INT32 -> output.contents.toInt32Bytes() + DataType.INT64 -> output.contents.toInt64Bytes() + DataType.BOOL -> output.contents.toBoolBytes() + DataType.FP16, // may need to handle this separately in future + DataType.FP32, + -> output.contents.toFp32Bytes() + DataType.FP64 -> output.contents.toFp64Bytes() + DataType.BYTES -> output.contents.toRawBytes() + } // Add binary data and clear corresponding contents. addRawOutputContents(v.toByteString()) @@ -86,131 +97,143 @@ fun ModelInferResponse.withBinaryContents(): ModelInferResponse { } } -private fun InferTensorContents.toUint8Bytes(): ByteArray = this.uintContentsList - .flatMap { - ByteBuffer - .allocate(1) - .put(it.toByte()) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toUint16Bytes(): ByteArray = this.uintContentsList - .flatMap { - ByteBuffer - .allocate(UShort.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putShort(it.toShort()) - .array() - .toList() - }.toByteArray() - -private fun InferTensorContents.toUint32Bytes(): ByteArray = this.uintContentsList - .flatMap { - ByteBuffer - .allocate(UInt.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putInt(it) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toUint64Bytes(): ByteArray = this.uint64ContentsList - .flatMap { - ByteBuffer - .allocate(ULong.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putLong(it) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toInt8Bytes(): ByteArray = this.intContentsList - .flatMap { - ByteBuffer - .allocate(1) - .put(it.toByte()) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toInt16Bytes(): ByteArray = this.intContentsList - .flatMap { - ByteBuffer - .allocate(Short.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putShort(it.toShort()) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toInt32Bytes(): ByteArray = this.intContentsList - .flatMap { - ByteBuffer - .allocate(Int.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putInt(it) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toInt64Bytes(): ByteArray = this.int64ContentsList - .flatMap { - ByteBuffer - .allocate(Long.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putLong(it) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toFp32Bytes(): ByteArray = this.fp32ContentsList - .flatMap { - ByteBuffer - .allocate(Float.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putFloat(it) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toFp64Bytes(): ByteArray = this.fp64ContentsList - .flatMap { - ByteBuffer - .allocate(Double.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putDouble(it) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toBoolBytes(): ByteArray = this.boolContentsList - .flatMap { - ByteBuffer - .allocate(1) - .put(if (it) 1 else 0) - .array() - .toList() - } - .toByteArray() - -private fun InferTensorContents.toRawBytes(): ByteArray = this.bytesContentsList - .flatMap { - ByteBuffer - .allocate(it.size() + Int.SIZE_BYTES) - .order(binaryContentsByteOrder) - .putInt(it.size()) - .put(it.toByteArray()) - .array() - .toList() - } - .toByteArray() +private fun InferTensorContents.toUint8Bytes(): ByteArray = + this.uintContentsList + .flatMap { + ByteBuffer + .allocate(1) + .put(it.toByte()) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toUint16Bytes(): ByteArray = + this.uintContentsList + .flatMap { + ByteBuffer + .allocate(UShort.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putShort(it.toShort()) + .array() + .toList() + }.toByteArray() + +private fun InferTensorContents.toUint32Bytes(): ByteArray = + this.uintContentsList + .flatMap { + ByteBuffer + .allocate(UInt.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putInt(it) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toUint64Bytes(): ByteArray = + this.uint64ContentsList + .flatMap { + ByteBuffer + .allocate(ULong.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putLong(it) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toInt8Bytes(): ByteArray = + this.intContentsList + .flatMap { + ByteBuffer + .allocate(1) + .put(it.toByte()) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toInt16Bytes(): ByteArray = + this.intContentsList + .flatMap { + ByteBuffer + .allocate(Short.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putShort(it.toShort()) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toInt32Bytes(): ByteArray = + this.intContentsList + .flatMap { + ByteBuffer + .allocate(Int.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putInt(it) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toInt64Bytes(): ByteArray = + this.int64ContentsList + .flatMap { + ByteBuffer + .allocate(Long.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putLong(it) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toFp32Bytes(): ByteArray = + this.fp32ContentsList + .flatMap { + ByteBuffer + .allocate(Float.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putFloat(it) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toFp64Bytes(): ByteArray = + this.fp64ContentsList + .flatMap { + ByteBuffer + .allocate(Double.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putDouble(it) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toBoolBytes(): ByteArray = + this.boolContentsList + .flatMap { + ByteBuffer + .allocate(1) + .put(if (it) 1 else 0) + .array() + .toList() + } + .toByteArray() + +private fun InferTensorContents.toRawBytes(): ByteArray = + this.bytesContentsList + .flatMap { + ByteBuffer + .allocate(it.size() + Int.SIZE_BYTES) + .order(binaryContentsByteOrder) + .putInt(it.size()) + .put(it.toByteArray()) + .array() + .toList() + } + .toByteArray() diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Chainer.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Chainer.kt index b8d9a4c48c..62c39f373f 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Chainer.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Chainer.kt @@ -50,9 +50,10 @@ class Chainer( } private fun buildPassThroughStream(builder: StreamsBuilder) { - val s1 = builder - .stream(inputTopic.topicName, consumerSerde) - .filterForPipeline(inputTopic.pipelineName) + val s1 = + builder + .stream(inputTopic.topicName, consumerSerde) + .filterForPipeline(inputTopic.pipelineName) addTriggerTopology( kafkaDomainParams, builder, @@ -68,14 +69,15 @@ class Chainer( } private fun buildInputOutputStream(builder: StreamsBuilder) { - val s1 = builder - .stream(inputTopic.topicName, consumerSerde) - .filterForPipeline(inputTopic.pipelineName) - .unmarshallInferenceV2Request() - .convertToResponse(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) - // handle cases where there are no tensors we want - .filter { _, value -> value.outputsList.size != 0 } - .marshallInferenceV2Response() + val s1 = + builder + .stream(inputTopic.topicName, consumerSerde) + .filterForPipeline(inputTopic.pipelineName) + .unmarshallInferenceV2Request() + .convertToResponse(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) + // handle cases where there are no tensors we want + .filter { _, value -> value.outputsList.size != 0 } + .marshallInferenceV2Response() addTriggerTopology( kafkaDomainParams, builder, @@ -91,14 +93,15 @@ class Chainer( } private fun buildOutputOutputStream(builder: StreamsBuilder) { - val s1 = builder - .stream(inputTopic.topicName, consumerSerde) - .filterForPipeline(inputTopic.pipelineName) - .unmarshallInferenceV2Response() - .filterResponses(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) - // handle cases where there are no tensors we want - .filter { _, value -> value.outputsList.size != 0 } - .marshallInferenceV2Response() + val s1 = + builder + .stream(inputTopic.topicName, consumerSerde) + .filterForPipeline(inputTopic.pipelineName) + .unmarshallInferenceV2Response() + .filterResponses(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) + // handle cases where there are no tensors we want + .filter { _, value -> value.outputsList.size != 0 } + .marshallInferenceV2Response() addTriggerTopology( kafkaDomainParams, builder, @@ -114,15 +117,16 @@ class Chainer( } private fun buildOutputInputStream(builder: StreamsBuilder) { - val s1 = builder - .stream(inputTopic.topicName, consumerSerde) - .filterForPipeline(inputTopic.pipelineName) - .unmarshallInferenceV2Response() - .convertToRequest(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) - // handle cases where there are no tensors we want - .filter { _, value -> value.inputsList.size != 0 } - .batchMessages(batchProperties) - .marshallInferenceV2Request() + val s1 = + builder + .stream(inputTopic.topicName, consumerSerde) + .filterForPipeline(inputTopic.pipelineName) + .unmarshallInferenceV2Response() + .convertToRequest(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) + // handle cases where there are no tensors we want + .filter { _, value -> value.inputsList.size != 0 } + .batchMessages(batchProperties) + .marshallInferenceV2Request() addTriggerTopology( kafkaDomainParams, builder, @@ -138,15 +142,16 @@ class Chainer( } private fun buildInputInputStream(builder: StreamsBuilder) { - val s1 = builder - .stream(inputTopic.topicName, consumerSerde) - .filterForPipeline(inputTopic.pipelineName) - .unmarshallInferenceV2Request() - .filterRequests(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) - // handle cases where there are no tensors we want - .filter { _, value -> value.inputsList.size != 0 } - .batchMessages(batchProperties) - .marshallInferenceV2Request() + val s1 = + builder + .stream(inputTopic.topicName, consumerSerde) + .filterForPipeline(inputTopic.pipelineName) + .unmarshallInferenceV2Request() + .filterRequests(inputTopic.pipelineName, inputTopic.topicName, tensors, tensorRenaming) + // handle cases where there are no tensors we want + .filter { _, value -> value.inputsList.size != 0 } + .batchMessages(batchProperties) + .marshallInferenceV2Request() addTriggerTopology( kafkaDomainParams, builder, @@ -164,4 +169,4 @@ class Chainer( companion object { private val logger = noCoLogger(Chainer::class) } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Configuration.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Configuration.kt index 0ee07de8a2..7e006d6cc4 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Configuration.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Configuration.kt @@ -26,7 +26,7 @@ import org.apache.kafka.streams.StreamsConfig import org.apache.kafka.streams.errors.DeserializationExceptionHandler import org.apache.kafka.streams.errors.ProductionExceptionHandler import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler -import java.util.* +import java.util.Properties const val KAFKA_UNCAUGHT_EXCEPTION_HANDLER_CLASS_CONFIG = "default.processing.exception.handler" @@ -50,10 +50,11 @@ data class KafkaDomainParams( ) data class TopicWaitRetryParams( - val createTimeoutMillis: Int, // int required by the underlying kafka-streams library + // int required by the underlying kafka-streams library + val createTimeoutMillis: Int, val describeTimeoutMillis: Long, val describeRetries: Int, - val describeRetryDelayMillis: Long + val describeRetryDelayMillis: Long, ) val kafkaTopicConfig = { maxMessageSizeBytes: Int -> @@ -71,11 +72,13 @@ fun getKafkaAdminProperties(params: KafkaStreamsParams): KafkaAdminProperties { } private fun getSecurityProperties(params: KafkaStreamsParams): Properties { - val authProperties = when (params.security.securityProtocol) { - SecurityProtocol.SSL -> getSslProperties(params) - SecurityProtocol.SASL_SSL -> getSaslProperties(params) - else -> Properties() // No authentication, so nothing to configure - } + val authProperties = + when (params.security.securityProtocol) { + SecurityProtocol.SSL -> getSslProperties(params) + SecurityProtocol.SASL_SSL -> getSaslProperties(params) + // No authentication, so nothing to configure + else -> Properties() + } return authProperties.apply { this[StreamsConfig.SECURITY_PROTOCOL_CONFIG] = params.security.securityProtocol.toString() @@ -121,34 +124,37 @@ private fun getSaslProperties(params: KafkaStreamsParams): Properties { when (params.security.saslConfig) { is SaslConfig.Password -> { - val module = when (params.security.saslConfig) { - is SaslConfig.Password.Plain -> "org.apache.kafka.common.security.plain.PlainLoginModule required" - is SaslConfig.Password.Scram256, - is SaslConfig.Password.Scram512 -> "org.apache.kafka.common.security.scram.ScramLoginModule required" - } + val module = + when (params.security.saslConfig) { + is SaslConfig.Password.Plain -> "org.apache.kafka.common.security.plain.PlainLoginModule required" + is SaslConfig.Password.Scram256, + is SaslConfig.Password.Scram512, + -> "org.apache.kafka.common.security.scram.ScramLoginModule required" + } val password = SaslPasswordProvider.default.getPassword(params.security.saslConfig) this[SaslConfigs.SASL_JAAS_CONFIG] = module + - """ username="${params.security.saslConfig.username}"""" + - """ password="$password";""" + """ username="${params.security.saslConfig.username}"""" + + """ password="$password";""" } is SaslConfig.Oauth -> { val oauthConfig = SaslOauthProvider.default.getOauthConfig(params.security.saslConfig) - val jaasConfig = buildString { - append("org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required") - append(""" clientId="${oauthConfig.clientId}"""") - append(""" clientSecret="${oauthConfig.clientSecret}"""") - oauthConfig.scope?.let { - append(""" scope="$it"""") - } - oauthConfig.extensions?.let { extensions -> - extensions.forEach { - append(""" $it""") + val jaasConfig = + buildString { + append("org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required") + append(""" clientId="${oauthConfig.clientId}"""") + append(""" clientSecret="${oauthConfig.clientSecret}"""") + oauthConfig.scope?.let { + append(""" scope="$it"""") + } + oauthConfig.extensions?.let { extensions -> + extensions.forEach { + append(""" $it""") + } } + append(";") } - append(";") - } this[SaslConfigs.SASL_JAAS_CONFIG] = jaasConfig this[SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL] = oauthConfig.tokenUrl @@ -190,7 +196,11 @@ fun getKafkaProperties(params: KafkaStreamsParams): KafkaProperties { } } -fun KafkaProperties.withAppId(namespace: String, consumerGroupIdPrefix: String, name: String): KafkaProperties { +fun KafkaProperties.withAppId( + namespace: String, + consumerGroupIdPrefix: String, + name: String, +): KafkaProperties { val properties = KafkaProperties() properties.putAll(this.toMap()) @@ -221,15 +231,17 @@ fun KafkaProperties.withStreamThreads(n: Int): KafkaProperties { return properties } -fun KafkaProperties.withErrorHandlers(deserializationExceptionHdl: DeserializationExceptionHandler?, - streamExceptionHdl: StreamsUncaughtExceptionHandler?, - productionExceptionHdl: ProductionExceptionHandler?): KafkaProperties { +fun KafkaProperties.withErrorHandlers( + deserializationExceptionHdl: DeserializationExceptionHandler?, + streamExceptionHdl: StreamsUncaughtExceptionHandler?, + productionExceptionHdl: ProductionExceptionHandler?, +): KafkaProperties { val properties = KafkaProperties() properties.putAll(this.toMap()) - deserializationExceptionHdl?.let { properties[StreamsConfig.DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG] = it::class.java } - streamExceptionHdl?.let { properties[KAFKA_UNCAUGHT_EXCEPTION_HANDLER_CLASS_CONFIG] = it::class.java } - productionExceptionHdl?.let { properties[StreamsConfig.DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_CONFIG] = it::class.java } + deserializationExceptionHdl?.let { properties[StreamsConfig.DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG] = it::class.java } + streamExceptionHdl?.let { properties[KAFKA_UNCAUGHT_EXCEPTION_HANDLER_CLASS_CONFIG] = it::class.java } + productionExceptionHdl?.let { properties[StreamsConfig.DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_CONFIG] = it::class.java } return properties } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Joiner.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Joiner.kt index 46464eef2a..bb676cd408 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Joiner.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Joiner.kt @@ -43,7 +43,7 @@ class Joiner( triggerTensorsByTopic, triggerJoinType, dataStream, - null + null, ) .headerRemover() .headerSetter(pipelineName) @@ -65,78 +65,87 @@ class Joiner( val topic = inputTopics.first() val chainType = ChainType.create(topic.topicName, outputTopic.topicName) - logger.info("Creating stream ${chainType} for ${topic}->${outputTopic}") - val nextStream = when (chainType) { - ChainType.OUTPUT_INPUT -> buildOutputInputStream(topic, builder) - ChainType.INPUT_INPUT -> buildInputInputStream(topic, builder) - ChainType.OUTPUT_OUTPUT -> buildOutputOutputStream(topic, builder) - ChainType.INPUT_OUTPUT -> buildInputOutputStream(topic, builder) - else -> buildPassThroughStream(topic, builder) - } - val payloadJoiner = when (chainType) { - ChainType.OUTPUT_INPUT, ChainType.INPUT_INPUT -> ::joinRequests - ChainType.OUTPUT_OUTPUT, ChainType.INPUT_OUTPUT -> ::joinResponses - else -> throw Exception("Can't join custom data") - } + logger.info("Creating stream $chainType for $topic->$outputTopic") + val nextStream = + when (chainType) { + ChainType.OUTPUT_INPUT -> buildOutputInputStream(topic, builder) + ChainType.INPUT_INPUT -> buildInputInputStream(topic, builder) + ChainType.OUTPUT_OUTPUT -> buildOutputOutputStream(topic, builder) + ChainType.INPUT_OUTPUT -> buildInputOutputStream(topic, builder) + else -> buildPassThroughStream(topic, builder) + } + val payloadJoiner = + when (chainType) { + ChainType.OUTPUT_INPUT, ChainType.INPUT_INPUT -> ::joinRequests + ChainType.OUTPUT_OUTPUT, ChainType.INPUT_OUTPUT -> ::joinResponses + else -> throw Exception("Can't join custom data") + } when (joinType) { PipelineJoinType.Any -> { - val nextPending = pending - ?.outerJoin( - nextStream, - payloadJoiner, - //JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(1), Duration.ofMillis(1)), - // Required because this "fix" causes outer joins to wait for next record to come in if all streams - // don't produce a record during grace period. https://issues.apache.org/jira/browse/KAFKA-10847 - // Also see https://confluentcommunity.slack.com/archives/C6UJNMY67/p1649520904545229?thread_ts=1649324912.542999&cid=C6UJNMY67 - // Issue created at https://issues.apache.org/jira/browse/KAFKA-13813 - JoinWindows.of(Duration.ofMillis(1)), - joinSerde, - ) ?: nextStream - + val nextPending = + pending + ?.outerJoin( + nextStream, + payloadJoiner, + // JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(1), Duration.ofMillis(1)), + // Required because this "fix" causes outer joins to wait for next record to come in if all streams + // don't produce a record during grace period. https://issues.apache.org/jira/browse/KAFKA-10847 + // Also see https://confluentcommunity.slack.com/archives/C6UJNMY67/p1649520904545229?thread_ts=1649324912.542999&cid=C6UJNMY67 + // Issue created at https://issues.apache.org/jira/browse/KAFKA-13813 + JoinWindows.of(Duration.ofMillis(1)), + joinSerde, + ) ?: nextStream return buildTopology(builder, inputTopics.minus(topic), nextPending) } PipelineJoinType.Outer -> { - val nextPending = pending - ?.outerJoin( - nextStream, - payloadJoiner, - // See above for Any case as this will wait until next record comes in before emitting a result after window - JoinWindows.ofTimeDifferenceWithNoGrace( - Duration.ofMillis(kafkaDomainParams.joinWindowMillis), - ), - joinSerde, - ) ?: nextStream - + val nextPending = + pending + ?.outerJoin( + nextStream, + payloadJoiner, + // See above for Any case as this will wait until next record comes in before emitting a result after window + JoinWindows.ofTimeDifferenceWithNoGrace( + Duration.ofMillis(kafkaDomainParams.joinWindowMillis), + ), + joinSerde, + ) ?: nextStream return buildTopology(builder, inputTopics.minus(topic), nextPending) } else -> { - val nextPending = pending - ?.join( - nextStream, - payloadJoiner, - JoinWindows.ofTimeDifferenceWithNoGrace( - Duration.ofMillis(kafkaDomainParams.joinWindowMillis), - ), - joinSerde, - ) ?: nextStream + val nextPending = + pending + ?.join( + nextStream, + payloadJoiner, + JoinWindows.ofTimeDifferenceWithNoGrace( + Duration.ofMillis(kafkaDomainParams.joinWindowMillis), + ), + joinSerde, + ) ?: nextStream return buildTopology(builder, inputTopics.minus(topic), nextPending) } } } - private fun buildPassThroughStream(topic: TopicForPipeline, builder: StreamsBuilder): KStream { + private fun buildPassThroughStream( + topic: TopicForPipeline, + builder: StreamsBuilder, + ): KStream { return builder .stream(topic.topicName, consumerSerde) .filterForPipeline(topic.pipelineName) } - private fun buildInputOutputStream(topic: TopicForPipeline, builder: StreamsBuilder): KStream { + private fun buildInputOutputStream( + topic: TopicForPipeline, + builder: StreamsBuilder, + ): KStream { return builder .stream(topic.topicName, consumerSerde) .filterForPipeline(topic.pipelineName) @@ -147,7 +156,10 @@ class Joiner( .marshallInferenceV2Response() } - private fun buildOutputOutputStream(topic: TopicForPipeline, builder: StreamsBuilder): KStream { + private fun buildOutputOutputStream( + topic: TopicForPipeline, + builder: StreamsBuilder, + ): KStream { return builder .stream(topic.topicName, consumerSerde) .filterForPipeline(topic.pipelineName) @@ -158,7 +170,10 @@ class Joiner( .marshallInferenceV2Response() } - private fun buildOutputInputStream(topic: TopicForPipeline, builder: StreamsBuilder): KStream { + private fun buildOutputInputStream( + topic: TopicForPipeline, + builder: StreamsBuilder, + ): KStream { return builder .stream(topic.topicName, consumerSerde) .filterForPipeline(topic.pipelineName) @@ -169,18 +184,24 @@ class Joiner( .marshallInferenceV2Request() } - private fun buildInputInputStream(topic: TopicForPipeline, builder: StreamsBuilder): KStream { + private fun buildInputInputStream( + topic: TopicForPipeline, + builder: StreamsBuilder, + ): KStream { return builder .stream(topic.topicName, consumerSerde) .filterForPipeline(topic.pipelineName) .unmarshallInferenceV2Request() - .filterRequests(topic.pipelineName,topic.topicName, tensorsByTopic?.get(topic), tensorRenaming) + .filterRequests(topic.pipelineName, topic.topicName, tensorsByTopic?.get(topic), tensorRenaming) // handle cases where there are no tensors we want .filter { _, value -> value.inputsList.size != 0 } .marshallInferenceV2Request() } - private fun joinRequests(left: ByteArray?, right: ByteArray?): ByteArray { + private fun joinRequests( + left: ByteArray?, + right: ByteArray?, + ): ByteArray { if (left == null) { return right!! } @@ -194,19 +215,23 @@ class Joiner( } else if (rightRequest.rawInputContentsCount > 0 && leftRequest.rawInputContentsCount == 0) { leftRequest = leftRequest.withBinaryContents() } - val request = V2Dataplane.ModelInferRequest - .newBuilder() - .setId(leftRequest.id) - .putAllParameters(leftRequest.parametersMap) - .addAllInputs(leftRequest.inputsList) - .addAllInputs(rightRequest.inputsList) - .addAllRawInputContents(leftRequest.rawInputContentsList) - .addAllRawInputContents(rightRequest.rawInputContentsList) - .build() + val request = + V2Dataplane.ModelInferRequest + .newBuilder() + .setId(leftRequest.id) + .putAllParameters(leftRequest.parametersMap) + .addAllInputs(leftRequest.inputsList) + .addAllInputs(rightRequest.inputsList) + .addAllRawInputContents(leftRequest.rawInputContentsList) + .addAllRawInputContents(rightRequest.rawInputContentsList) + .build() return request.toByteArray() } - private fun joinResponses(left: ByteArray?, right: ByteArray?): ByteArray { + private fun joinResponses( + left: ByteArray?, + right: ByteArray?, + ): ByteArray { if (left == null) { return right!! } @@ -220,19 +245,20 @@ class Joiner( } else if (rightResponse.rawOutputContentsCount > 0 && leftResponse.rawOutputContentsCount == 0) { leftResponse = leftResponse.withBinaryContents() } - val response = V2Dataplane.ModelInferResponse - .newBuilder() - .setId(leftResponse.id) - .putAllParameters(leftResponse.parametersMap) - .addAllOutputs(leftResponse.outputsList) - .addAllOutputs(rightResponse.outputsList) - .addAllRawOutputContents(leftResponse.rawOutputContentsList) - .addAllRawOutputContents(rightResponse.rawOutputContentsList) - .build() + val response = + V2Dataplane.ModelInferResponse + .newBuilder() + .setId(leftResponse.id) + .putAllParameters(leftResponse.parametersMap) + .addAllOutputs(leftResponse.outputsList) + .addAllOutputs(rightResponse.outputsList) + .addAllRawOutputContents(leftResponse.rawOutputContentsList) + .addAllRawOutputContents(rightResponse.rawOutputContentsList) + .build() return response.toByteArray() } companion object { private val logger = noCoLogger(Joiner::class) } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/KafkaAdmin.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/KafkaAdmin.kt index 157bccc910..6971f918dd 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/KafkaAdmin.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/KafkaAdmin.kt @@ -9,20 +9,20 @@ the Change License after the Change Date as each is defined in accordance with t package io.seldon.dataflow.kafka -import com.github.michaelbull.retry.policy.* +import com.github.michaelbull.retry.policy.constantDelay +import com.github.michaelbull.retry.policy.continueIf +import com.github.michaelbull.retry.policy.plus +import com.github.michaelbull.retry.policy.stopAtAttempts import com.github.michaelbull.retry.retry +import io.klogging.noCoLogger import io.seldon.mlops.chainer.ChainerOuterClass.PipelineStepUpdate import org.apache.kafka.clients.admin.Admin import org.apache.kafka.clients.admin.CreateTopicsOptions import org.apache.kafka.clients.admin.NewTopic -import org.apache.kafka.common.KafkaFuture import org.apache.kafka.common.errors.TimeoutException -import org.apache.kafka.common.errors.TopicExistsException import org.apache.kafka.common.errors.UnknownTopicOrPartitionException -import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit import io.klogging.logger as coLogger -import io.klogging.noCoLogger class KafkaAdmin( adminConfig: KafkaAdminProperties, @@ -31,22 +31,22 @@ class KafkaAdmin( ) { private val adminClient = Admin.create(adminConfig) - suspend fun ensureTopicsExist( - steps: List, - ) : Exception? { - val missingTopicRetryPolicy = continueIf { (failure) -> - when (failure) { - is TimeoutException, - is UnknownTopicOrPartitionException -> true - else -> { - // We log here for dev purposes, to gather other kinds of exceptions that occur. In time, we should - // collate those and decide which are permanent errors. For permanent errors, it would be worth - // stopping the retries and returning false. - noCoLogger.warn("ignoring exception while waiting for topic creation: ${failure.message}") - true + suspend fun ensureTopicsExist(steps: List): Exception? { + val missingTopicRetryPolicy = + continueIf { (failure) -> + when (failure) { + is TimeoutException, + is UnknownTopicOrPartitionException, + -> true + else -> { + // We log here for dev purposes, to gather other kinds of exceptions that occur. In time, we should + // collate those and decide which are permanent errors. For permanent errors, it would be worth + // stopping the retries and returning false. + noCoLogger.warn("ignoring exception while waiting for topic creation: ${failure.message}") + true + } } } - } try { steps @@ -70,7 +70,7 @@ class KafkaAdmin( .run { adminClient.createTopics( this, - CreateTopicsOptions().timeoutMs(topicWaitRetryParams.createTimeoutMillis) + CreateTopicsOptions().timeoutMs(topicWaitRetryParams.createTimeoutMillis), ) } .values() @@ -80,9 +80,10 @@ class KafkaAdmin( // one broker. This is because the call to createTopics above returns before topics can actually // be subscribed to. retry( - missingTopicRetryPolicy + stopAtAttempts(topicWaitRetryParams.describeRetries) + constantDelay( - topicWaitRetryParams.describeRetryDelayMillis - ) + missingTopicRetryPolicy + stopAtAttempts(topicWaitRetryParams.describeRetries) + + constantDelay( + topicWaitRetryParams.describeRetryDelayMillis, + ), ) { logger.debug("Still waiting for all topics to be created...") // the KafkaFuture retrieved via .allTopicNames() only succeeds if all the topic diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Pipeline.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Pipeline.kt index 850b642739..3ba55c64fe 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Pipeline.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Pipeline.kt @@ -20,7 +20,6 @@ import org.apache.kafka.streams.KafkaStreams.StateListener import org.apache.kafka.streams.StreamsBuilder import org.apache.kafka.streams.StreamsConfig import org.apache.kafka.streams.Topology -import org.apache.kafka.streams.errors.StreamsException import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler import java.util.concurrent.CountDownLatch import kotlin.math.floor @@ -35,7 +34,6 @@ data class PipelineMetadata( val version: Int, ) - class Pipeline( private val metadata: PipelineMetadata, private val topology: Topology, @@ -44,12 +42,13 @@ class Pipeline( val size: Int, ) : StateListener { private val latch = CountDownLatch(1) + // Never update status properties in-place, because we need it to have atomic // properties. Instead, just assign new values to it. @Volatile - var status : PipelineStatus = PipelineStatus.StreamStopped(null) + var status: PipelineStatus = PipelineStatus.StreamStopped(null) - fun start() : PipelineStatus { + fun start(): PipelineStatus { if (kafkaDomainParams.useCleanState) { streams.cleanUp() } @@ -61,9 +60,10 @@ class Pipeline( } catch (e: Exception) { streams.close() streams.cleanUp() - status = PipelineStatus.Error(State.NOT_RUNNING) - .withException(e) - .withMessage("kafka streams: failed to start") + status = + PipelineStatus.Error(State.NOT_RUNNING) + .withException(e) + .withMessage("kafka streams: failed to start") return status } status = PipelineStatus.StreamStarting() @@ -93,11 +93,16 @@ class Pipeline( latch.countDown() } - override fun onChange(newState: State?, oldState: State?) { + override fun onChange( + newState: State?, + oldState: State?, + ) { logger.info { - e("pipeline {pipelineName} (v{pipelineVersion}) changing to state $newState", + e( + "pipeline {pipelineName} (v{pipelineVersion}) changing to state $newState", metadata.name, - metadata.version) + metadata.version, + ) } if (newState == State.RUNNING) { // Only update the status if the pipeline is not already being stopped @@ -117,8 +122,9 @@ class Pipeline( // see: https://kafka.apache.org/28/javadoc/org/apache/kafka/streams/KafkaStreams.State.html if (newState != State.CREATED && newState != State.REBALANCING) { if (status !is PipelineStatus.StreamStopping) { - status = PipelineStatus.Error(newState) - .withMessage("pipeline data streams error: kafka streams state: $newState") + status = + PipelineStatus.Error(newState) + .withMessage("pipeline data streams error: kafka streams state: $newState") latch.countDown() } } @@ -140,28 +146,31 @@ class Pipeline( ): Pair { val (topology, numSteps) = buildTopology(metadata, steps, kafkaDomainParams) val pipelineProperties = localiseKafkaProperties(kafkaProperties, metadata, numSteps, kafkaConsumerGroupIdPrefix, namespace) - var streamsApp : KafkaStreams? + var streamsApp: KafkaStreams? var pipelineError: PipelineStatus.Error? try { streamsApp = KafkaStreams(topology, pipelineProperties) } catch (e: Exception) { - pipelineError = PipelineStatus.Error(null) - .withException(e) - .withMessage("failed to initialize kafka streams for pipeline") + pipelineError = + PipelineStatus.Error(null) + .withException(e) + .withMessage("failed to initialize kafka streams for pipeline") return null to pipelineError } - val uncaughtExceptionHandlerClass = pipelineProperties[KAFKA_UNCAUGHT_EXCEPTION_HANDLER_CLASS_CONFIG] as? Class? - uncaughtExceptionHandlerClass?.let{ + val uncaughtExceptionHandlerClass = + pipelineProperties[KAFKA_UNCAUGHT_EXCEPTION_HANDLER_CLASS_CONFIG] as? Class? + uncaughtExceptionHandlerClass?.let { logger.debug("Setting custom Kafka streams uncaught exception handler") streamsApp.setUncaughtExceptionHandler(it.getDeclaredConstructor().newInstance()) } logger.info( - "Create pipeline stream for name:{pipelineName} id:{pipelineId} version:{pipelineVersion} stream with kstream app id:{kstreamAppId}", + "Create pipeline stream for name:{pipelineName} id:{pipelineId} " + + "version:{pipelineVersion} stream with kstream app id:{kstreamAppId}", metadata.name, metadata.id, metadata.version, - pipelineProperties[StreamsConfig.APPLICATION_ID_CONFIG] + pipelineProperties[StreamsConfig.APPLICATION_ID_CONFIG], ) return Pipeline(metadata, topology, streamsApp, kafkaDomainParams, numSteps) to null } @@ -172,21 +181,22 @@ class Pipeline( kafkaDomainParams: KafkaDomainParams, ): Pair { val builder = StreamsBuilder() - val topologySteps = steps - .mapNotNull { - stepFor( - builder, - metadata.name, - it.sourcesList, - it.triggersList, - it.tensorMapList, - it.sink, - it.inputJoinTy, - it.triggersJoinTy, - it.batch, - kafkaDomainParams, - ) - } + val topologySteps = + steps + .mapNotNull { + stepFor( + builder, + metadata.name, + it.sourcesList, + it.triggersList, + it.tensorMapList, + it.sink, + it.inputJoinTy, + it.triggersJoinTy, + it.batch, + kafkaDomainParams, + ) + } val topology = builder.build() return topology to topologySteps.size } @@ -196,7 +206,7 @@ class Pipeline( metadata: PipelineMetadata, numSteps: Int, kafkaConsumerGroupIdPrefix: String, - namespace: String + namespace: String, ): KafkaProperties { return kafkaProperties .withAppId( @@ -210,7 +220,7 @@ class Pipeline( .withErrorHandlers( StreamErrorHandling.StreamsDeserializationErrorHandler(), StreamErrorHandling.StreamsCustomUncaughtExceptionHandler(), - StreamErrorHandling.StreamsRecordProducerErrorHandler() + StreamErrorHandling.StreamsRecordProducerErrorHandler(), ) } @@ -219,4 +229,4 @@ class Pipeline( return max(1, scale.toInt()) } } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStatus.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStatus.kt index 3f732e4a02..4def1914b4 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStatus.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStatus.kt @@ -22,7 +22,7 @@ open class PipelineStatus(val state: KafkaStreams.State?, var isError: Boolean) this.isError = this.prevState?.isError ?: false } - override fun getDescription() : String? { + override fun getDescription(): String? { val exceptionMsg = this.exception?.message var statusMsg = this.message val prevStateDescription = this.prevState?.getDescription() @@ -37,7 +37,10 @@ open class PipelineStatus(val state: KafkaStreams.State?, var isError: Boolean) } // log status when logger is in a coroutine - override fun log(logger: Klogger, levelIfNoException: Level) { + override fun log( + logger: Klogger, + levelIfNoException: Level, + ) { var exceptionMsg = this.exception?.message var exceptionCause = this.exception?.cause ?: Exception("") var statusMsg = this.message @@ -57,7 +60,10 @@ open class PipelineStatus(val state: KafkaStreams.State?, var isError: Boolean) } // log status when logger is outside coroutines - override fun log(logger: NoCoLogger, levelIfNoException: Level) { + override fun log( + logger: NoCoLogger, + levelIfNoException: Level, + ) { val exceptionMsg = this.exception?.message val exceptionCause = this.exception?.cause ?: Exception("") var statusMsg = this.message @@ -85,9 +91,8 @@ open class PipelineStatus(val state: KafkaStreams.State?, var isError: Boolean) override var message: String? = "pipeline data streams: ready" } - data class Error(val errorState: KafkaStreams.State?): PipelineStatus(errorState, true) + data class Error(val errorState: KafkaStreams.State?) : PipelineStatus(errorState, true) override var exception: Exception? = null override var message: String? = null - } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStep.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStep.kt index 4d3938cc4f..d4ad59a98c 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStep.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/PipelineStep.kt @@ -42,85 +42,93 @@ fun stepFor( val triggerTopicsToTensors = parseTriggers(triggerSources) return when (val result = parseSources(sources)) { is SourceProjection.Empty -> null - is SourceProjection.Single -> Chainer( - builder, - result.topicForPipeline, - TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), - null, - pipelineName, - tensorMap, - batchProperties, - kafkaDomainParams, - triggerTopicsToTensors.keys, - triggerJoinType, - triggerTopicsToTensors - ) - is SourceProjection.SingleSubset -> Chainer( - builder, - result.topicForPipeline, - TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), - result.tensors, - pipelineName, - tensorMap, - batchProperties, - kafkaDomainParams, - triggerTopicsToTensors.keys, - triggerJoinType, - triggerTopicsToTensors - ) - is SourceProjection.Many -> Joiner( - builder, - result.topicNames, - TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), - null, - pipelineName, - tensorMap, - kafkaDomainParams, - joinType, - triggerTopicsToTensors.keys, - triggerJoinType, - triggerTopicsToTensors - ) - is SourceProjection.ManySubsets -> Joiner( - builder, - result.tensorsByTopic.keys, - TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), - result.tensorsByTopic, - pipelineName, - tensorMap, - kafkaDomainParams, - joinType, - triggerTopicsToTensors.keys, - triggerJoinType, - triggerTopicsToTensors - ) + is SourceProjection.Single -> + Chainer( + builder, + result.topicForPipeline, + TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), + null, + pipelineName, + tensorMap, + batchProperties, + kafkaDomainParams, + triggerTopicsToTensors.keys, + triggerJoinType, + triggerTopicsToTensors, + ) + is SourceProjection.SingleSubset -> + Chainer( + builder, + result.topicForPipeline, + TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), + result.tensors, + pipelineName, + tensorMap, + batchProperties, + kafkaDomainParams, + triggerTopicsToTensors.keys, + triggerJoinType, + triggerTopicsToTensors, + ) + is SourceProjection.Many -> + Joiner( + builder, + result.topicNames, + TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), + null, + pipelineName, + tensorMap, + kafkaDomainParams, + joinType, + triggerTopicsToTensors.keys, + triggerJoinType, + triggerTopicsToTensors, + ) + is SourceProjection.ManySubsets -> + Joiner( + builder, + result.tensorsByTopic.keys, + TopicForPipeline(topicName = sink.topicName, pipelineName = sink.pipelineName), + result.tensorsByTopic, + pipelineName, + tensorMap, + kafkaDomainParams, + joinType, + triggerTopicsToTensors.keys, + triggerJoinType, + triggerTopicsToTensors, + ) } } - sealed class SourceProjection { object Empty : SourceProjection() + data class Single(val topicForPipeline: TopicForPipeline) : SourceProjection() + data class SingleSubset(val topicForPipeline: TopicForPipeline, val tensors: Set) : SourceProjection() + data class Many(val topicNames: Set) : SourceProjection() + data class ManySubsets(val tensorsByTopic: Map>) : SourceProjection() } -fun parseTriggers(sources: List): Map> { +fun parseTriggers(sources: List): Map> { return sources .map { parseSource(it) } - .groupBy(keySelector = { it.first+":"+it.third }, valueTransform = { it.second }) + .groupBy(keySelector = { it.first + ":" + it.third }, valueTransform = { it.second }) .mapValues { it.value.filterNotNull().toSet() } .map { TopicTensors(TopicForPipeline(topicName = it.key.split(":")[0], pipelineName = it.key.split(":")[1]), it.value) } - .associate {it.topicForPipeline to it.tensors } + .associate { it.topicForPipeline to it.tensors } } fun parseSources(sources: List): SourceProjection { - val topicsAndTensors = sources - .map { parseSource(it) } - .groupBy(keySelector = { it.first+":"+it.third }, valueTransform = { it.second }) - .mapValues { it.value.filterNotNull().toSet() } - .map { TopicTensors(TopicForPipeline(topicName = it.key.split(":")[0], pipelineName = it.key.split(":")[1]), it.value) } + val topicsAndTensors = + sources + .map { parseSource(it) } + .groupBy(keySelector = { it.first + ":" + it.third }, valueTransform = { it.second }) + .mapValues { it.value.filterNotNull().toSet() } + .map { TopicTensors(TopicForPipeline(topicName = it.key.split(":")[0], pipelineName = it.key.split(":")[1]), it.value) } return when { topicsAndTensors.isEmpty() -> SourceProjection.Empty @@ -146,4 +154,4 @@ fun parseSource(source: PipelineTopic): Triple { if (source.tensor == "") null else source.tensor, source.pipelineName, ) -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamErrorHandling.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamErrorHandling.kt index 55a117361c..ca52b6672f 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamErrorHandling.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamErrorHandling.kt @@ -10,16 +10,14 @@ import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler import org.apache.kafka.streams.processor.ProcessorContext class StreamErrorHandling { - - class StreamsDeserializationErrorHandler: DeserializationExceptionHandler { - + class StreamsDeserializationErrorHandler : DeserializationExceptionHandler { override fun configure(configs: MutableMap?) { } override fun handle( context: ProcessorContext?, record: ConsumerRecord?, - exception: Exception? + exception: Exception?, ): DeserializationExceptionHandler.DeserializationHandlerResponse { if (exception != null) { logger.error(exception, "Kafka streams: message deserialization error on ${record?.topic()}") @@ -28,39 +26,36 @@ class StreamErrorHandling { } } - class StreamsRecordProducerErrorHandler: ProductionExceptionHandler { - + class StreamsRecordProducerErrorHandler : ProductionExceptionHandler { override fun configure(configs: MutableMap?) { } override fun handle( record: ProducerRecord?, - exception: Exception? + exception: Exception?, ): ProductionExceptionHandler.ProductionExceptionHandlerResponse { if (exception != null) { logger.error(exception, "Kafka streams: error when writing to ${record?.topic()}") } return ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE } - } - class StreamsCustomUncaughtExceptionHandler: StreamsUncaughtExceptionHandler { + class StreamsCustomUncaughtExceptionHandler : StreamsUncaughtExceptionHandler { override fun handle(exception: Throwable?): StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse { if (exception is StreamsException) { val originalException = exception.cause originalException?.let { logger.error(it, "Kafka streams: stream processing exception") - return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT } } // try to continue - return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.REPLACE_THREAD; + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.REPLACE_THREAD } } companion object { private val logger = noCoLogger(StreamErrorHandling::class) } - } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamTransforms.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamTransforms.kt index cb88022329..44c6ed87a3 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamTransforms.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/StreamTransforms.kt @@ -9,11 +9,11 @@ the Change License after the Change Date as each is defined in accordance with t package io.seldon.dataflow.kafka -import io.seldon.dataflow.kafka.headers.PipelineNameFilter import io.seldon.dataflow.kafka.headers.AlibiDetectRemover import io.seldon.dataflow.kafka.headers.PipelineHeaderSetter -import io.seldon.mlops.chainer.ChainerOuterClass.PipelineTensorMapping +import io.seldon.dataflow.kafka.headers.PipelineNameFilter import io.seldon.mlops.chainer.ChainerOuterClass.Batch +import io.seldon.mlops.chainer.ChainerOuterClass.PipelineTensorMapping import io.seldon.mlops.inference.v2.V2Dataplane.ModelInferRequest import io.seldon.mlops.inference.v2.V2Dataplane.ModelInferResponse import org.apache.kafka.streams.kstream.KStream @@ -55,14 +55,18 @@ fun KStream.marshallInferenceV2Response(): KStream request.toByteArray() } } - fun KStream.convertToRequest( inputPipeline: String, inputTopic: TopicName, desiredTensors: Set?, - tensorRenamingList: List + tensorRenamingList: List, ): KStream { - val tensorRenaming = tensorRenamingList.filter { it.pipelineName.equals(inputPipeline) }.map { it.topicAndTensor to it.tensorName }.toMap() + val tensorRenaming = + tensorRenamingList.filter { + it.pipelineName.equals( + inputPipeline, + ) + }.map { it.topicAndTensor to it.tensorName }.toMap() return this .mapValues { inferResponse -> convertToRequest( @@ -80,8 +84,9 @@ fun KStream.convertToRequest( fun KStream.batchMessages(batchProperties: Batch): KStream { return when (batchProperties.size) { 0 -> this - else -> this - .transform({ BatchProcessor(batchProperties.size) }, BatchProcessor.STATE_STORE_ID) + else -> + this + .transform({ BatchProcessor(batchProperties.size) }, BatchProcessor.STATE_STORE_ID) } } @@ -103,11 +108,12 @@ private fun convertToRequest( response.outputsList .forEachIndexed { idx, tensor -> if (tensor.name in desiredTensors || desiredTensors == null || desiredTensors.isEmpty()) { - val newName = tensorRenaming - .getOrDefault( - "${inputTopic}.${tensor.name}", - tensor.name, - ) + val newName = + tensorRenaming + .getOrDefault( + "$inputTopic.${tensor.name}", + tensor.name, + ) val convertedTensor = convertOutputToInputTensor(newName, tensor, response.rawOutputContentsCount > 0) @@ -115,7 +121,7 @@ private fun convertToRequest( if (idx < response.rawOutputContentsCount) { // TODO - should add in appropriate index for raw input contents! addRawInputContents( - response.getRawOutputContents(idx) + response.getRawOutputContents(idx), ) } } @@ -126,14 +132,15 @@ private fun convertToRequest( private fun convertOutputToInputTensor( tensorName: TensorName, output: ModelInferResponse.InferOutputTensor, - rawContents: Boolean + rawContents: Boolean, ): ModelInferRequest.InferInputTensor { - val req = ModelInferRequest.InferInputTensor - .newBuilder() - .setName(tensorName) - .setDatatype(output.datatype) - .addAllShape(output.shapeList) - .putAllParameters(output.parametersMap) + val req = + ModelInferRequest.InferInputTensor + .newBuilder() + .setName(tensorName) + .setDatatype(output.datatype) + .addAllShape(output.shapeList) + .putAllParameters(output.parametersMap) if (!rawContents) { req.setContents(output.contents) } @@ -144,9 +151,14 @@ fun KStream.filterRequests( inputPipeline: String, inputTopic: TopicName, desiredTensors: Set?, - tensorRenamingList: List + tensorRenamingList: List, ): KStream { - val tensorRenaming = tensorRenamingList.filter { it.pipelineName.equals(inputPipeline) }.map { it.topicAndTensor to it.tensorName }.toMap() + val tensorRenaming = + tensorRenamingList.filter { + it.pipelineName.equals( + inputPipeline, + ) + }.map { it.topicAndTensor to it.tensorName }.toMap() return this .mapValues { inferResponse -> filterRequest( @@ -173,19 +185,20 @@ private fun filterRequest( request.inputsList .forEachIndexed { idx, tensor -> if (tensor.name in desiredTensors || desiredTensors == null || desiredTensors.isEmpty()) { - val newName = tensorRenaming - .getOrDefault( - "${inputTopic}.${tensor.name}", - tensor.name, - ) + val newName = + tensorRenaming + .getOrDefault( + "$inputTopic.${tensor.name}", + tensor.name, + ) - val convertedTensor = createInputTensor(newName, tensor, request.rawInputContentsCount>0) + val convertedTensor = createInputTensor(newName, tensor, request.rawInputContentsCount > 0) addInputs(convertedTensor) if (idx < request.rawInputContentsCount) { // TODO - should add in appropriate index for raw input contents! addRawInputContents( - request.getRawInputContents(idx) + request.getRawInputContents(idx), ) } } @@ -196,28 +209,33 @@ private fun filterRequest( private fun createInputTensor( tensorName: TensorName, input: ModelInferRequest.InferInputTensor, - rawContents: Boolean + rawContents: Boolean, ): ModelInferRequest.InferInputTensor { - val req = ModelInferRequest.InferInputTensor - .newBuilder() - .setName(tensorName) - .setDatatype(input.datatype) - .addAllShape(input.shapeList) - .putAllParameters(input.parametersMap) + val req = + ModelInferRequest.InferInputTensor + .newBuilder() + .setName(tensorName) + .setDatatype(input.datatype) + .addAllShape(input.shapeList) + .putAllParameters(input.parametersMap) if (!rawContents) { req.setContents(input.contents) } return req.build() } - fun KStream.filterResponses( inputPipeline: String, inputTopic: TopicName, desiredTensors: Set?, - tensorRenamingList: List + tensorRenamingList: List, ): KStream { - val tensorRenaming = tensorRenamingList.filter { it.pipelineName.equals(inputPipeline) }.map { it.topicAndTensor to it.tensorName }.toMap() + val tensorRenaming = + tensorRenamingList.filter { + it.pipelineName.equals( + inputPipeline, + ) + }.map { it.topicAndTensor to it.tensorName }.toMap() return this .mapValues { inferResponse -> filterResponse( @@ -244,19 +262,20 @@ private fun filterResponse( response.outputsList .forEachIndexed { idx, tensor -> if (tensor.name in desiredTensors || desiredTensors == null || desiredTensors.isEmpty()) { - val newName = tensorRenaming - .getOrDefault( - "${inputTopic}.${tensor.name}", - tensor.name, - ) + val newName = + tensorRenaming + .getOrDefault( + "$inputTopic.${tensor.name}", + tensor.name, + ) - val convertedTensor = createOutputTensor(newName, tensor, response.rawOutputContentsCount>0) + val convertedTensor = createOutputTensor(newName, tensor, response.rawOutputContentsCount > 0) addOutputs(convertedTensor) if (idx < response.rawOutputContentsCount) { // TODO - should add in appropriate index for raw input contents! addRawOutputContents( - response.getRawOutputContents(idx) + response.getRawOutputContents(idx), ) } } @@ -267,14 +286,15 @@ private fun filterResponse( private fun createOutputTensor( tensorName: TensorName, input: ModelInferResponse.InferOutputTensor, - rawContents: Boolean + rawContents: Boolean, ): ModelInferResponse.InferOutputTensor { - val res = ModelInferResponse.InferOutputTensor - .newBuilder() - .setName(tensorName) - .setDatatype(input.datatype) - .addAllShape(input.shapeList) - .putAllParameters(input.parametersMap) + val res = + ModelInferResponse.InferOutputTensor + .newBuilder() + .setName(tensorName) + .setDatatype(input.datatype) + .addAllShape(input.shapeList) + .putAllParameters(input.parametersMap) if (!rawContents) { res.setContents(input.contents) } @@ -285,9 +305,14 @@ fun KStream.convertToResponse( inputPipeline: String, inputTopic: TopicName, desiredTensors: Set?, - tensorRenamingList: List + tensorRenamingList: List, ): KStream { - val tensorRenaming = tensorRenamingList.filter { it.pipelineName.equals(inputPipeline) }.map { it.topicAndTensor to it.tensorName }.toMap() + val tensorRenaming = + tensorRenamingList.filter { + it.pipelineName.equals( + inputPipeline, + ) + }.map { it.topicAndTensor to it.tensorName }.toMap() return this .mapValues { inferResponse -> convertToResponse( @@ -317,18 +342,19 @@ private fun convertToResponse( request.inputsList .forEachIndexed { idx, tensor -> if (tensor.name in desiredTensors || desiredTensors == null || desiredTensors.isEmpty()) { - val newName = tensorRenaming - .getOrDefault( - "${inputTopic}.${tensor.name}", - tensor.name, - ) - val convertedTensor = convertInputToOutputTensor(newName, tensor, request.rawInputContentsCount>0) + val newName = + tensorRenaming + .getOrDefault( + "$inputTopic.${tensor.name}", + tensor.name, + ) + val convertedTensor = convertInputToOutputTensor(newName, tensor, request.rawInputContentsCount > 0) addOutputs(convertedTensor) if (idx < request.rawInputContentsCount) { // TODO - should add in appropriate index for raw input contents! addRawOutputContents( - request.getRawInputContents(idx) + request.getRawInputContents(idx), ) } } @@ -339,14 +365,15 @@ private fun convertToResponse( private fun convertInputToOutputTensor( tensorName: TensorName, input: ModelInferRequest.InferInputTensor, - rawContents: Boolean + rawContents: Boolean, ): ModelInferResponse.InferOutputTensor { - val req = ModelInferResponse.InferOutputTensor - .newBuilder() - .setName(tensorName) - .setDatatype(input.datatype) - .addAllShape(input.shapeList) - .putAllParameters(input.parametersMap) + val req = + ModelInferResponse.InferOutputTensor + .newBuilder() + .setName(tensorName) + .setDatatype(input.datatype) + .addAllShape(input.shapeList) + .putAllParameters(input.parametersMap) if (!rawContents) { req.setContents(input.contents) } diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TriggerTransforms.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TriggerTransforms.kt index 532fcd15ef..c387d36a92 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TriggerTransforms.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TriggerTransforms.kt @@ -27,75 +27,102 @@ fun addTriggerTopology( if (inputTopics.isEmpty()) { when (pending) { null -> return lastStream - else -> return lastStream - .join( - pending, - ::joinTriggerRequests, - JoinWindows.ofTimeDifferenceWithNoGrace( - Duration.ofMillis(kafkaDomainParams.joinWindowMillis), - ), - joinSerde, - ) + else -> + return lastStream + .join( + pending, + ::joinTriggerRequests, + JoinWindows.ofTimeDifferenceWithNoGrace( + Duration.ofMillis(kafkaDomainParams.joinWindowMillis), + ), + joinSerde, + ) } } val topic = inputTopics.first() - val nextStream = builder //TODO possible bug - not all streams will be v2 requests? Maybe v2 responses? - .stream(topic.topicName, consumerSerde) - .filterForPipeline(topic.pipelineName) - .unmarshallInferenceV2Response() - .convertToRequest(topic.pipelineName, topic.topicName, tensorsByTopic?.get(topic), emptyList()) - // handle cases where there are no tensors we want - .filter { _, value -> value.inputsList.size != 0} - .marshallInferenceV2Request() + val nextStream = + builder // TODO possible bug - not all streams will be v2 requests? Maybe v2 responses? + .stream(topic.topicName, consumerSerde) + .filterForPipeline(topic.pipelineName) + .unmarshallInferenceV2Response() + .convertToRequest(topic.pipelineName, topic.topicName, tensorsByTopic?.get(topic), emptyList()) + // handle cases where there are no tensors we want + .filter { _, value -> value.inputsList.size != 0 } + .marshallInferenceV2Request() when (joinType) { ChainerOuterClass.PipelineStepUpdate.PipelineJoinType.Any -> { - val nextPending = pending - ?.outerJoin( - nextStream, - ::joinTriggerRequests, - //JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(1), Duration.ofMillis(1)), - // Required because this "fix" causes outer joins to wait for next record to come in if all streams - // don't produce a record during grace period. https://issues.apache.org/jira/browse/KAFKA-10847 - // Also see https://confluentcommunity.slack.com/archives/C6UJNMY67/p1649520904545229?thread_ts=1649324912.542999&cid=C6UJNMY67 - // Issue created at https://issues.apache.org/jira/browse/KAFKA-13813 - JoinWindows.of(Duration.ofMillis(1)), - joinSerde, - ) ?: nextStream + val nextPending = + pending + ?.outerJoin( + nextStream, + ::joinTriggerRequests, + // JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(1), Duration.ofMillis(1)), + // Required because this "fix" causes outer joins to wait for next record to come in if all streams + // don't produce a record during grace period. https://issues.apache.org/jira/browse/KAFKA-10847 + // Also see https://confluentcommunity.slack.com/archives/C6UJNMY67/p1649520904545229?thread_ts=1649324912.542999&cid=C6UJNMY67 + // Issue created at https://issues.apache.org/jira/browse/KAFKA-13813 + JoinWindows.of(Duration.ofMillis(1)), + joinSerde, + ) ?: nextStream - - return addTriggerTopology(kafkaDomainParams, builder, inputTopics.minus(topic), tensorsByTopic, joinType, lastStream, nextPending) + return addTriggerTopology( + kafkaDomainParams, + builder, + inputTopics.minus(topic), + tensorsByTopic, + joinType, + lastStream, + nextPending, + ) } ChainerOuterClass.PipelineStepUpdate.PipelineJoinType.Outer -> { - val nextPending = pending - ?.outerJoin( - nextStream, - ::joinTriggerRequests, - // See above for Any case as this will wait until next record comes in before emitting a result after window - JoinWindows.ofTimeDifferenceWithNoGrace( - Duration.ofMillis(kafkaDomainParams.joinWindowMillis), - ), - joinSerde, - ) ?: nextStream - + val nextPending = + pending + ?.outerJoin( + nextStream, + ::joinTriggerRequests, + // See above for Any case as this will wait until next record comes in before emitting a result after window + JoinWindows.ofTimeDifferenceWithNoGrace( + Duration.ofMillis(kafkaDomainParams.joinWindowMillis), + ), + joinSerde, + ) ?: nextStream - return addTriggerTopology(kafkaDomainParams, builder, inputTopics.minus(topic), tensorsByTopic, joinType, lastStream, nextPending) + return addTriggerTopology( + kafkaDomainParams, + builder, + inputTopics.minus(topic), + tensorsByTopic, + joinType, + lastStream, + nextPending, + ) } else -> { - val nextPending = pending - ?.join( - nextStream, - ::joinTriggerRequests, - JoinWindows.ofTimeDifferenceWithNoGrace( - Duration.ofMillis(kafkaDomainParams.joinWindowMillis), - ), - joinSerde, - ) ?: nextStream + val nextPending = + pending + ?.join( + nextStream, + ::joinTriggerRequests, + JoinWindows.ofTimeDifferenceWithNoGrace( + Duration.ofMillis(kafkaDomainParams.joinWindowMillis), + ), + joinSerde, + ) ?: nextStream - return addTriggerTopology(kafkaDomainParams, builder, inputTopics.minus(topic), tensorsByTopic, joinType, lastStream, nextPending) + return addTriggerTopology( + kafkaDomainParams, + builder, + inputTopics.minus(topic), + tensorsByTopic, + joinType, + lastStream, + nextPending, + ) } } } @@ -103,6 +130,9 @@ fun addTriggerTopology( // For triggers eventually we always want the left item which is the real data returned as its // join join ... // However for triggers joined to triggers its ok to return anyone that is not null -private fun joinTriggerRequests(left: ByteArray?, right: ByteArray?): ByteArray { +private fun joinTriggerRequests( + left: ByteArray?, + right: ByteArray?, +): ByteArray { return left ?: right!! -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TypeExtensions.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TypeExtensions.kt index d0699c7d63..dfe62ef02d 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TypeExtensions.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/TypeExtensions.kt @@ -14,4 +14,3 @@ operator fun Set?.contains(tensor: TensorName): Boolean { tensor in this } ?: true } - diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Types.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Types.kt index 020b9b5f85..750064e3de 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Types.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/Types.kt @@ -13,7 +13,7 @@ import org.apache.kafka.common.serialization.Serdes import org.apache.kafka.streams.kstream.Consumed import org.apache.kafka.streams.kstream.Produced import org.apache.kafka.streams.kstream.StreamJoined -import java.util.* +import java.util.Properties typealias KafkaProperties = Properties typealias KafkaAdminProperties = Properties @@ -33,10 +33,14 @@ enum class ChainType { INPUT_INPUT, INPUT_OUTPUT, OUTPUT_INPUT, - PASSTHROUGH; + PASSTHROUGH, + ; companion object { - fun create(input: String, output: String): ChainType { + fun create( + input: String, + output: String, + ): ChainType { return when (input.substringAfterLast(".") to output.substringAfterLast(".")) { "inputs" to "inputs" -> INPUT_INPUT "inputs" to "outputs" -> INPUT_OUTPUT @@ -47,4 +51,3 @@ enum class ChainType { } } } - diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/AlibiDetectRemover.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/AlibiDetectRemover.kt index 82f690a0fa..84b358d80f 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/AlibiDetectRemover.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/AlibiDetectRemover.kt @@ -29,4 +29,4 @@ class AlibiDetectRemover : ValueTransformer { } override fun close() {} -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineHeaderSetter.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineHeaderSetter.kt index ddfe98a6b5..4d2bd6883d 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineHeaderSetter.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineHeaderSetter.kt @@ -7,7 +7,6 @@ Use of this software is governed BY the Change License after the Change Date as each is defined in accordance with the LICENSE file. */ - package io.seldon.dataflow.kafka.headers import io.seldon.dataflow.kafka.TRecord @@ -22,10 +21,10 @@ class PipelineHeaderSetter(private val pipelineName: String) : ValueTransformer< } override fun transform(value: TRecord?): TRecord? { - this.context?.headers()?.remove(SeldonHeaders.pipelineName) - this.context?.headers()?.add(SeldonHeaders.pipelineName, pipelineName.toByteArray()) + this.context?.headers()?.remove(SeldonHeaders.PIPELINENAME) + this.context?.headers()?.add(SeldonHeaders.PIPELINENAME, pipelineName.toByteArray()) return value } override fun close() {} -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineNameFilter.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineNameFilter.kt index 4851ec831b..be317bf99a 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineNameFilter.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/kafka/headers/PipelineNameFilter.kt @@ -21,13 +21,14 @@ class PipelineNameFilter(private val pipelineName: String) : ValueTransformer { val password = generatePassword() val location = generateLocation() - val trustStore = KeyStore.getInstance(storeType) + val trustStore = KeyStore.getInstance(STORE_TYPE) trustStore.load(null, password.toCharArray()) certsFromFile(certPaths.brokerCaCertPath) @@ -70,7 +69,7 @@ object Provider { return FileInputStream(certFile) .use { certStream -> CertificateFactory - .getInstance(certificateType) + .getInstance(CERTIFICATE_TYPE) .generateCertificates(certStream) } } @@ -86,7 +85,7 @@ object Provider { private fun keyStoreFromCerts(certPaths: CertificateConfig): Pair { val password = generatePassword() val location = generateLocation() - val keyStore = KeyStore.getInstance(storeType) + val keyStore = KeyStore.getInstance(STORE_TYPE) keyStore.load(null, password.toCharArray()) val privateKey = privateKeyFromFile(certPaths.keyPath) @@ -94,9 +93,10 @@ object Provider { val caCerts = certsFromFile(certPaths.caCertPath) // TODO - check if CA certs are required as part of the chain. Docs imply this, but unclear. keyStore.setKeyEntry( - keyName, + KEY_NAME, privateKey, - password.toCharArray(), // No password + // No password for private key + password.toCharArray(), certs.union(caCerts).toTypedArray(), ) @@ -112,10 +112,11 @@ object Provider { fun privateKeyFromFile(filename: FilePath): RSAPrivateKey { val file = File(filename) val key = String(Files.readAllBytes(file.toPath()), Charset.defaultCharset()) - val privateKeyPEM = key - .replace("-----BEGIN PRIVATE KEY-----", "") - .replace(System.lineSeparator().toRegex(), "") - .replace("-----END PRIVATE KEY-----", "") + val privateKeyPEM = + key + .replace("-----BEGIN PRIVATE KEY-----", "") + .replace(System.lineSeparator().toRegex(), "") + .replace("-----END PRIVATE KEY-----", "") val encoded: ByteArray = Base64.getDecoder().decode(privateKeyPEM) val keyFactory = KeyFactory.getInstance("RSA") val keySpec = PKCS8EncodedKeySpec(encoded) @@ -139,4 +140,4 @@ object Provider { ) .toFile() } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/KubernetesSecretProvider.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/KubernetesSecretProvider.kt index 33d4c21a2b..e09dabeccc 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/KubernetesSecretProvider.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/KubernetesSecretProvider.kt @@ -7,7 +7,6 @@ Use of this software is governed BY the Change License after the Change Date as each is defined in accordance with the LICENSE file. */ - package io.seldon.dataflow.sasl import io.klogging.noCoLogger @@ -37,4 +36,4 @@ object KubernetesSecretProvider : SecretsProvider { mapOf() } } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslOauthProvider.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslOauthProvider.kt index e92b44e113..40695f7181 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslOauthProvider.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslOauthProvider.kt @@ -7,7 +7,6 @@ Use of this software is governed BY the Change License after the Change Date as each is defined in accordance with the LICENSE file. */ - package io.seldon.dataflow.sasl import io.klogging.noCoLogger @@ -31,11 +30,11 @@ class SaslOauthProvider(private val secretsProvider: SecretsProvider) { } private fun Map.toOauthConfig(): SaslOauthConfig { - val clientId = this.getValue(clientIdKey).toString(Charsets.UTF_8) - val clientSecret = this.getValue(clientSecretKey).toString(Charsets.UTF_8) - val tokenUrl = this.getValue(tokenUrlKey).toString(Charsets.UTF_8) - val scope = this.getValue(scopeKey).toString(Charsets.UTF_8) - val extensions = this.getValue(extensionsKey).toString(Charsets.UTF_8).toExtensions() + val clientId = this.getValue(CLIENT_ID_KEY).toString(Charsets.UTF_8) + val clientSecret = this.getValue(CLIENT_SECRET_KEY).toString(Charsets.UTF_8) + val tokenUrl = this.getValue(TOKEN_URL_KEY).toString(Charsets.UTF_8) + val scope = this.getValue(SCOPE_KEY).toString(Charsets.UTF_8) + val extensions = this.getValue(EXTENSIONS_KEY).toString(Charsets.UTF_8).toExtensions() return SaslOauthConfig( tokenUrl = tokenUrl, @@ -60,9 +59,10 @@ class SaslOauthProvider(private val secretsProvider: SecretsProvider) { .map { it.split("=", limit = 2) } .map { parts -> val k = parts.first() - val v = parts.last().let { - if (it.startsWith('"')) it else """"$it"""" - } + val v = + parts.last().let { + if (it.startsWith('"')) it else """"$it"""" + } k to v } @@ -77,11 +77,11 @@ class SaslOauthProvider(private val secretsProvider: SecretsProvider) { } companion object { - private const val clientIdKey = "client_id" - private const val clientSecretKey = "client_secret" - private const val tokenUrlKey = "token_endpoint_url" - private const val scopeKey = "scope" - private const val extensionsKey = "extensions" + private const val CLIENT_ID_KEY = "client_id" + private const val CLIENT_SECRET_KEY = "client_secret" + private const val TOKEN_URL_KEY = "token_endpoint_url" + private const val SCOPE_KEY = "scope" + private const val EXTENSIONS_KEY = "extensions" private val logger = noCoLogger(SaslOauthProvider::class) diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslPasswordProvider.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslPasswordProvider.kt index b0c2d07600..ffd7586242 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslPasswordProvider.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/sasl/SaslPasswordProvider.kt @@ -7,14 +7,12 @@ Use of this software is governed BY the Change License after the Change Date as each is defined in accordance with the LICENSE file. */ - package io.seldon.dataflow.sasl import io.klogging.noCoLogger import io.seldon.dataflow.kafka.security.SaslConfig class SaslPasswordProvider(private val secretsProvider: SecretsProvider) { - fun getPassword(config: SaslConfig.Password): String { logger.info("retrieving password for SASL user") @@ -22,7 +20,11 @@ class SaslPasswordProvider(private val secretsProvider: SecretsProvider) { return extractPassword(config.secretName, secret, config.passwordField) } - private fun extractPassword(secretName: String, secret: Map, fieldName: String): String { + private fun extractPassword( + secretName: String, + secret: Map, + fieldName: String, + ): String { return when (val password = secret[fieldName]) { null -> { logger.warn("unable to retrieve password for SASL user from secret $secretName at path $fieldName") diff --git a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt index 87dd3b37b2..0899fc35c9 100644 --- a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt +++ b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt @@ -7,7 +7,6 @@ Use of this software is governed BY the Change License after the Change Date as each is defined in accordance with the LICENSE file. */ - package io.seldon.dataflow import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms @@ -21,10 +20,12 @@ import strikt.assertions.isSuccess import java.util.stream.Stream internal class CliTest { - @ParameterizedTest(name = "{0}") @MethodSource("saslMechanisms") - fun getSaslMechanism(input: String, expectedMechanism: KafkaSaslMechanisms) { + fun getSaslMechanism( + input: String, + expectedMechanism: KafkaSaslMechanisms, + ) { val args = arrayOf("--kafka-sasl-mechanism", input) val cli = Cli.configWith(args) @@ -43,4 +44,4 @@ internal class CliTest { ) } } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/PipelineSubscriberTest.kt b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/PipelineSubscriberTest.kt index 78d20c4581..2b5679c43a 100644 --- a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/PipelineSubscriberTest.kt +++ b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/PipelineSubscriberTest.kt @@ -11,14 +11,17 @@ package io.seldon.dataflow import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.async -import kotlinx.coroutines.flow.* +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.flatMapMerge +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Test import java.time.LocalDateTime @OptIn(FlowPreview::class) internal class PipelineSubscriberTest { - @Test fun `should run sequentially`() { suspend fun waitAndPrint(i: Int) { @@ -45,7 +48,7 @@ internal class PipelineSubscriberTest { async { kotlinx.coroutines.delay(1000) println("${LocalDateTime.now()} - $it") - } + }, ) } } @@ -68,4 +71,4 @@ internal class PipelineSubscriberTest { .collect() } } -} \ No newline at end of file +} diff --git a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/BatchProcessorTest.kt b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/BatchProcessorTest.kt index 922b70f86f..b61921c978 100644 --- a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/BatchProcessorTest.kt +++ b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/BatchProcessorTest.kt @@ -22,7 +22,10 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.Arguments.arguments import org.junit.jupiter.params.provider.MethodSource import strikt.api.expectThat -import strikt.assertions.* +import strikt.assertions.isEmpty +import strikt.assertions.isEqualTo +import strikt.assertions.isGreaterThan +import strikt.assertions.isNull import java.util.stream.Stream internal class BatchProcessorTest { @@ -31,11 +34,12 @@ internal class BatchProcessorTest { val mockContext = MockProcessorContext() val batcher = BatchProcessor(3) val expected = makeRequest("3", listOf(12.34F, 12.34F, 12.34F)) - val requests = listOf( - makeRequest("1", listOf(12.34F)), - makeRequest("2", listOf(12.34F)), - makeRequest("3", listOf(12.34F)), - ) + val requests = + listOf( + makeRequest("1", listOf(12.34F)), + makeRequest("2", listOf(12.34F)), + makeRequest("3", listOf(12.34F)), + ) batcher.init(mockContext) val merged = batcher.merge(requests) @@ -48,11 +52,12 @@ internal class BatchProcessorTest { val mockContext = MockProcessorContext() val batcher = BatchProcessor(3) val expected = makeRequest("3", emptyList()) - val requests = listOf( - makeRequest("1", emptyList()), - makeRequest("2", emptyList()), - makeRequest("3", emptyList()), - ) + val requests = + listOf( + makeRequest("1", emptyList()), + makeRequest("2", emptyList()), + makeRequest("3", emptyList()), + ) batcher.init(mockContext) val merged = batcher.merge(requests) @@ -65,11 +70,12 @@ internal class BatchProcessorTest { val mockContext = MockProcessorContext() val batcher = BatchProcessor(3) val expected = makeRequest("3", listOf(12.34F, 12.34F, 12.34F)).withBinaryContents() - val requests = listOf( - makeRequest("1", listOf(12.34F)).withBinaryContents(), - makeRequest("2", listOf(12.34F)).withBinaryContents(), - makeRequest("3", listOf(12.34F)).withBinaryContents(), - ) + val requests = + listOf( + makeRequest("1", listOf(12.34F)).withBinaryContents(), + makeRequest("2", listOf(12.34F)).withBinaryContents(), + makeRequest("3", listOf(12.34F)).withBinaryContents(), + ) batcher.init(mockContext) val merged = batcher.merge(requests) @@ -80,15 +86,16 @@ internal class BatchProcessorTest { @Test fun `should only forward when batch size met`() { val mockContext = MockProcessorContext() - val store = Stores - .keyValueStoreBuilder( - Stores.inMemoryKeyValueStore(BatchProcessor.STATE_STORE_ID), - Serdes.String(), - Serdes.ByteArray(), - ) - .withCachingDisabled() - .withLoggingDisabled() - .build() + val store = + Stores + .keyValueStoreBuilder( + Stores.inMemoryKeyValueStore(BatchProcessor.STATE_STORE_ID), + Serdes.String(), + Serdes.ByteArray(), + ) + .withCachingDisabled() + .withLoggingDisabled() + .build() val batchSize = 10 val batcher = BatchProcessor(batchSize) val streamKey = "789" @@ -109,13 +116,14 @@ internal class BatchProcessorTest { val batchRequest = makeRequest(batchSize.toString(), listOf(batchSize.toFloat())) val batched = batcher.transform(streamKey, batchRequest) - val expected = KeyValue( - streamKey, - makeRequest( - batchRequest.id, - (1..batchSize).map { it.toFloat() }, + val expected = + KeyValue( + streamKey, + makeRequest( + batchRequest.id, + (1..batchSize).map { it.toFloat() }, + ), ) - ) expectThat(batched).isEqualTo(expected) expectThat(mockContext.forwarded()).isEmpty() expectThat(store.approximateNumEntries()).isEqualTo(0) @@ -187,7 +195,10 @@ internal class BatchProcessorTest { ) } - private fun makeRequest(id: String, values: List): ModelInferRequest { + private fun makeRequest( + id: String, + values: List, + ): ModelInferRequest { return ModelInferRequest .newBuilder() .setId(id) @@ -197,14 +208,14 @@ internal class BatchProcessorTest { .setName("preprocessed_image") .setDatatype("FP32") .addAllShape( - listOf(values.size.toLong(), 1, 1, 1) + listOf(values.size.toLong(), 1, 1, 1), ) .setContents( InferTensorContents .newBuilder() .addAllFp32Contents(values) - .build() - ) + .build(), + ), ) .build() } diff --git a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/PipelineStepTest.kt b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/PipelineStepTest.kt index 9355b1ef14..b025101fdd 100644 --- a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/PipelineStepTest.kt +++ b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/kafka/PipelineStepTest.kt @@ -18,11 +18,12 @@ import org.junit.jupiter.params.provider.Arguments.arguments import org.junit.jupiter.params.provider.MethodSource import strikt.api.Assertion import strikt.api.expect -import strikt.assertions.* +import strikt.assertions.isEqualTo +import strikt.assertions.isNotNull +import strikt.assertions.isNull import java.util.stream.Stream internal class PipelineStepTest { - @ParameterizedTest(name = "{0}") @MethodSource fun stepFor( @@ -33,7 +34,7 @@ internal class PipelineStepTest { val result = stepFor( StreamsBuilder(), - defaultPipelineName, + DEFAULT_PIPELINENAME, sources, emptyList(), emptyList(), @@ -53,14 +54,14 @@ internal class PipelineStepTest { } companion object { - private const val defaultPipelineName = "some-pipeline" - private val defaultPipelineTopic = PipelineTopic.newBuilder() - .setTopicName("seldon.namespace.sinkModel.inputs") - .setPipelineName(defaultPipelineName).build() - private val defaultSink = TopicForPipeline(topicName = "seldon.namespace.sinkModel.inputs", pipelineName = defaultPipelineName) + private const val DEFAULT_PIPELINENAME = "some-pipeline" + private val defaultPipelineTopic = + PipelineTopic.newBuilder() + .setTopicName("seldon.namespace.sinkModel.inputs") + .setPipelineName(DEFAULT_PIPELINENAME).build() + private val defaultSink = TopicForPipeline(topicName = "seldon.namespace.sinkModel.inputs", pipelineName = DEFAULT_PIPELINENAME) private val kafkaDomainParams = KafkaDomainParams(useCleanState = true, joinWindowMillis = 1_000L) - @JvmStatic fun stepFor(): Stream = Stream.of( @@ -68,86 +69,132 @@ internal class PipelineStepTest { arguments( "single source, no tensors", makeChainerFor( - inputTopic = TopicForPipeline(topicName = "seldon.namespace.model.model11.outputs", pipelineName = defaultPipelineName), + inputTopic = + TopicForPipeline( + topicName = "seldon.namespace.model.model11.outputs", + pipelineName = DEFAULT_PIPELINENAME, + ), tensors = null, ), - listOf(PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.model11.outputs").setPipelineName(defaultPipelineName).build()), + listOf( + PipelineTopic.newBuilder().setTopicName( + "seldon.namespace.model.model11.outputs", + ).setPipelineName(DEFAULT_PIPELINENAME).build(), + ), ), arguments( "single source, one tensor", makeChainerFor( - inputTopic = TopicForPipeline(topicName = "seldon.namespace.model.model1.outputs", pipelineName = defaultPipelineName), - tensors = setOf("tensorA") + inputTopic = + TopicForPipeline( + topicName = "seldon.namespace.model.model1.outputs", + pipelineName = DEFAULT_PIPELINENAME, + ), + tensors = setOf("tensorA"), + ), + listOf( + PipelineTopic.newBuilder().setTopicName( + "seldon.namespace.model.model1.outputs", + ).setPipelineName(DEFAULT_PIPELINENAME).setTensor("tensorA").build(), ), - listOf(PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.model1.outputs").setPipelineName(defaultPipelineName).setTensor("tensorA").build()), ), arguments( "single source, multiple tensors", makeChainerFor( - inputTopic = TopicForPipeline(topicName = "seldon.namespace.model.model1.outputs", pipelineName = defaultPipelineName), - tensors = setOf("tensorA", "tensorB") + inputTopic = + TopicForPipeline( + topicName = "seldon.namespace.model.model1.outputs", + pipelineName = DEFAULT_PIPELINENAME, + ), + tensors = setOf("tensorA", "tensorB"), ), listOf( PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.model1.outputs").setPipelineName( - defaultPipelineName).setTensor("tensorA").build(), + DEFAULT_PIPELINENAME, + ).setTensor("tensorA").build(), PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.model1.outputs").setPipelineName( - defaultPipelineName).setTensor("tensorB").build(), + DEFAULT_PIPELINENAME, + ).setTensor("tensorB").build(), ), ), arguments( "multiple sources, no tensors", makeJoinerFor( - inputTopics = setOf(TopicForPipeline(topicName = "seldon.namespace.model.modelA.outputs", pipelineName = defaultPipelineName), - TopicForPipeline(topicName = "seldon.namespace.model.modelB.outputs", pipelineName = defaultPipelineName)), + inputTopics = + setOf( + TopicForPipeline(topicName = "seldon.namespace.model.modelA.outputs", pipelineName = DEFAULT_PIPELINENAME), + TopicForPipeline(topicName = "seldon.namespace.model.modelB.outputs", pipelineName = DEFAULT_PIPELINENAME), + ), tensorsByTopic = null, ), listOf( PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.modelA.outputs").setPipelineName( - defaultPipelineName).build(), + DEFAULT_PIPELINENAME, + ).build(), PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.modelB.outputs").setPipelineName( - defaultPipelineName).build(), + DEFAULT_PIPELINENAME, + ).build(), ), ), arguments( "multiple sources, multiple tensors", makeJoinerFor( - inputTopics = setOf( - TopicForPipeline(topicName = "seldon.namespace.model.modelA.outputs", pipelineName = defaultPipelineName), - TopicForPipeline(topicName = "seldon.namespace.model.modelB.outputs", pipelineName = defaultPipelineName), - ), - tensorsByTopic = mapOf( - TopicForPipeline(topicName = "seldon.namespace.model.modelA.outputs", pipelineName = defaultPipelineName) to setOf("tensor1"), - TopicForPipeline(topicName = "seldon.namespace.model.modelB.outputs", pipelineName = defaultPipelineName) to setOf("tensor2"), - ), + inputTopics = + setOf( + TopicForPipeline(topicName = "seldon.namespace.model.modelA.outputs", pipelineName = DEFAULT_PIPELINENAME), + TopicForPipeline(topicName = "seldon.namespace.model.modelB.outputs", pipelineName = DEFAULT_PIPELINENAME), + ), + tensorsByTopic = + mapOf( + TopicForPipeline( + topicName = "seldon.namespace.model.modelA.outputs", + pipelineName = DEFAULT_PIPELINENAME, + ) to setOf("tensor1"), + TopicForPipeline( + topicName = "seldon.namespace.model.modelB.outputs", + pipelineName = DEFAULT_PIPELINENAME, + ) to setOf("tensor2"), + ), ), listOf( PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.modelA.outputs").setPipelineName( - defaultPipelineName).setTensor("tensor1").build(), + DEFAULT_PIPELINENAME, + ).setTensor("tensor1").build(), PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.modelB.outputs").setPipelineName( - defaultPipelineName).setTensor("tensor2").build(), + DEFAULT_PIPELINENAME, + ).setTensor("tensor2").build(), ), ), arguments( "tensors override plain topic", makeChainerFor( - inputTopic = TopicForPipeline(topicName = "seldon.namespace.model.modelA.outputs", pipelineName = defaultPipelineName), + inputTopic = + TopicForPipeline( + topicName = "seldon.namespace.model.modelA.outputs", + pipelineName = DEFAULT_PIPELINENAME, + ), tensors = setOf("tensorA"), ), listOf( PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.modelA.outputs").setPipelineName( - defaultPipelineName).setTensor("tensorA").build(), + DEFAULT_PIPELINENAME, + ).setTensor("tensorA").build(), PipelineTopic.newBuilder().setTopicName("seldon.namespace.model.modelA.outputs").setPipelineName( - defaultPipelineName).build(), + DEFAULT_PIPELINENAME, + ).build(), ), ), ) - private fun makeChainerFor(inputTopic: TopicForPipeline, tensors: Set?): Chainer = + private fun makeChainerFor( + inputTopic: TopicForPipeline, + tensors: Set?, + ): Chainer = Chainer( StreamsBuilder(), inputTopic = inputTopic, tensors = tensors, - pipelineName = defaultPipelineName, + pipelineName = DEFAULT_PIPELINENAME, outputTopic = defaultSink, tensorRenaming = emptyList(), kafkaDomainParams = kafkaDomainParams, @@ -165,7 +212,7 @@ internal class PipelineStepTest { StreamsBuilder(), inputTopics = inputTopics, tensorsByTopic = tensorsByTopic, - pipelineName = defaultPipelineName, + pipelineName = DEFAULT_PIPELINENAME, outputTopic = defaultSink, tensorRenaming = emptyList(), kafkaDomainParams = kafkaDomainParams, @@ -188,22 +235,24 @@ fun Assertion.Builder.isSameTypeAs(other: PipelineStep) = fun Assertion.Builder.matches(expected: PipelineStep) = assert("Type and values are the same") { when { - it is Chainer && expected is Chainer -> expect { - that(it) { - get { inputTopic }.isEqualTo(expected.inputTopic) - get { outputTopic }.isEqualTo(expected.outputTopic) - get { tensors }.isEqualTo(expected.tensors) + it is Chainer && expected is Chainer -> + expect { + that(it) { + get { inputTopic }.isEqualTo(expected.inputTopic) + get { outputTopic }.isEqualTo(expected.outputTopic) + get { tensors }.isEqualTo(expected.tensors) + } } - } - it is Joiner && expected is Joiner -> expect { - that(it) { - get { inputTopics }.isEqualTo(expected.inputTopics) - get { outputTopic }.isEqualTo(expected.outputTopic) - get { tensorsByTopic }.isEqualTo(expected.tensorsByTopic) - get { tensorRenaming }.isEqualTo(expected.tensorRenaming) - get { kafkaDomainParams }.isEqualTo(expected.kafkaDomainParams) + it is Joiner && expected is Joiner -> + expect { + that(it) { + get { inputTopics }.isEqualTo(expected.inputTopics) + get { outputTopic }.isEqualTo(expected.outputTopic) + get { tensorsByTopic }.isEqualTo(expected.tensorsByTopic) + get { tensorRenaming }.isEqualTo(expected.tensorRenaming) + get { kafkaDomainParams }.isEqualTo(expected.kafkaDomainParams) + } } - } else -> fail(actual = expected) } - } \ No newline at end of file + }