Skip to content

Commit

Permalink
feat(JAQPOT-432): support docker runtime + docker_llm runtime (#130)
Browse files Browse the repository at this point in the history
* feat(JAQPOT-432): support docker config

* feat: DOCKER_LLM model type
  • Loading branch information
alarv authored Dec 12, 2024
1 parent 1cf596d commit fe54f11
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jaqpot/api/entity/Model.kt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Model(
val rPbpkOdeSolver: String? = null,

@OneToOne(mappedBy = "model", fetch = FetchType.LAZY, cascade = [CascadeType.PERSIST])
val dockerConfig: DockerConfig? = null,
var dockerConfig: DockerConfig? = null,

@OneToMany(mappedBy = "model", cascade = [CascadeType.ALL], orphanRemoval = true)
@SQLRestriction("score_type = 'TEST'")
Expand Down
1 change: 1 addition & 0 deletions src/main/kotlin/org/jaqpot/api/entity/ModelType.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jaqpot.api.entity
enum class ModelType {
// DOCKER models
DOCKER,
DOCKER_LLM,

SKLEARN_ONNX,

Expand Down
20 changes: 20 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DockerConfigMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DockerConfig
import org.jaqpot.api.entity.Model
import org.jaqpot.api.model.DockerConfigDto

fun DockerConfig.toDto(): DockerConfigDto {
return DockerConfigDto(
appName = this.appName,
dockerImage = this.dockerImage
)
}

fun DockerConfigDto.toEntity(model: Model): DockerConfig {
return DockerConfig(
appName = this.appName,
model = model,
dockerImage = this.dockerImage
)
}
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fun Model.toDto(userDto: UserDto? = null, userCanEdit: Boolean? = null, isAdmin:
test = this.testScores?.map { it.toDto() },
crossValidation = this.crossValidationScores?.map { it.toDto() },
),
dockerConfig = this.dockerConfig?.toDto(),
archived = this.archived,
archivedAt = this.archivedAt,
createdAt = this.createdAt,
Expand Down Expand Up @@ -96,6 +97,7 @@ fun ModelDto.toEntity(creatorId: String): Model {
this.scores?.crossValidation?.let {
m.crossValidationScores = this.scores.crossValidation.map { it.toEntity(m, ScoreType.CROSS_VALIDATION) }
}
this.dockerConfig?.let { m.dockerConfig = it.toEntity(m) }

return m
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/ModelTypeMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ fun ModelTypeDto.toEntity(): ModelType {
ModelTypeDto.QSAR_TOOLBOX_QSAR_MODEL -> ModelType.QSAR_TOOLBOX_QSAR_MODEL
ModelTypeDto.QSAR_TOOLBOX_PROFILER -> ModelType.QSAR_TOOLBOX_PROFILER
ModelTypeDto.DOCKER -> ModelType.DOCKER
ModelTypeDto.DOCKER_LLM -> ModelType.DOCKER_LLM
}
}

Expand All @@ -46,5 +47,6 @@ fun ModelType.toDto(): ModelTypeDto {
ModelType.QSAR_TOOLBOX_QSAR_MODEL -> ModelTypeDto.QSAR_TOOLBOX_QSAR_MODEL
ModelType.QSAR_TOOLBOX_PROFILER -> ModelTypeDto.QSAR_TOOLBOX_PROFILER
ModelType.DOCKER -> ModelTypeDto.DOCKER
ModelType.DOCKER_LLM -> ModelTypeDto.DOCKER_LLM
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@ import org.jaqpot.api.model.PredictionModelDto
import org.jaqpot.api.model.PredictionRequestDto
import org.jaqpot.api.repository.DockerConfigRepository
import org.jaqpot.api.service.prediction.runtime.config.RuntimeConfiguration
import org.jaqpot.api.service.prediction.runtime.runtimes.util.HttpClientUtil
import org.springframework.stereotype.Service
import org.springframework.web.util.UriComponentsBuilder
import reactor.netty.http.client.HttpClient
import java.net.URI

@Service
class JaqpotDockerRuntime(
private val runtimeConfiguration: RuntimeConfiguration,
private val dockerConfigRepository: DockerConfigRepository
) : RuntimeBase() {

companion object {
val dockerRuntimeHttpClient = HttpClientUtil.generateHttpClient(30, 30, 30, 30, 30)
}

override fun createRequestBody(
predictionModelDto: PredictionModelDto,
datasetDto: DatasetDto
Expand Down Expand Up @@ -43,4 +50,8 @@ class JaqpotDockerRuntime(
override fun getRuntimePath(predictionModelDto: PredictionModelDto): String {
return "/infer"
}

override fun getHttpClient(): HttpClient {
return dockerRuntimeHttpClient
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
package org.jaqpot.api.service.prediction.runtime.runtimes

import io.netty.channel.ChannelOption
import io.netty.handler.timeout.ReadTimeoutHandler
import io.netty.handler.timeout.WriteTimeoutHandler
import org.jaqpot.api.model.DatasetDto
import org.jaqpot.api.model.ModelTypeDto
import org.jaqpot.api.model.PredictionModelDto
import org.jaqpot.api.model.PredictionRequestDto
import org.jaqpot.api.service.prediction.runtime.config.RuntimeConfiguration
import org.jaqpot.api.service.prediction.runtime.runtimes.util.HttpClientUtil
import org.springframework.stereotype.Component
import reactor.netty.http.client.HttpClient
import reactor.netty.resources.ConnectionProvider
import java.time.Duration
import java.util.concurrent.TimeUnit

@Component
class JaqpotRV6Runtime(private val runtimeConfiguration: RuntimeConfiguration) : RuntimeBase() {
Expand All @@ -30,22 +25,7 @@ class JaqpotRV6Runtime(private val runtimeConfiguration: RuntimeConfiguration) :
ModelTypeDto.R_TREE_REGR to "predict_tree_regr",
)

val connectionProvider =
ConnectionProvider.builder("custom")
.maxIdleTime(Duration.ofMinutes(10))
.maxLifeTime(Duration.ofMinutes(10))
.build()

val RHttpClient = HttpClient.create(connectionProvider)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.option(ChannelOption.SO_KEEPALIVE, true)
.responseTimeout(Duration.ofMinutes(10))
.doOnConnected { conn ->
conn.addHandlerLast(ReadTimeoutHandler(10, TimeUnit.MINUTES))
.addHandlerLast(WriteTimeoutHandler(10, TimeUnit.MINUTES))


}
val RHttpClient = HttpClientUtil.generateHttpClient(10, 10, 10, 10, 10)
}

override fun getRuntimeUrl(predictionModelDto: PredictionModelDto): String {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.jaqpot.api.service.prediction.runtime.runtimes.util

import io.netty.channel.ChannelOption
import io.netty.handler.timeout.ReadTimeoutHandler
import io.netty.handler.timeout.WriteTimeoutHandler
import reactor.netty.http.client.HttpClient
import reactor.netty.resources.ConnectionProvider
import java.time.Duration
import java.util.concurrent.TimeUnit

class HttpClientUtil {
companion object {

fun generateHttpClient(
maxIdleTimeInMin: Long,
maxLifeTimeInMin: Long,
responseTimeoutInMin: Long,
readTimeoutInMin: Long,
writeTimeoutInMin: Long
): HttpClient {
val connectionProvider =
ConnectionProvider.builder("custom")
.maxIdleTime(Duration.ofMinutes(maxIdleTimeInMin))
.maxLifeTime(Duration.ofMinutes(maxLifeTimeInMin))
.build()

return HttpClient.create(connectionProvider)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.option(ChannelOption.SO_KEEPALIVE, true)
.responseTimeout(Duration.ofMinutes(responseTimeoutInMin))
.doOnConnected { conn ->
conn.addHandlerLast(ReadTimeoutHandler(readTimeoutInMin, TimeUnit.MINUTES))
.addHandlerLast(WriteTimeoutHandler(writeTimeoutInMin, TimeUnit.MINUTES))
}
}
}
}
40 changes: 40 additions & 0 deletions src/main/kotlin/org/jaqpot/api/storage/LocalStorage.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.jaqpot.api.storage

import org.jaqpot.api.error.JaqpotRuntimeException
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Service
import java.util.*

@Profile("local")
@Service
class LocalStorage : Storage {
override fun getObject(bucketName: String, keyName: String): Optional<ByteArray> {
return Optional.empty()
}

override fun getObjects(bucketName: String, keyNames: List<String>): Map<String, ByteArray> {
return emptyMap()
}

override fun listObjects(bucketName: String, prefix: String): List<String> {
return emptyList()
}

override fun putObject(bucketName: String, keyName: String, obj: ByteArray, metadata: Map<String, String>) {
throw JaqpotRuntimeException("Not implemented")
}

override fun putObject(
bucketName: String,
keyName: String,
contentType: String,
obj: ByteArray,
metadata: Map<String, String>
) {
throw JaqpotRuntimeException("Not implemented")
}

override fun deleteObject(bucketName: String, keyName: String) {
throw JaqpotRuntimeException("Not implemented")
}
}
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/storage/s3/S3Storage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.jaqpot.api.aws.AWSConfig
import org.jaqpot.api.storage.Storage
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Service
import software.amazon.awssdk.core.sync.RequestBody
import software.amazon.awssdk.services.s3.S3Client
Expand All @@ -15,6 +16,7 @@ import software.amazon.awssdk.services.s3.model.PutObjectRequest
import java.util.*


@Profile("!local")
@Service
class S3Storage(
private val awsConfig: AWSConfig, private val s3Client: S3Client
Expand Down
1 change: 1 addition & 0 deletions src/main/resources/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,7 @@ components:
- R_TREE_CLASS
- R_TREE_REGR
- DOCKER
- DOCKER_LLM
- QSAR_TOOLBOX_CALCULATOR
- QSAR_TOOLBOX_QSAR_MODEL
- QSAR_TOOLBOX_PROFILER
Expand Down

0 comments on commit fe54f11

Please sign in to comment.