Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-114): send request to inference api #25

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package org.jaqpot.api
import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.boot.context.properties.ConfigurationPropertiesScan
import org.springframework.boot.runApplication
import org.springframework.scheduling.annotation.EnableAsync

@SpringBootApplication
@ConfigurationPropertiesScan("org.jaqpot.api")
@EnableAsync
class JaqpotApiApplication

fun main(args: Array<String>) {
Expand Down
10 changes: 7 additions & 3 deletions src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ class DataEntry(
@SequenceGenerator(name = "data_entry_id_seq", sequenceName = "data_entry_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "dataset_id", updatable = false, nullable = false)
val dataset: Dataset,

@Enumerated(EnumType.STRING)
@Column(nullable = false)
val type: DataEntryType,
val type: DataEntryType = DataEntryType.ARRAY,

@Enumerated(EnumType.STRING)
@Column(nullable = false)
val role: DataEntryRole,

@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "values", columnDefinition = "jsonb", nullable = false)
val values: Any,
val values: List<Any>,
) : BaseEntity()
6 changes: 6 additions & 0 deletions src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.jaqpot.api.entity

enum class DataEntryRole {
INPUT,
RESULTS
}
10 changes: 8 additions & 2 deletions src/main/kotlin/org/jaqpot/api/entity/Dataset.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jaqpot.api.entity

import jakarta.persistence.*
import org.hibernate.annotations.SQLRestriction

@Entity
class Dataset(
Expand All @@ -9,7 +10,7 @@ class Dataset(
@SequenceGenerator(name = "dataset_id_seq", sequenceName = "dataset_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", updatable = false, nullable = false)
val model: Model,

Expand All @@ -21,5 +22,10 @@ class Dataset(
val type: DatasetType = DatasetType.PREDICTION,

@OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true)
val dataEntry: List<DataEntry>
@SQLRestriction("data_entry_role = 'INPUT'")
val input: MutableList<DataEntry>,

@OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true)
@SQLRestriction("data_entry_role = 'RESULTS'")
var results: MutableList<DataEntry>
) : BaseEntity()
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jaqpot/api/entity/Feature.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Feature(
@SequenceGenerator(name = "feature_id_seq", sequenceName = "feature_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", nullable = false)
val model: Model,

Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jaqpot/api/entity/Library.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Library(
@SequenceGenerator(name = "library_id_seq", sequenceName = "library_id_seq", allocationSize = 1)
val id: Long? = 0,

@ManyToOne
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", updatable = false, nullable = false)
val model: Model,

Expand Down
26 changes: 26 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DataEntry
import org.jaqpot.api.entity.DataEntryRole
import org.jaqpot.api.entity.Dataset
import org.jaqpot.api.model.DataEntryDto

fun DataEntry.toDto(): DataEntryDto {
return DataEntryDto(
this.type.toDto(),
this.values,
this.id,
this.createdAt,
this.updatedAt
)
}

fun DataEntryDto.toEntity(dataset: Dataset, dataEntryRole: DataEntryRole): DataEntry {
return DataEntry(
this.id,
dataset,
this.type.toEntity(),
dataEntryRole,
this.propertyValues,
)
}
16 changes: 16 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DataEntryType
import org.jaqpot.api.model.DataEntryDto

fun DataEntryDto.Type.toEntity(): DataEntryType {
return when (this) {
DataEntryDto.Type.ARRAY -> DataEntryType.ARRAY
}
}

fun DataEntryType.toDto(): DataEntryDto.Type {
return when (this) {
DataEntryType.ARRAY -> DataEntryDto.Type.ARRAY
}
}
34 changes: 34 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DataEntryRole
import org.jaqpot.api.entity.Dataset
import org.jaqpot.api.entity.Model
import org.jaqpot.api.model.DatasetDto


fun Dataset.toDto(): DatasetDto {
return DatasetDto(
this.type.toDto(),
this.input.map { it.toDto() },
this.id,
this.results.map { it.toDto() },
this.createdAt,
this.updatedAt
)
}

fun DatasetDto.toEntity(model: Model, userId: String): Dataset {
val d = Dataset(
this.id,
model,
userId,
this.type.toEntity(),
mutableListOf(),
mutableListOf()
)

d.input.addAll(this.input.map { it -> it.toEntity(d, DataEntryRole.INPUT) })
d.results.addAll(this.results?.map { it -> it.toEntity(d, DataEntryRole.RESULTS) } ?: emptyList())

return d
}
16 changes: 16 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.DatasetType
import org.jaqpot.api.model.DatasetDto

fun DatasetDto.Type.toEntity(): DatasetType {
return when (this) {
DatasetDto.Type.PREDICTION -> DatasetType.PREDICTION
}
}

fun DatasetType.toDto(): DatasetDto.Type {
return when (this) {
DatasetType.PREDICTION -> DatasetDto.Type.PREDICTION
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.jaqpot.api.repository

import org.jaqpot.api.entity.Dataset
import org.springframework.data.repository.CrudRepository

interface DatasetRepository : CrudRepository<Dataset, Long>
26 changes: 25 additions & 1 deletion src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@ import org.jaqpot.api.auth.AuthenticationFacade
import org.jaqpot.api.auth.UserService
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.mapper.toEntity
import org.jaqpot.api.model.DatasetDto
import org.jaqpot.api.model.ModelDto
import org.jaqpot.api.repository.DatasetRepository
import org.jaqpot.api.repository.ModelRepository
import org.springframework.data.repository.findByIdOrNull
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity
import org.springframework.stereotype.Service
import org.springframework.web.server.ResponseStatusException
import org.springframework.web.servlet.support.ServletUriComponentsBuilder
import java.net.URI

Expand All @@ -17,7 +22,8 @@ import java.net.URI
class ModelService(
private val authenticationFacade: AuthenticationFacade,
private val modelRepository: ModelRepository,
private val userService: UserService
private val userService: UserService,
private val predictionService: PredictionService, private val datasetRepository: DatasetRepository
) : ModelApiDelegate {
override fun createModel(modelDto: ModelDto): ResponseEntity<Unit> {
val userId = authenticationFacade.userId
Expand All @@ -37,5 +43,23 @@ class ModelService(
}
.orElse(ResponseEntity.notFound().build())
}

override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity<Unit> {
if (datasetDto.type == DatasetDto.Type.PREDICTION) {
val model = this.modelRepository.findByIdOrNull(modelId)
?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
val userId = authenticationFacade.userId
val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId))

this.predictionService.executePredictionAndSaveResults(model, dataset)

val location: URI = ServletUriComponentsBuilder
.fromCurrentRequest().path("/{id}")
.buildAndExpand(dataset.id).toUri()
return ResponseEntity.created(location).build()
}

throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null)
}
}

62 changes: 62 additions & 0 deletions src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.jaqpot.api.service.model

import org.jaqpot.api.entity.*
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.repository.DatasetRepository
import org.jaqpot.api.service.model.dto.PredictionRequestDto
import org.jaqpot.api.service.model.dto.PredictionResponseDto
import org.jaqpot.api.service.runtime.RuntimeResolver
import org.springframework.http.HttpEntity
import org.springframework.scheduling.annotation.Async
import org.springframework.stereotype.Service
import org.springframework.web.client.RestTemplate
import java.util.*


@Service
class PredictionService(
private val datasetRepository: DatasetRepository,
private val runtimeResolver: RuntimeResolver
) {

@Async
fun executePredictionAndSaveResults(model: Model, dataset: Dataset) {
val rawModel = Base64.getEncoder().encodeToString(model.actualModel)
val request: HttpEntity<PredictionRequestDto> =
HttpEntity(PredictionRequestDto(listOf(rawModel), dataset.toDto()))

val results: List<Any> = makePredictionRequest(model, request)

storeResults(dataset, results)
}

private fun storeResults(dataset: Dataset, results: List<Any>) {
dataset.results.clear()
dataset.results.addAll(
listOf(
DataEntry(
null,
dataset,
DataEntryType.ARRAY,
DataEntryRole.RESULTS,
results
)
)
)

datasetRepository.save(dataset)
}

private fun makePredictionRequest(
model: Model,
request: HttpEntity<PredictionRequestDto>
): List<Any> {
val restTemplate = RestTemplate()
val inferenceUrl = "${runtimeResolver.resolveRuntime(model)}/predict/"
val response = restTemplate.postForEntity(inferenceUrl, request, PredictionResponseDto::class.java)

val results: List<Any> = response.body?.predictions ?: emptyList()
return results
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.jaqpot.api.service.model.dto

import com.fasterxml.jackson.annotation.JsonInclude
import org.jaqpot.api.model.DatasetDto

@JsonInclude(JsonInclude.Include.ALWAYS)
class PredictionRequestDto(
val rawModel: List<String>,
val dataset: DatasetDto,
val additionalInfo: String? = null,
val doaMatrix: String? = null
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.jaqpot.api.service.model.dto

import com.fasterxml.jackson.annotation.JsonInclude

@JsonInclude(JsonInclude.Include.ALWAYS)
class PredictionResponseDto(
val predictions: List<Map<String, Any>>
)
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package org.jaqpot.api.service.runtime

import org.jaqpot.api.entity.Model
import org.jaqpot.api.service.runtime.config.RuntimeProvider
import org.springframework.stereotype.Component

@Component
class RuntimeResolver(val runtimeProvider: RuntimeProvider) {

fun resolveRuntime(): String {
return runtimeProvider.jaqpotpyPretrainedUrl;
fun resolveRuntime(model: Model): String {
return runtimeProvider.jaqpotpyPretrainedUrl
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.jaqpot.api.service.runtime
package org.jaqpot.api.service.runtime.config

import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ CREATE TABLE data_entry
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
dataset_id BIGINT NOT NULL,
type VARCHAR(255) NOT NULL,
role VARCHAR(255) NOT NULL,
values JSONB NOT NULL,
CONSTRAINT pk_dataentry PRIMARY KEY (id)
);
Expand Down
Loading
Loading