Skip to content

Commit

Permalink
fix: LLM fixes + datasets storage (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
alarv authored Dec 18, 2024
1 parent 0bdbcbc commit 1d58273
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/main/kotlin/org/jaqpot/api/entity/Dataset.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class Dataset(

@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "input", columnDefinition = "jsonb")
var input: List<Any>? = null,
var input: List<Any>? = emptyList(),

@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "result", columnDefinition = "jsonb")
var result: List<Any>? = null,
var result: List<Any>? = emptyList(),

@Size(min = 3, max = 15000)
@Column(columnDefinition = "TEXT")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class DatasetService(
}
val dataset = datasetDto.toEntity(model, userId, DatasetEntryType.ARRAY)
datasetRepository.save(dataset)
storageService.storeRawDataset(dataset)
if (storageService.storeRawDataset(dataset)) {
datasetRepository.setDatasetInputAndResultToNull(dataset.id)
}

return ResponseEntity.ok(dataset.toDto(dataset.input!!, dataset.result))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jaqpot.api.service.prediction.streaming

import io.github.oshai.kotlinlogging.KotlinLogging
import org.jaqpot.api.entity.Dataset
import org.jaqpot.api.entity.DatasetStatus
import org.jaqpot.api.error.JaqpotRuntimeException
import org.jaqpot.api.mapper.toDto
Expand All @@ -11,13 +12,12 @@ import org.jaqpot.api.service.dataset.DatasetService
import org.jaqpot.api.service.model.JAQPOT_ROW_ID_KEY
import org.jaqpot.api.service.model.dto.StreamPredictRequestDto
import org.jaqpot.api.service.prediction.runtime.runtimes.streaming.StreamingModelRuntime
import org.jaqpot.api.service.prediction.util.PredictionUtil
import org.jaqpot.api.service.prediction.util.PredictionUtil.Companion.updateDatasetToExecuting
import org.jaqpot.api.storage.StorageService
import org.springframework.http.HttpStatus
import org.springframework.stereotype.Service
import org.springframework.web.server.ResponseStatusException
import reactor.core.publisher.Flux
import java.time.OffsetDateTime

@Service
class StreamingPredictionService(
Expand Down Expand Up @@ -53,22 +53,17 @@ class StreamingPredictionService(
toEntity.input = storageService.readRawDatasetInput(toEntity)
toEntity.result = storageService.readRawDatasetResult(toEntity)


toEntity.input!!.forEachIndexed { index, it: Any ->
if (it is Map<*, *>)
(it as MutableMap<String, String>)[JAQPOT_ROW_ID_KEY] = index.toString()
}
toEntity.input = listOf(mapOf("prompt" to streamPredictRequestDto.prompt)) + (toEntity.input ?: emptyList())
toEntity.input = (toEntity.input ?: emptyList()) + listOf(mapOf("prompt" to streamPredictRequestDto.prompt))

val predictionModelDto = model.toPredictionModelDto(byteArrayOf(), emptyList(), byteArrayOf())
updateDatasetToExecuting(toEntity)
storageService.storeRawDataset(toEntity)

updateDatasetToExecuting(toEntity, datasetRepository, storageService)

val dataset = this.datasetRepository.save(toEntity)
if (storageService.storeRawDataset(dataset)) {
datasetRepository.setDatasetInputAndResultToNull(dataset.id)
}
val datasetDto = dataset.toDto(dataset.input!!, dataset.result)
val datasetDto = toEntity.toDto(toEntity.input!!, toEntity.result)

var output = ""

Expand All @@ -82,28 +77,52 @@ class StreamingPredictionService(
}
.doOnError { e ->
logger.error(e) { "Stream error for model ${predictionModelDto.id}: ${e.message}" }
PredictionUtil.storeDatasetFailure(dataset, e, datasetRepository, storageService)
storeDatasetFailure(toEntity, e, datasetRepository)
}
.doFinally { signal ->
logger.info { "Stream finished with signal $signal for model ${predictionModelDto.id}" }
storeDatasetSuccess(datasetDto.id!!, mapOf("output" to output))
}
}

private fun updateDatasetToExecuting(
dataset: Dataset,
) {
dataset.status = DatasetStatus.EXECUTING
dataset.executedAt = OffsetDateTime.now()
datasetRepository.save(dataset)
}

private fun storeDatasetSuccess(datasetId: Long, result: Any) {
val dataset = datasetRepository.findById(datasetId).orElseThrow {
throw JaqpotRuntimeException("Dataset with id $datasetId not found")
}
val datasetInput = storageService.readRawDatasetInput(dataset)?.toMutableList()
dataset.input = datasetInput ?: emptyList()

var datasetResult = storageService.readRawDatasetResult(dataset)?.toMutableList()
if (datasetResult == null) {
datasetResult = mutableListOf()
}
dataset.result = datasetResult
dataset.result = listOf(result) + datasetResult
dataset.result = datasetResult + listOf(result)
dataset.status = DatasetStatus.SUCCESS
dataset.executionFinishedAt = OffsetDateTime.now()
datasetRepository.save(dataset)
if (storageService.storeRawDataset(dataset)) {
datasetRepository.setDatasetInputAndResultToNull(dataset.id)
}
}


private fun storeDatasetFailure(
dataset: Dataset,
err: Throwable,
datasetRepository: DatasetRepository,
) {
dataset.status = DatasetStatus.FAILURE
dataset.failureReason = err.toString()
dataset.executionFinishedAt = OffsetDateTime.now()

datasetRepository.save(dataset)
}
}
72 changes: 65 additions & 7 deletions src/main/kotlin/org/jaqpot/api/storage/LocalStorage.kt
Original file line number Diff line number Diff line change
@@ -1,27 +1,67 @@
package org.jaqpot.api.storage

