Skip to content

Commit

Permalink
feat(JAQPOT-432): streaming prediction (#131)
Browse files Browse the repository at this point in the history
* feat: working streaming of docker LLM

* fix: store dataset to storage before execution too to avoid warnings

* feat: add dataset name

* feat: allow creation of dataset and chat type

* feat: working version of streaming predictions

* fix: remove duplicate code

* fix: rename to jaqpot docker model runtime

* fix: create DockerRuntimeUtil
  • Loading branch information
alarv authored Dec 17, 2024
1 parent fe54f11 commit b37a931
Show file tree
Hide file tree
Showing 20 changed files with 579 additions and 77 deletions.
10 changes: 8 additions & 2 deletions src/main/kotlin/org/jaqpot/api/entity/Dataset.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class Dataset(
@JoinColumn(name = "model_id", updatable = false, nullable = false)
val model: Model,

@Column
val name: String? = null,

@Column(nullable = false)
val userId: String,

Expand All @@ -34,7 +37,7 @@ class Dataset(

@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "input", columnDefinition = "jsonb")
val input: List<Any>? = null,
var input: List<Any>? = null,

@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "result", columnDefinition = "jsonb")
Expand All @@ -48,8 +51,11 @@ class Dataset(

var executionFinishedAt: OffsetDateTime? = null
) : BaseEntity() {
/**
* This is to avoid querying s3 for non-existing results.
*/
fun shouldHaveResult(): Boolean {
return this.type == DatasetType.PREDICTION && this.status == DatasetStatus.SUCCESS
return (this.type == DatasetType.PREDICTION && this.status == DatasetStatus.SUCCESS) || this.type == DatasetType.CHAT
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/kotlin/org/jaqpot/api/entity/DatasetType.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jaqpot.api.entity

enum class DatasetType {
PREDICTION
PREDICTION,
CHAT
}
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ fun Dataset.toDto(input: List<Any>, result: List<Any>?): DatasetDto {
id = this.id,
entryType = this.entryType.toDto(),
status = this.status.toDto(),
name = this.name,
failureReason = this.failureReason,
input = input,
result = result,
Expand All @@ -32,6 +33,7 @@ fun DatasetDto.toEntity(model: Model, userId: String, entryType: DatasetEntryTyp
id = this.id,
model = model,
userId = userId,
name = this.name,
entryType = entryType,
type = this.type.toEntity(),
status = DatasetStatus.CREATED,
Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import org.jaqpot.api.model.DatasetTypeDto
fun DatasetTypeDto.toEntity(): DatasetType {
return when (this) {
DatasetTypeDto.PREDICTION -> DatasetType.PREDICTION
DatasetTypeDto.CHAT -> DatasetType.CHAT
}
}

fun DatasetType.toDto(): DatasetTypeDto {
return when (this) {
DatasetType.PREDICTION -> DatasetTypeDto.PREDICTION
DatasetType.CHAT -> DatasetTypeDto.CHAT
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import org.springframework.data.jpa.repository.Query
import org.springframework.data.repository.CrudRepository
import org.springframework.data.repository.query.Param
import java.time.OffsetDateTime
import java.util.*


interface DatasetRepository : CrudRepository<Dataset, Long> {
fun findAllByUserId(userId: String, pageable: Pageable): Page<Dataset>
fun findAllByUserIdAndModelId(userId: String, modelId: Long, pageable: Pageable): Page<Dataset>

fun findAllByCreatedAtBefore(date: OffsetDateTime): List<Dataset>

fun findByIdAndModelId(id: Long, modelId: Long): Optional<Dataset>

@Modifying
@Transactional
@Query("UPDATE Dataset d SET d.input = NULL, d.result = NULL WHERE d.id = :id")
Expand Down
30 changes: 30 additions & 0 deletions src/main/kotlin/org/jaqpot/api/service/dataset/DatasetService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package org.jaqpot.api.service.dataset
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.transaction.Transactional
import org.jaqpot.api.DatasetApiDelegate
import org.jaqpot.api.entity.DatasetEntryType
import org.jaqpot.api.error.JaqpotRuntimeException
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.mapper.toEntity
import org.jaqpot.api.mapper.toGetDatasets200ResponseDto
import org.jaqpot.api.model.DatasetDto
import org.jaqpot.api.model.GetDatasets200ResponseDto
import org.jaqpot.api.repository.DatasetRepository
import org.jaqpot.api.repository.ModelRepository
import org.jaqpot.api.service.authentication.AuthenticationFacade
import org.jaqpot.api.service.util.SortUtil.Companion.parseSortParameters
import org.jaqpot.api.storage.StorageService
Expand All @@ -22,6 +26,7 @@ import java.time.OffsetDateTime
@Service
class DatasetService(
private val datasetRepository: DatasetRepository,
private val modelRepository: ModelRepository,
private val authenticationFacade: AuthenticationFacade,
private val storageService: StorageService
) : DatasetApiDelegate {
Expand All @@ -30,6 +35,18 @@ class DatasetService(
private val logger = KotlinLogging.logger {}
}

override fun createDataset(modelId: Long, datasetDto: DatasetDto): ResponseEntity<DatasetDto> {
val userId = authenticationFacade.userId
val model = modelRepository.findById(modelId).orElseThrow {
throw JaqpotRuntimeException("Model with id ${datasetDto.modelId} not found")
}
val dataset = datasetDto.toEntity(model, userId, DatasetEntryType.ARRAY)
datasetRepository.save(dataset)
storageService.storeRawDataset(dataset)

return ResponseEntity.ok(dataset.toDto(dataset.input!!, dataset.result))
}

@PostAuthorize("@getDatasetAuthorizationLogic.decide(#root)")
override fun getDatasetById(id: Long): ResponseEntity<DatasetDto> {
val dataset = datasetRepository.findById(id)
Expand Down Expand Up @@ -63,6 +80,19 @@ class DatasetService(
return ResponseEntity.ok().body(datasets.toGetDatasets200ResponseDto(inputsMap, resultsMap))
}

override fun getDatasetsByModelId(
modelId: Long,
page: Int,
size: Int,
sort: List<String>?
): ResponseEntity<GetDatasets200ResponseDto> {
val userId = authenticationFacade.userId
val pageable = PageRequest.of(page, size, Sort.by(parseSortParameters(sort)))
val datasets = datasetRepository.findAllByUserIdAndModelId(userId, modelId, pageable)

return ResponseEntity.ok().body(datasets.toGetDatasets200ResponseDto(emptyMap(), emptyMap()))
}

@Transactional
@Scheduled(cron = "0 0 3 * * *" /* every day at 3:00 AM */)
fun purgeExpiredDatasets() {
Expand Down
93 changes: 93 additions & 0 deletions src/main/kotlin/org/jaqpot/api/service/model/ModelApi.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package org.jaqpot.api.service.model

import io.swagger.v3.oas.annotations.Operation
import io.swagger.v3.oas.annotations.Parameter
import io.swagger.v3.oas.annotations.media.Content
import io.swagger.v3.oas.annotations.media.Schema
import io.swagger.v3.oas.annotations.responses.ApiResponse
import io.swagger.v3.oas.annotations.security.SecurityRequirement
import jakarta.validation.Valid
import org.jaqpot.api.service.model.dto.StreamPredictRequestDto
import org.springframework.stereotype.Controller
import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestMethod
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter

@Controller
class ModelApi(private val modelService: ModelService) {
@Operation(
tags = ["model"],
summary = "Stream predictions from LLM Model",
operationId = "streamPredictWithModel",
description = """Submit a prompt for streaming prediction using a specific LLM model""",
responses = [
ApiResponse(
responseCode = "200",
description = "Streaming response started",
content = [Content(schema = Schema(implementation = kotlin.String::class))]
),
ApiResponse(responseCode = "400", description = "Invalid Request"),
ApiResponse(responseCode = "404", description = "Model not found"),
ApiResponse(responseCode = "500", description = "Internal Server Error")
],
security = [SecurityRequirement(name = "bearerAuth")]
)
@RequestMapping(
method = [RequestMethod.POST],
value = ["/v1/models/{modelId}/predict/stream/{datasetId}"],
produces = ["text/event-stream"],
consumes = ["application/json"]
)
fun streamPredictWithModel(
@Parameter(
description = "The ID of the LLM model to use for prediction",
required = true
) @PathVariable("modelId") modelId: kotlin.Long,
@Parameter(
description = "The ID of the dataset for prediction",
required = true
) @PathVariable("datasetId") datasetId: kotlin.Long,
@Parameter(
description = "",
required = true
) @Valid @RequestBody streamPredictRequestDto: StreamPredictRequestDto
): ResponseBodyEmitter {
val emitter = ResponseBodyEmitter()

// Store the subscription so we can dispose of it later
val subscription = modelService.streamPredictWithModel(modelId, datasetId, streamPredictRequestDto)
.subscribe(
// OnNext handler
{ result ->
try {
emitter.send(result)
} catch (e: Exception) {
emitter.completeWithError(e)
}
},
// OnError handler
{ error ->
emitter.completeWithError(error)
},
// OnComplete handler
{
emitter.complete()
}
)

// Set up completion callback to dispose of the subscription
emitter.onCompletion {
subscription.dispose()
}

// Set up timeout callback
emitter.onTimeout {
subscription.dispose()
}

return emitter
}

}
26 changes: 18 additions & 8 deletions src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import org.jaqpot.api.service.authentication.UserService
import org.jaqpot.api.service.dataset.csv.CSVDataConverter
import org.jaqpot.api.service.dataset.csv.CSVParser
import org.jaqpot.api.service.model.config.ModelConfiguration
import org.jaqpot.api.service.model.dto.StreamPredictRequestDto
import org.jaqpot.api.service.prediction.PredictionService
import org.jaqpot.api.service.prediction.streaming.StreamingPredictionService
import org.jaqpot.api.service.ratelimit.WithRateLimitProtectionByUser
import org.jaqpot.api.service.util.SortUtil.Companion.parseSortParameters
import org.jaqpot.api.storage.StorageService
Expand All @@ -34,6 +36,7 @@ import org.springframework.security.access.prepost.PreAuthorize
import org.springframework.stereotype.Service
import org.springframework.web.server.ResponseStatusException
import org.springframework.web.servlet.support.ServletUriComponentsBuilder
import reactor.core.publisher.Flux
import java.net.URI
import java.time.OffsetDateTime

Expand All @@ -54,7 +57,8 @@ class ModelService(
private val csvDataConverter: CSVDataConverter,
private val storageService: StorageService,
private val doaService: DoaService,
private val modelConfiguration: ModelConfiguration
private val modelConfiguration: ModelConfiguration,
private val streamingPredictionService: StreamingPredictionService
) : ModelApiDelegate {

companion object {
Expand Down Expand Up @@ -215,10 +219,6 @@ class ModelService(
}
// TODO once there are no models with rawModel in the database, remove this
storeRawModelToStorage(model)
// TODO once there are no models with rawPreprocessor in the database, remove this
storeRawPreprocessorToStorage(model)
// TODO once there are no models with rawDoa in the database, remove this
model.doas.forEach(doaService::storeRawDoaToStorage)

if (model.archived) {
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Model with id $modelId is archived")
Expand All @@ -238,7 +238,7 @@ class ModelService(
)
}

toEntity.input.forEachIndexed { index, it: Any ->
toEntity.input!!.forEachIndexed { index, it: Any ->
if (it is Map<*, *>)
(it as MutableMap<String, String>)[JAQPOT_ROW_ID_KEY] = index.toString()
}
Expand All @@ -264,8 +264,6 @@ class ModelService(
val userId = authenticationFacade.userId
// TODO once there are no models with rawModel in the database, remove this
storeRawModelToStorage(model)
// TODO once there are no models with rawPreprocessor in the database, remove this
storeRawPreprocessorToStorage(model)

val csvData = csvParser.readCsv(datasetCSVDto.inputFile.inputStream())

Expand All @@ -292,6 +290,18 @@ class ModelService(
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null)
}

fun streamPredictWithModel(
modelId: Long,
datasetId: Long,
streamPredictRequestDto: StreamPredictRequestDto
): Flux<String> {
return streamingPredictionService.getStreamingPrediction(
modelId,
datasetId,
streamPredictRequestDto
)
}

private fun triggerPredictionAndReturnSuccessStatus(
model: Model,
dataset: Dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package org.jaqpot.api.service.model.dto

class StreamPredictRequestDto(val prompt: String)
Loading

0 comments on commit b37a931

Please sign in to comment.