From b6e939c11c57b688c916f215b4aa49fac0948473 Mon Sep 17 00:00:00 2001 From: Alex Arvanitidis Date: Tue, 4 Jun 2024 11:27:57 +0300 Subject: [PATCH] feat(JAQPOT-114): send request to inference api (#25) * feat(JAQPOT-114): send request to inference api * feat: send test request to jaqpot-inference * fix: compilation error * feat: complete end-to-end with prediction request * fix: make async method run properly async --- .../org/jaqpot/api/JaqpotApiApplication.kt | 2 + .../kotlin/org/jaqpot/api/entity/DataEntry.kt | 10 ++- .../org/jaqpot/api/entity/DataEntryRole.kt | 6 ++ .../kotlin/org/jaqpot/api/entity/Dataset.kt | 10 ++- .../kotlin/org/jaqpot/api/entity/Feature.kt | 2 +- .../kotlin/org/jaqpot/api/entity/Library.kt | 2 +- .../org/jaqpot/api/mapper/DataEntryMapper.kt | 26 ++++++ .../jaqpot/api/mapper/DataEntryTypeMapper.kt | 16 ++++ .../org/jaqpot/api/mapper/DatasetMapper.kt | 34 ++++++++ .../jaqpot/api/mapper/DatasetTypeMapper.kt | 16 ++++ .../api/repository/DatasetRepository.kt | 6 ++ .../jaqpot/api/service/model/ModelService.kt | 26 +++++- .../api/service/model/PredictionService.kt | 62 ++++++++++++++ .../service/model/dto/PredictionRequestDto.kt | 12 +++ .../model/dto/PredictionResponseDto.kt | 8 ++ .../api/service/runtime/RuntimeResolver.kt | 6 +- .../runtime/{ => config}/RuntimeProvider.kt | 2 +- .../db/migration/V2__create_main_entities.sql | 1 + src/main/resources/openapi.yaml | 85 +++++++++++++++++++ 19 files changed, 321 insertions(+), 11 deletions(-) create mode 100644 src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt create mode 100644 src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt create mode 100644 src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt create mode 100644 src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt rename src/main/kotlin/org/jaqpot/api/service/runtime/{ => config}/RuntimeProvider.kt (84%) diff --git a/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt b/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt index ac620c8..4966326 100644 --- a/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt +++ b/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt @@ -3,9 +3,11 @@ package org.jaqpot.api import org.springframework.boot.autoconfigure.SpringBootApplication import org.springframework.boot.context.properties.ConfigurationPropertiesScan import org.springframework.boot.runApplication +import org.springframework.scheduling.annotation.EnableAsync @SpringBootApplication @ConfigurationPropertiesScan("org.jaqpot.api") +@EnableAsync class JaqpotApiApplication fun main(args: Array) { diff --git a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt index cf1146a..4a6ea3c 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt @@ -11,15 +11,19 @@ class DataEntry( @SequenceGenerator(name = "data_entry_id_seq", sequenceName = "data_entry_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "dataset_id", updatable = false, nullable = false) val dataset: Dataset, @Enumerated(EnumType.STRING) @Column(nullable = false) - val type: DataEntryType, + val type: DataEntryType = DataEntryType.ARRAY, + + @Enumerated(EnumType.STRING) + @Column(nullable = false) + val role: DataEntryRole, @JdbcTypeCode(SqlTypes.JSON) @Column(name = "values", columnDefinition = "jsonb", nullable = false) - val values: Any, + val values: List, ) : BaseEntity() diff --git a/src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt b/src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt new file mode 100644 index 0000000..e0d7f6e --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt @@ -0,0 +1,6 @@ +package org.jaqpot.api.entity + +enum class DataEntryRole { + INPUT, + RESULTS +} diff --git a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt index ef5cafb..9b47db8 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt @@ -1,6 +1,7 @@ package org.jaqpot.api.entity import jakarta.persistence.* +import org.hibernate.annotations.SQLRestriction @Entity class Dataset( @@ -9,7 +10,7 @@ class Dataset( @SequenceGenerator(name = "dataset_id_seq", sequenceName = "dataset_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "model_id", updatable = false, nullable = false) val model: Model, @@ -21,5 +22,10 @@ class Dataset( val type: DatasetType = DatasetType.PREDICTION, @OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true) - val dataEntry: List + @SQLRestriction("data_entry_role = 'INPUT'") + val input: MutableList, + + @OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true) + @SQLRestriction("data_entry_role = 'RESULTS'") + var results: MutableList ) : BaseEntity() diff --git a/src/main/kotlin/org/jaqpot/api/entity/Feature.kt b/src/main/kotlin/org/jaqpot/api/entity/Feature.kt index 3e09367..fa1336a 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Feature.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Feature.kt @@ -11,7 +11,7 @@ class Feature( @SequenceGenerator(name = "feature_id_seq", sequenceName = "feature_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "model_id", nullable = false) val model: Model, diff --git a/src/main/kotlin/org/jaqpot/api/entity/Library.kt b/src/main/kotlin/org/jaqpot/api/entity/Library.kt index 0130d8e..438cc2f 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Library.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Library.kt @@ -9,7 +9,7 @@ class Library( @SequenceGenerator(name = "library_id_seq", sequenceName = "library_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "model_id", updatable = false, nullable = false) val model: Model, diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt new file mode 100644 index 0000000..a92793f --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt @@ -0,0 +1,26 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DataEntry +import org.jaqpot.api.entity.DataEntryRole +import org.jaqpot.api.entity.Dataset +import org.jaqpot.api.model.DataEntryDto + +fun DataEntry.toDto(): DataEntryDto { + return DataEntryDto( + this.type.toDto(), + this.values, + this.id, + this.createdAt, + this.updatedAt + ) +} + +fun DataEntryDto.toEntity(dataset: Dataset, dataEntryRole: DataEntryRole): DataEntry { + return DataEntry( + this.id, + dataset, + this.type.toEntity(), + dataEntryRole, + this.propertyValues, + ) +} diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt new file mode 100644 index 0000000..dece6ee --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt @@ -0,0 +1,16 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DataEntryType +import org.jaqpot.api.model.DataEntryDto + +fun DataEntryDto.Type.toEntity(): DataEntryType { + return when (this) { + DataEntryDto.Type.ARRAY -> DataEntryType.ARRAY + } +} + +fun DataEntryType.toDto(): DataEntryDto.Type { + return when (this) { + DataEntryType.ARRAY -> DataEntryDto.Type.ARRAY + } +} diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt new file mode 100644 index 0000000..0621639 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt @@ -0,0 +1,34 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DataEntryRole +import org.jaqpot.api.entity.Dataset +import org.jaqpot.api.entity.Model +import org.jaqpot.api.model.DatasetDto + + +fun Dataset.toDto(): DatasetDto { + return DatasetDto( + this.type.toDto(), + this.input.map { it.toDto() }, + this.id, + this.results.map { it.toDto() }, + this.createdAt, + this.updatedAt + ) +} + +fun DatasetDto.toEntity(model: Model, userId: String): Dataset { + val d = Dataset( + this.id, + model, + userId, + this.type.toEntity(), + mutableListOf(), + mutableListOf() + ) + + d.input.addAll(this.input.map { it -> it.toEntity(d, DataEntryRole.INPUT) }) + d.results.addAll(this.results?.map { it -> it.toEntity(d, DataEntryRole.RESULTS) } ?: emptyList()) + + return d +} diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt new file mode 100644 index 0000000..d7766e7 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt @@ -0,0 +1,16 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DatasetType +import org.jaqpot.api.model.DatasetDto + +fun DatasetDto.Type.toEntity(): DatasetType { + return when (this) { + DatasetDto.Type.PREDICTION -> DatasetType.PREDICTION + } +} + +fun DatasetType.toDto(): DatasetDto.Type { + return when (this) { + DatasetType.PREDICTION -> DatasetDto.Type.PREDICTION + } +} diff --git a/src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt b/src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt new file mode 100644 index 0000000..027c38f --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt @@ -0,0 +1,6 @@ +package org.jaqpot.api.repository + +import org.jaqpot.api.entity.Dataset +import org.springframework.data.repository.CrudRepository + +interface DatasetRepository : CrudRepository diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index e4f88a7..438ff82 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -5,10 +5,15 @@ import org.jaqpot.api.auth.AuthenticationFacade import org.jaqpot.api.auth.UserService import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity +import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.model.ModelDto +import org.jaqpot.api.repository.DatasetRepository import org.jaqpot.api.repository.ModelRepository +import org.springframework.data.repository.findByIdOrNull +import org.springframework.http.HttpStatus import org.springframework.http.ResponseEntity import org.springframework.stereotype.Service +import org.springframework.web.server.ResponseStatusException import org.springframework.web.servlet.support.ServletUriComponentsBuilder import java.net.URI @@ -17,7 +22,8 @@ import java.net.URI class ModelService( private val authenticationFacade: AuthenticationFacade, private val modelRepository: ModelRepository, - private val userService: UserService + private val userService: UserService, + private val predictionService: PredictionService, private val datasetRepository: DatasetRepository ) : ModelApiDelegate { override fun createModel(modelDto: ModelDto): ResponseEntity { val userId = authenticationFacade.userId @@ -37,5 +43,23 @@ class ModelService( } .orElse(ResponseEntity.notFound().build()) } + + override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity { + if (datasetDto.type == DatasetDto.Type.PREDICTION) { + val model = this.modelRepository.findByIdOrNull(modelId) + ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") + val userId = authenticationFacade.userId + val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) + + this.predictionService.executePredictionAndSaveResults(model, dataset) + + val location: URI = ServletUriComponentsBuilder + .fromCurrentRequest().path("/{id}") + .buildAndExpand(dataset.id).toUri() + return ResponseEntity.created(location).build() + } + + throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null) + } } diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt new file mode 100644 index 0000000..f4c420c --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt @@ -0,0 +1,62 @@ +package org.jaqpot.api.service.model + +import org.jaqpot.api.entity.* +import org.jaqpot.api.mapper.toDto +import org.jaqpot.api.repository.DatasetRepository +import org.jaqpot.api.service.model.dto.PredictionRequestDto +import org.jaqpot.api.service.model.dto.PredictionResponseDto +import org.jaqpot.api.service.runtime.RuntimeResolver +import org.springframework.http.HttpEntity +import org.springframework.scheduling.annotation.Async +import org.springframework.stereotype.Service +import org.springframework.web.client.RestTemplate +import java.util.* + + +@Service +class PredictionService( + private val datasetRepository: DatasetRepository, + private val runtimeResolver: RuntimeResolver +) { + + @Async + fun executePredictionAndSaveResults(model: Model, dataset: Dataset) { + val rawModel = Base64.getEncoder().encodeToString(model.actualModel) + val request: HttpEntity = + HttpEntity(PredictionRequestDto(listOf(rawModel), dataset.toDto())) + + val results: List = makePredictionRequest(model, request) + + storeResults(dataset, results) + } + + private fun storeResults(dataset: Dataset, results: List) { + dataset.results.clear() + dataset.results.addAll( + listOf( + DataEntry( + null, + dataset, + DataEntryType.ARRAY, + DataEntryRole.RESULTS, + results + ) + ) + ) + + datasetRepository.save(dataset) + } + + private fun makePredictionRequest( + model: Model, + request: HttpEntity + ): List { + val restTemplate = RestTemplate() + val inferenceUrl = "${runtimeResolver.resolveRuntime(model)}/predict/" + val response = restTemplate.postForEntity(inferenceUrl, request, PredictionResponseDto::class.java) + + val results: List = response.body?.predictions ?: emptyList() + return results + } +} + diff --git a/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt new file mode 100644 index 0000000..577c96f --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt @@ -0,0 +1,12 @@ +package org.jaqpot.api.service.model.dto + +import com.fasterxml.jackson.annotation.JsonInclude +import org.jaqpot.api.model.DatasetDto + +@JsonInclude(JsonInclude.Include.ALWAYS) +class PredictionRequestDto( + val rawModel: List, + val dataset: DatasetDto, + val additionalInfo: String? = null, + val doaMatrix: String? = null +) diff --git a/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt new file mode 100644 index 0000000..a40d61a --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt @@ -0,0 +1,8 @@ +package org.jaqpot.api.service.model.dto + +import com.fasterxml.jackson.annotation.JsonInclude + +@JsonInclude(JsonInclude.Include.ALWAYS) +class PredictionResponseDto( + val predictions: List> +) diff --git a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt b/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt index 0bd60b7..ec1bd73 100644 --- a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt +++ b/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt @@ -1,11 +1,13 @@ package org.jaqpot.api.service.runtime +import org.jaqpot.api.entity.Model +import org.jaqpot.api.service.runtime.config.RuntimeProvider import org.springframework.stereotype.Component @Component class RuntimeResolver(val runtimeProvider: RuntimeProvider) { - fun resolveRuntime(): String { - return runtimeProvider.jaqpotpyPretrainedUrl; + fun resolveRuntime(model: Model): String { + return runtimeProvider.jaqpotpyPretrainedUrl } } diff --git a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeProvider.kt b/src/main/kotlin/org/jaqpot/api/service/runtime/config/RuntimeProvider.kt similarity index 84% rename from src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeProvider.kt rename to src/main/kotlin/org/jaqpot/api/service/runtime/config/RuntimeProvider.kt index 1e44389..e083bc0 100644 --- a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeProvider.kt +++ b/src/main/kotlin/org/jaqpot/api/service/runtime/config/RuntimeProvider.kt @@ -1,4 +1,4 @@ -package org.jaqpot.api.service.runtime +package org.jaqpot.api.service.runtime.config import org.springframework.beans.factory.annotation.Value import org.springframework.context.annotation.Configuration diff --git a/src/main/resources/db/migration/V2__create_main_entities.sql b/src/main/resources/db/migration/V2__create_main_entities.sql index fd1d8e1..4b451e7 100644 --- a/src/main/resources/db/migration/V2__create_main_entities.sql +++ b/src/main/resources/db/migration/V2__create_main_entities.sql @@ -57,6 +57,7 @@ CREATE TABLE data_entry updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL, dataset_id BIGINT NOT NULL, type VARCHAR(255) NOT NULL, + role VARCHAR(255) NOT NULL, values JSONB NOT NULL, CONSTRAINT pk_dataentry PRIMARY KEY (id) ); diff --git a/src/main/resources/openapi.yaml b/src/main/resources/openapi.yaml index 3da0496..1e00247 100644 --- a/src/main/resources/openapi.yaml +++ b/src/main/resources/openapi.yaml @@ -69,6 +69,39 @@ paths: $ref: '#/components/schemas/Model' '404': description: Model not found + /v1/models/{modelId}/predict: + post: + summary: Predict with Model + description: Submit a dataset for prediction using a specific model + security: + - bearerAuth: [ ] + tags: + - model + operationId: predictWithModel + parameters: + - name: modelId + in: path + required: true + description: The ID of the model to use for prediction + schema: + type: integer + format: int64 + example: 0 + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Dataset' + responses: + '201': + description: Prediction created successfully + '400': + description: Invalid Request + '404': + description: Model not found + '500': + description: Internal Server Error components: securitySchemes: bearerAuth: @@ -202,6 +235,56 @@ components: type: LocalDateTime description: The date and time when the feature was last updated. example: '2023-01-01T12:00:00Z' + Dataset: + type: object + required: + - type + - input + properties: + id: + type: integer + format: int64 + example: 1 + type: + type: string + enum: + - PREDICTION + example: "PREDICTION" + input: + type: array + items: + $ref: '#/components/schemas/DataEntry' + results: + type: array + items: + $ref: '#/components/schemas/DataEntry' + created_at: + type: LocalDateTime + updated_at: + type: LocalDateTime + DataEntry: + type: object + required: + - type + - values + properties: + id: + type: integer + format: int64 + example: 1 + created_at: + type: LocalDateTime + updated_at: + type: LocalDateTime + type: + type: string + enum: + - ARRAY + example: "ARRAY" + values: + type: array + items: { } + User: title: User x-stoplight: @@ -216,3 +299,5 @@ components: type: string x-stoplight: id: x0pm25vavfibz + AnyValue: + description: Can be any value - string, number, boolean, array or object.