Skip to content

Commit

Permalink
feat(JAQPOT-114): send request to inference api (#25)
Browse files Browse the repository at this point in the history
* feat(JAQPOT-114): send request to inference api

* feat: send test request to jaqpot-inference

* fix: compilation error

* feat: complete end-to-end with prediction request

* fix: make async method run properly async
  • Loading branch information
alarv authored Jun 4, 2024
1 parent 3b377db commit b6e939c
Show file tree
Hide file tree
Showing 19 changed files with 321 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package org.jaqpot.api
import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.boot.context.properties.ConfigurationPropertiesScan
import org.springframework.boot.runApplication
import org.springframework.scheduling.annotation.EnableAsync

@SpringBootApplication
@ConfigurationPropertiesScan("org.jaqpot.api")
@EnableAsync
class JaqpotApiApplication

fun main(args: Array<String>) {
Expand Down
10 changes: 7 additions & 3 deletions src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ class DataEntry(
@SequenceGenerator(name = "data_entry_id_seq", sequenceName = "data_entry_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "dataset_id", updatable = false, nullable = false)
val dataset: Dataset,

@Enumerated(EnumType.STRING)
@Column(nullable = false)
val type: DataEntryType,
val type: DataEntryType = DataEntryType.ARRAY,

@Enumerated(EnumType.STRING)
@Column(nullable = false)
val role: DataEntryRole,

@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "values", columnDefinition = "jsonb", nullable = false)
val values: Any,
val values: List<Any>,
) : BaseEntity()
6 changes: 6 additions & 0 deletions src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.jaqpot.api.entity

enum class DataEntryRole {
INPUT,
RESULTS
}
10 changes: 8 additions & 2 deletions src/main/kotlin/org/jaqpot/api/entity/Dataset.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jaqpot.api.entity

import jakarta.persistence.*
import org.hibernate.annotations.SQLRestriction

@Entity
class Dataset(
Expand All @@ -9,7 +10,7 @@ class Dataset(
@SequenceGenerator(name = "dataset_id_seq", sequenceName = "dataset_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", updatable = false, nullable = false)
val model: Model,

Expand All @@ -21,5 +22,10 @@ class Dataset(
val type: DatasetType = DatasetType.PREDICTION,

@OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true)
val dataEntry: List<DataEntry>
@SQLRestriction("data_entry_role = 'INPUT'")
val input: MutableList<DataEntry>,

@OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true)
@SQLRestriction("data_entry_role = 'RESULTS'")
var results: MutableList<DataEntry>
) : BaseEntity()
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jaqpot/api/entity/Feature.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Feature(
@SequenceGenerator(name = "feature_id_seq", sequenceName = "feature_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", nullable = false)
val model: Model,

Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jaqpot/api/entity/Library.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Library(
@SequenceGenerator(name = "library_id_seq", sequenceName = "library_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", updatable = false, nullable = false)
val model: Model,

Expand Down
26 changes: 26 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DataEntry
import org.jaqpot.api.entity.DataEntryRole
import org.jaqpot.api.entity.Dataset
import org.jaqpot.api.model.DataEntryDto

fun DataEntry.toDto(): DataEntryDto {
return DataEntryDto(
this.type.toDto(),
this.values,
this.id,
this.createdAt,
this.updatedAt
)
}

fun DataEntryDto.toEntity(dataset: Dataset, dataEntryRole: DataEntryRole): DataEntry {
return DataEntry(
this.id,
dataset,
this.type.toEntity(),
dataEntryRole,
this.propertyValues,
)
}
16 changes: 16 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DataEntryType
import org.jaqpot.api.model.DataEntryDto

fun DataEntryDto.Type.toEntity(): DataEntryType {
return when (this) {
DataEntryDto.Type.ARRAY -> DataEntryType.ARRAY
}
}

fun DataEntryType.toDto(): DataEntryDto.Type {
return when (this) {
DataEntryType.ARRAY -> DataEntryDto.Type.ARRAY
}
}
34 changes: 34 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DataEntryRole
import org.jaqpot.api.entity.Dataset
import org.jaqpot.api.entity.Model
import org.jaqpot.api.model.DatasetDto


fun Dataset.toDto(): DatasetDto {
return DatasetDto(
this.type.toDto(),
this.input.map { it.toDto() },
this.id,
this.results.map { it.toDto() },
this.createdAt,
this.updatedAt
)
}

fun DatasetDto.toEntity(model: Model, userId: String): Dataset {
val d = Dataset(
this.id,
model,
userId,
this.type.toEntity(),
mutableListOf(),
mutableListOf()
)

d.input.addAll(this.input.map { it -> it.toEntity(d, DataEntryRole.INPUT) })
d.results.addAll(this.results?.map { it -> it.toEntity(d, DataEntryRole.RESULTS) } ?: emptyList())

return d
}
16 changes: 16 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DatasetType
import org.jaqpot.api.model.DatasetDto

fun DatasetDto.Type.toEntity(): DatasetType {
return when (this) {
DatasetDto.Type.PREDICTION -> DatasetType.PREDICTION
}
}

fun DatasetType.toDto(): DatasetDto.Type {
return when (this) {
DatasetType.PREDICTION -> DatasetDto.Type.PREDICTION
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.jaqpot.api.repository

import org.jaqpot.api.entity.Dataset
import org.springframework.data.repository.CrudRepository

interface DatasetRepository : CrudRepository<Dataset, Long>
26 changes: 25 additions & 1 deletion src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@ import org.jaqpot.api.auth.AuthenticationFacade
import org.jaqpot.api.auth.UserService
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.mapper.toEntity
import org.jaqpot.api.model.DatasetDto
import org.jaqpot.api.model.ModelDto
import org.jaqpot.api.repository.DatasetRepository
import org.jaqpot.api.repository.ModelRepository
import org.springframework.data.repository.findByIdOrNull
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity
import org.springframework.stereotype.Service
import org.springframework.web.server.ResponseStatusException
import org.springframework.web.servlet.support.ServletUriComponentsBuilder
import java.net.URI

Expand All @@ -17,7 +22,8 @@ import java.net.URI
class ModelService(
private val authenticationFacade: AuthenticationFacade,
private val modelRepository: ModelRepository,
private val userService: UserService
private val userService: UserService,
private val predictionService: PredictionService, private val datasetRepository: DatasetRepository
) : ModelApiDelegate {
override fun createModel(modelDto: ModelDto): ResponseEntity<Unit> {
val userId = authenticationFacade.userId
Expand All @@ -37,5 +43,23 @@ class ModelService(
}
.orElse(ResponseEntity.notFound().build())
}

override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity<Unit> {
if (datasetDto.type == DatasetDto.Type.PREDICTION) {
val model = this.modelRepository.findByIdOrNull(modelId)
?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
val userId = authenticationFacade.userId
val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId))

this.predictionService.executePredictionAndSaveResults(model, dataset)

val location: URI = ServletUriComponentsBuilder
.fromCurrentRequest().path("/{id}")
.buildAndExpand(dataset.id).toUri()
return ResponseEntity.created(location).build()
}

throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null)
}
}

62 changes: 62 additions & 0 deletions src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.jaqpot.api.service.model

import org.jaqpot.api.entity.*
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.repository.DatasetRepository
import org.jaqpot.api.service.model.dto.PredictionRequestDto
import org.jaqpot.api.service.model.dto.PredictionResponseDto
import org.jaqpot.api.service.runtime.RuntimeResolver
import org.springframework.http.HttpEntity
import org.springframework.scheduling.annotation.Async
import org.springframework.stereotype.Service
import org.springframework.web.client.RestTemplate
import java.util.*


@Service
class PredictionService(
private val datasetRepository: DatasetRepository,
private val runtimeResolver: RuntimeResolver
) {

@Async
fun executePredictionAndSaveResults(model: Model, dataset: Dataset) {
val rawModel = Base64.getEncoder().encodeToString(model.actualModel)
val request: HttpEntity<PredictionRequestDto> =
HttpEntity(PredictionRequestDto(listOf(rawModel), dataset.toDto()))

val results: List<Any> = makePredictionRequest(model, request)

storeResults(dataset, results)
}

private fun storeResults(dataset: Dataset, results: List<Any>) {
dataset.results.clear()
dataset.results.addAll(
listOf(
DataEntry(
null,
dataset,
DataEntryType.ARRAY,
DataEntryRole.RESULTS,
results
)
)
)

datasetRepository.save(dataset)
}

private fun makePredictionRequest(
model: Model,
request: HttpEntity<PredictionRequestDto>
): List<Any> {
val restTemplate = RestTemplate()
val inferenceUrl = "${runtimeResolver.resolveRuntime(model)}/predict/"
val response = restTemplate.postForEntity(inferenceUrl, request, PredictionResponseDto::class.java)

val results: List<Any> = response.body?.predictions ?: emptyList()
return results
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.jaqpot.api.service.model.dto

import com.fasterxml.jackson.annotation.JsonInclude
import org.jaqpot.api.model.DatasetDto

@JsonInclude(JsonInclude.Include.ALWAYS)
class PredictionRequestDto(
val rawModel: List<String>,
val dataset: DatasetDto,
val additionalInfo: String? = null,
val doaMatrix: String? = null
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.jaqpot.api.service.model.dto

import com.fasterxml.jackson.annotation.JsonInclude

@JsonInclude(JsonInclude.Include.ALWAYS)
class PredictionResponseDto(
val predictions: List<Map<String, Any>>
)
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package org.jaqpot.api.service.runtime

import org.jaqpot.api.entity.Model
import org.jaqpot.api.service.runtime.config.RuntimeProvider
import org.springframework.stereotype.Component

@Component
class RuntimeResolver(val runtimeProvider: RuntimeProvider) {

fun resolveRuntime(): String {
return runtimeProvider.jaqpotpyPretrainedUrl;
fun resolveRuntime(model: Model): String {
return runtimeProvider.jaqpotpyPretrainedUrl
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.jaqpot.api.service.runtime
package org.jaqpot.api.service.runtime.config

import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ CREATE TABLE data_entry
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
dataset_id BIGINT NOT NULL,
type VARCHAR(255) NOT NULL,
role VARCHAR(255) NOT NULL,
values JSONB NOT NULL,
CONSTRAINT pk_dataentry PRIMARY KEY (id)
);
Expand Down
Loading

0 comments on commit b6e939c

Please sign in to comment.