import org.jaqpot.api.error.JaqpotRuntimeException
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Service
import java.util.*
import java.util.concurrent.ConcurrentHashMap

@Profile("local")
@Service
class LocalStorage : Storage {
// Structure: Map<BucketName, Map<KeyName, StorageObject>>
private val storage = ConcurrentHashMap<String, ConcurrentHashMap<String, StorageObject>>()

data class StorageObject(
val data: ByteArray,
val contentType: String = "application/octet-stream",
val metadata: Map<String, String> = emptyMap()
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as StorageObject

if (!data.contentEquals(other.data)) return false
if (contentType != other.contentType) return false
if (metadata != other.metadata) return false

return true
}

override fun hashCode(): Int {
var result = data.contentHashCode()
result = 31 * result + contentType.hashCode()
result = 31 * result + metadata.hashCode()
return result
}
}

override fun getObject(bucketName: String, keyName: String): Optional<ByteArray> {
return Optional.empty()
return Optional.ofNullable(
storage[bucketName]?.get(keyName)?.data
)
}

override fun getObjects(bucketName: String, keyNames: List<String>): Map<String, ByteArray> {
return emptyMap()
return storage[bucketName]?.let { bucket ->
keyNames.mapNotNull { keyName ->
bucket[keyName]?.let { obj ->
keyName to obj.data
}
}.toMap()
} ?: emptyMap()
}

override fun listObjects(bucketName: String, prefix: String): List<String> {
return emptyList()
return storage[bucketName]?.keys
?.filter { it.startsWith(prefix) }
?.toList()
?: emptyList()
}

override fun putObject(bucketName: String, keyName: String, obj: ByteArray, metadata: Map<String, String>) {
throw JaqpotRuntimeException("Not implemented")
putObject(bucketName, keyName, "application/octet-stream", obj, metadata)
}

override fun putObject(
Expand All @@ -31,10 +71,28 @@ class LocalStorage : Storage {
obj: ByteArray,
metadata: Map<String, String>
) {
throw JaqpotRuntimeException("Not implemented")
val bucket = storage.computeIfAbsent(bucketName) { ConcurrentHashMap() }
bucket[keyName] = StorageObject(obj, contentType, metadata)
}

override fun deleteObject(bucketName: String, keyName: String) {
throw JaqpotRuntimeException("Not implemented")
storage[bucketName]?.remove(keyName)
}

// Additional helper methods for testing
fun clearStorage() {
storage.clear()
}

fun getBuckets(): Set<String> {
return storage.keys
}

fun getObjectMetadata(bucketName: String, keyName: String): Map<String, String>? {
return storage[bucketName]?.get(keyName)?.metadata
}

fun getObjectContentType(bucketName: String, keyName: String): String? {
return storage[bucketName]?.get(keyName)?.contentType
}
}

0 comments on commit 1d58273

Please sign in to comment.