From c340d3264ee03eccbd6d2b5124261d9ce2acdf28 Mon Sep 17 00:00:00 2001 From: alarv Date: Mon, 3 Jun 2024 19:15:00 +0300 Subject: [PATCH 1/5] feat(JAQPOT-114): send request to inference api --- .../kotlin/org/jaqpot/api/entity/DataEntry.kt | 2 +- .../kotlin/org/jaqpot/api/entity/Dataset.kt | 2 +- .../org/jaqpot/api/mapper/DataEntryMapper.kt | 25 ++++++ .../jaqpot/api/mapper/DataEntryTypeMapper.kt | 16 ++++ .../org/jaqpot/api/mapper/DatasetMapper.kt | 31 +++++++ .../jaqpot/api/mapper/DatasetTypeMapper.kt | 16 ++++ .../api/repository/DatasetRepository.kt | 6 ++ .../jaqpot/api/service/model/ModelService.kt | 24 +++++- .../api/service/model/PredictionService.kt | 29 +++++++ src/main/resources/openapi.yaml | 81 +++++++++++++++++++ 10 files changed, 229 insertions(+), 3 deletions(-) create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt create mode 100644 src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt create mode 100644 src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt diff --git a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt index cf1146a..c711f27 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt @@ -21,5 +21,5 @@ class DataEntry( @JdbcTypeCode(SqlTypes.JSON) @Column(name = "values", columnDefinition = "jsonb", nullable = false) - val values: Any, + val values: List, ) : BaseEntity() diff --git a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt index ef5cafb..ecc6789 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt @@ -21,5 +21,5 @@ class Dataset( val type: DatasetType = DatasetType.PREDICTION, @OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true) - val dataEntry: List + val dataEntries: MutableList ) : BaseEntity() diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt new file mode 100644 index 0000000..8c63350 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt @@ -0,0 +1,25 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DataEntry +import org.jaqpot.api.entity.Dataset +import org.jaqpot.api.model.DataEntryDto +import org.jaqpot.api.model.UserDto + +fun DataEntry.toDto(userDto: UserDto): DataEntryDto { + return DataEntryDto( + this.type.toDto(), + this.values, + this.id, + this.createdAt, + this.updatedAt + ) +} + +fun DataEntryDto.toEntity(dataset: Dataset, userId: String): DataEntry { + return DataEntry( + this.id, + dataset, + this.type.toEntity(), + this.propertyValues, + ) +} diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt new file mode 100644 index 0000000..dece6ee --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryTypeMapper.kt @@ -0,0 +1,16 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DataEntryType +import org.jaqpot.api.model.DataEntryDto + +fun DataEntryDto.Type.toEntity(): DataEntryType { + return when (this) { + DataEntryDto.Type.ARRAY -> DataEntryType.ARRAY + } +} + +fun DataEntryType.toDto(): DataEntryDto.Type { + return when (this) { + DataEntryType.ARRAY -> DataEntryDto.Type.ARRAY + } +} diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt new file mode 100644 index 0000000..f875ae7 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt @@ -0,0 +1,31 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.Dataset +import org.jaqpot.api.entity.Model +import org.jaqpot.api.model.DatasetDto +import org.jaqpot.api.model.UserDto + + +fun Dataset.toDto(userDto: UserDto): DatasetDto { + return DatasetDto( + this.type.toDto(), + this.dataEntries.map { it.toDto(userDto) }, + this.id, + this.createdAt, + this.updatedAt + ) +} + +fun DatasetDto.toEntity(model: Model, userId: String): Dataset { + val d = Dataset( + this.id, + model, + userId, + this.type.toEntity(), + mutableListOf() + ) + + d.dataEntries.addAll(this.dataEntries.map { it -> it.toEntity(d, userId) }) + + return d; +} diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt new file mode 100644 index 0000000..d7766e7 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/DatasetTypeMapper.kt @@ -0,0 +1,16 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.DatasetType +import org.jaqpot.api.model.DatasetDto + +fun DatasetDto.Type.toEntity(): DatasetType { + return when (this) { + DatasetDto.Type.PREDICTION -> DatasetType.PREDICTION + } +} + +fun DatasetType.toDto(): DatasetDto.Type { + return when (this) { + DatasetType.PREDICTION -> DatasetDto.Type.PREDICTION + } +} diff --git a/src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt b/src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt new file mode 100644 index 0000000..027c38f --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/repository/DatasetRepository.kt @@ -0,0 +1,6 @@ +package org.jaqpot.api.repository + +import org.jaqpot.api.entity.Dataset +import org.springframework.data.repository.CrudRepository + +interface DatasetRepository : CrudRepository diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index e4f88a7..281d46f 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -5,10 +5,14 @@ import org.jaqpot.api.auth.AuthenticationFacade import org.jaqpot.api.auth.UserService import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity +import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.model.ModelDto import org.jaqpot.api.repository.ModelRepository +import org.springframework.data.repository.findByIdOrNull +import org.springframework.http.HttpStatus import org.springframework.http.ResponseEntity import org.springframework.stereotype.Service +import org.springframework.web.server.ResponseStatusException import org.springframework.web.servlet.support.ServletUriComponentsBuilder import java.net.URI @@ -17,7 +21,8 @@ import java.net.URI class ModelService( private val authenticationFacade: AuthenticationFacade, private val modelRepository: ModelRepository, - private val userService: UserService + private val userService: UserService, + private val predictionService: PredictionService ) : ModelApiDelegate { override fun createModel(modelDto: ModelDto): ResponseEntity { val userId = authenticationFacade.userId @@ -37,5 +42,22 @@ class ModelService( } .orElse(ResponseEntity.notFound().build()) } + + override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity { + if (datasetDto.type == DatasetDto.Type.PREDICTION) { + val model = this.modelRepository.findByIdOrNull(modelId) + ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found"); + val userId = authenticationFacade.userId + + val predictionDataset = this.predictionService.createPrediction(model, userId, datasetDto) + + val location: URI = ServletUriComponentsBuilder + .fromCurrentRequest().path("/{id}") + .buildAndExpand(predictionDataset.id).toUri() + return ResponseEntity.created(location).build() + } + + throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null); + } } diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt new file mode 100644 index 0000000..c65f73c --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt @@ -0,0 +1,29 @@ +package org.jaqpot.api.service.model + +import org.jaqpot.api.entity.Dataset +import org.jaqpot.api.entity.Model +import org.jaqpot.api.mapper.toEntity +import org.jaqpot.api.model.DatasetDto +import org.jaqpot.api.repository.DatasetRepository +import org.springframework.scheduling.annotation.Async +import org.springframework.stereotype.Service + + +@Service +class PredictionService( + private val datasetRepository: DatasetRepository, +) { + + fun createPrediction(model: Model, userId: String, datasetDto: DatasetDto): Dataset { + val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) + this.sendPredictionToInference(dataset) + + return dataset + } + + @Async + fun sendPredictionToInference(dataset: Dataset) { + + } +} + diff --git a/src/main/resources/openapi.yaml b/src/main/resources/openapi.yaml index 3da0496..256ebff 100644 --- a/src/main/resources/openapi.yaml +++ b/src/main/resources/openapi.yaml @@ -69,6 +69,39 @@ paths: $ref: '#/components/schemas/Model' '404': description: Model not found + /v1/models/{modelId}/predict: + post: + summary: Predict with Model + description: Submit a dataset for prediction using a specific model + security: + - bearerAuth: [ ] + tags: + - model + operationId: predictWithModel + parameters: + - name: modelId + in: path + required: true + description: The ID of the model to use for prediction + schema: + type: integer + format: int64 + example: 0 + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Dataset' + responses: + '201': + description: Prediction created successfully + '400': + description: Invalid Request + '404': + description: Model not found + '500': + description: Internal Server Error components: securitySchemes: bearerAuth: @@ -202,6 +235,52 @@ components: type: LocalDateTime description: The date and time when the feature was last updated. example: '2023-01-01T12:00:00Z' + Dataset: + type: object + required: + - type + - dataEntries + properties: + id: + type: integer + format: int64 + example: 1 + type: + type: string + enum: + - PREDICTION + example: "PREDICTION" + dataEntries: + type: array + items: + $ref: '#/components/schemas/DataEntry' + created_at: + type: LocalDateTime + updated_at: + type: LocalDateTime + DataEntry: + type: object + required: + - type + - values + properties: + id: + type: integer + format: int64 + example: 1 + created_at: + type: LocalDateTime + updated_at: + type: LocalDateTime + type: + type: string + enum: + - ARRAY + example: "ARRAY" + values: + type: array + items: { } + User: title: User x-stoplight: @@ -216,3 +295,5 @@ components: type: string x-stoplight: id: x0pm25vavfibz + AnyValue: + description: Can be any value - string, number, boolean, array or object. From 100040d6bc1cb571077f0edc282c553ce49021c3 Mon Sep 17 00:00:00 2001 From: alarv Date: Mon, 3 Jun 2024 21:04:51 +0300 Subject: [PATCH 2/5] feat: send test request to jaqpot-inference --- .../kotlin/org/jaqpot/api/entity/DataEntry.kt | 2 +- .../kotlin/org/jaqpot/api/entity/Dataset.kt | 2 +- .../kotlin/org/jaqpot/api/entity/Feature.kt | 2 +- .../kotlin/org/jaqpot/api/entity/Library.kt | 2 +- .../jaqpot/api/service/model/ModelService.kt | 1 - .../api/service/model/PredictionRequestDto.kt | 12 ++++++++++++ .../api/service/model/PredictionService.kt | 17 +++++++++++++++-- 7 files changed, 31 insertions(+), 7 deletions(-) create mode 100644 src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt diff --git a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt index c711f27..aca5e67 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt @@ -11,7 +11,7 @@ class DataEntry( @SequenceGenerator(name = "data_entry_id_seq", sequenceName = "data_entry_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "dataset_id", updatable = false, nullable = false) val dataset: Dataset, diff --git a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt index ecc6789..cfdad39 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt @@ -9,7 +9,7 @@ class Dataset( @SequenceGenerator(name = "dataset_id_seq", sequenceName = "dataset_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "model_id", updatable = false, nullable = false) val model: Model, diff --git a/src/main/kotlin/org/jaqpot/api/entity/Feature.kt b/src/main/kotlin/org/jaqpot/api/entity/Feature.kt index 3e09367..fa1336a 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Feature.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Feature.kt @@ -11,7 +11,7 @@ class Feature( @SequenceGenerator(name = "feature_id_seq", sequenceName = "feature_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "model_id", nullable = false) val model: Model, diff --git a/src/main/kotlin/org/jaqpot/api/entity/Library.kt b/src/main/kotlin/org/jaqpot/api/entity/Library.kt index 0130d8e..438cc2f 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Library.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Library.kt @@ -9,7 +9,7 @@ class Library( @SequenceGenerator(name = "library_id_seq", sequenceName = "library_id_seq", allocationSize = 1) val id: Long? = 0, - @ManyToOne + @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "model_id", updatable = false, nullable = false) val model: Model, diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index 281d46f..1f076b2 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -3,7 +3,6 @@ package org.jaqpot.api.service.model import org.jaqpot.api.ModelApiDelegate import org.jaqpot.api.auth.AuthenticationFacade import org.jaqpot.api.auth.UserService -import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.model.ModelDto diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt b/src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt new file mode 100644 index 0000000..64233f6 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt @@ -0,0 +1,12 @@ +package org.jaqpot.api.service.model + +import com.fasterxml.jackson.annotation.JsonInclude +import org.jaqpot.api.entity.Dataset + +@JsonInclude(JsonInclude.Include.ALWAYS) +class PredictionRequestDto( + val rawModel: List, + val dataset: Dataset, + val additionalInfo: String? = null, + val doaMatrix: String? = null +) diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt index c65f73c..42e134f 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt @@ -5,8 +5,11 @@ import org.jaqpot.api.entity.Model import org.jaqpot.api.mapper.toEntity import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.repository.DatasetRepository +import org.springframework.http.HttpEntity import org.springframework.scheduling.annotation.Async import org.springframework.stereotype.Service +import org.springframework.web.client.RestTemplate +import java.util.* @Service @@ -16,14 +19,24 @@ class PredictionService( fun createPrediction(model: Model, userId: String, datasetDto: DatasetDto): Dataset { val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) - this.sendPredictionToInference(dataset) + this.sendPredictionToInference(model, dataset) return dataset } @Async - fun sendPredictionToInference(dataset: Dataset) { + fun sendPredictionToInference(model: Model, dataset: Dataset) { + // TODO send request properly + val restTemplate = RestTemplate() + val fooResourceUrl = "http://localhost:8002/predict/" + val rawModel = Base64.getEncoder().encodeToString(model.actualModel) + val request: HttpEntity = + HttpEntity(PredictionRequestDto(listOf(rawModel), dataset)) + + val response = restTemplate.postForEntity(fooResourceUrl, request, String::class.java) + + println(response) } } From d3d142a930a31260281adbf5c57d7ba18aa3ef8b Mon Sep 17 00:00:00 2001 From: alarv Date: Mon, 3 Jun 2024 21:36:21 +0300 Subject: [PATCH 3/5] fix: compilation error --- src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index 1f076b2..c0c36d5 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -3,6 +3,7 @@ package org.jaqpot.api.service.model import org.jaqpot.api.ModelApiDelegate import org.jaqpot.api.auth.AuthenticationFacade import org.jaqpot.api.auth.UserService +import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.model.ModelDto @@ -45,7 +46,7 @@ class ModelService( override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity { if (datasetDto.type == DatasetDto.Type.PREDICTION) { val model = this.modelRepository.findByIdOrNull(modelId) - ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found"); + ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") val userId = authenticationFacade.userId val predictionDataset = this.predictionService.createPrediction(model, userId, datasetDto) @@ -56,7 +57,7 @@ class ModelService( return ResponseEntity.created(location).build() } - throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null); + throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null) } } From 139e33856cef3a3e450f4c7c2dd65fd9ca1a7b15 Mon Sep 17 00:00:00 2001 From: alarv Date: Mon, 3 Jun 2024 22:36:41 +0300 Subject: [PATCH 4/5] feat: complete end-to-end with prediction request --- .../kotlin/org/jaqpot/api/entity/DataEntry.kt | 6 ++- .../org/jaqpot/api/entity/DataEntryRole.kt | 6 +++ .../kotlin/org/jaqpot/api/entity/Dataset.kt | 8 ++- .../org/jaqpot/api/mapper/DataEntryMapper.kt | 7 +-- .../org/jaqpot/api/mapper/DatasetMapper.kt | 15 +++--- .../jaqpot/api/service/model/ModelService.kt | 2 +- .../api/service/model/PredictionService.kt | 53 ++++++++++++++----- .../model/{ => dto}/PredictionRequestDto.kt | 6 +-- .../model/dto/PredictionResponseDto.kt | 8 +++ .../api/service/runtime/RuntimeResolver.kt | 6 ++- .../runtime/{ => config}/RuntimeProvider.kt | 2 +- .../db/migration/V2__create_main_entities.sql | 1 + src/main/resources/openapi.yaml | 8 ++- 13 files changed, 96 insertions(+), 32 deletions(-) create mode 100644 src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt rename src/main/kotlin/org/jaqpot/api/service/model/{ => dto}/PredictionRequestDto.kt (67%) create mode 100644 src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt rename src/main/kotlin/org/jaqpot/api/service/runtime/{ => config}/RuntimeProvider.kt (84%) diff --git a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt index aca5e67..4a6ea3c 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/DataEntry.kt @@ -17,7 +17,11 @@ class DataEntry( @Enumerated(EnumType.STRING) @Column(nullable = false) - val type: DataEntryType, + val type: DataEntryType = DataEntryType.ARRAY, + + @Enumerated(EnumType.STRING) + @Column(nullable = false) + val role: DataEntryRole, @JdbcTypeCode(SqlTypes.JSON) @Column(name = "values", columnDefinition = "jsonb", nullable = false) diff --git a/src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt b/src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt new file mode 100644 index 0000000..e0d7f6e --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/entity/DataEntryRole.kt @@ -0,0 +1,6 @@ +package org.jaqpot.api.entity + +enum class DataEntryRole { + INPUT, + RESULTS +} diff --git a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt index cfdad39..9b47db8 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Dataset.kt @@ -1,6 +1,7 @@ package org.jaqpot.api.entity import jakarta.persistence.* +import org.hibernate.annotations.SQLRestriction @Entity class Dataset( @@ -21,5 +22,10 @@ class Dataset( val type: DatasetType = DatasetType.PREDICTION, @OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true) - val dataEntries: MutableList + @SQLRestriction("data_entry_role = 'INPUT'") + val input: MutableList, + + @OneToMany(mappedBy = "dataset", cascade = [CascadeType.ALL], orphanRemoval = true) + @SQLRestriction("data_entry_role = 'RESULTS'") + var results: MutableList ) : BaseEntity() diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt index 8c63350..a92793f 100644 --- a/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt +++ b/src/main/kotlin/org/jaqpot/api/mapper/DataEntryMapper.kt @@ -1,11 +1,11 @@ package org.jaqpot.api.mapper import org.jaqpot.api.entity.DataEntry +import org.jaqpot.api.entity.DataEntryRole import org.jaqpot.api.entity.Dataset import org.jaqpot.api.model.DataEntryDto -import org.jaqpot.api.model.UserDto -fun DataEntry.toDto(userDto: UserDto): DataEntryDto { +fun DataEntry.toDto(): DataEntryDto { return DataEntryDto( this.type.toDto(), this.values, @@ -15,11 +15,12 @@ fun DataEntry.toDto(userDto: UserDto): DataEntryDto { ) } -fun DataEntryDto.toEntity(dataset: Dataset, userId: String): DataEntry { +fun DataEntryDto.toEntity(dataset: Dataset, dataEntryRole: DataEntryRole): DataEntry { return DataEntry( this.id, dataset, this.type.toEntity(), + dataEntryRole, this.propertyValues, ) } diff --git a/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt index f875ae7..0621639 100644 --- a/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt +++ b/src/main/kotlin/org/jaqpot/api/mapper/DatasetMapper.kt @@ -1,16 +1,17 @@ package org.jaqpot.api.mapper +import org.jaqpot.api.entity.DataEntryRole import org.jaqpot.api.entity.Dataset import org.jaqpot.api.entity.Model import org.jaqpot.api.model.DatasetDto -import org.jaqpot.api.model.UserDto -fun Dataset.toDto(userDto: UserDto): DatasetDto { +fun Dataset.toDto(): DatasetDto { return DatasetDto( this.type.toDto(), - this.dataEntries.map { it.toDto(userDto) }, + this.input.map { it.toDto() }, this.id, + this.results.map { it.toDto() }, this.createdAt, this.updatedAt ) @@ -22,10 +23,12 @@ fun DatasetDto.toEntity(model: Model, userId: String): Dataset { model, userId, this.type.toEntity(), + mutableListOf(), mutableListOf() ) - d.dataEntries.addAll(this.dataEntries.map { it -> it.toEntity(d, userId) }) - - return d; + d.input.addAll(this.input.map { it -> it.toEntity(d, DataEntryRole.INPUT) }) + d.results.addAll(this.results?.map { it -> it.toEntity(d, DataEntryRole.RESULTS) } ?: emptyList()) + + return d } diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index c0c36d5..125a576 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -49,7 +49,7 @@ class ModelService( ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") val userId = authenticationFacade.userId - val predictionDataset = this.predictionService.createPrediction(model, userId, datasetDto) + val predictionDataset = this.predictionService.createAndPredictDataset(model, userId, datasetDto) val location: URI = ServletUriComponentsBuilder .fromCurrentRequest().path("/{id}") diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt index 42e134f..2b8972b 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt @@ -1,10 +1,13 @@ package org.jaqpot.api.service.model -import org.jaqpot.api.entity.Dataset -import org.jaqpot.api.entity.Model +import org.jaqpot.api.entity.* +import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.repository.DatasetRepository +import org.jaqpot.api.service.model.dto.PredictionRequestDto +import org.jaqpot.api.service.model.dto.PredictionResponseDto +import org.jaqpot.api.service.runtime.RuntimeResolver import org.springframework.http.HttpEntity import org.springframework.scheduling.annotation.Async import org.springframework.stereotype.Service @@ -15,28 +18,54 @@ import java.util.* @Service class PredictionService( private val datasetRepository: DatasetRepository, + private val runtimeResolver: RuntimeResolver ) { - fun createPrediction(model: Model, userId: String, datasetDto: DatasetDto): Dataset { + fun createAndPredictDataset(model: Model, userId: String, datasetDto: DatasetDto): Dataset { val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) - this.sendPredictionToInference(model, dataset) + this.executePredictionAndSaveResults(model, dataset) return dataset } @Async - fun sendPredictionToInference(model: Model, dataset: Dataset) { - // TODO send request properly - val restTemplate = RestTemplate() - val fooResourceUrl = "http://localhost:8002/predict/" - + fun executePredictionAndSaveResults(model: Model, dataset: Dataset) { val rawModel = Base64.getEncoder().encodeToString(model.actualModel) val request: HttpEntity = - HttpEntity(PredictionRequestDto(listOf(rawModel), dataset)) + HttpEntity(PredictionRequestDto(listOf(rawModel), dataset.toDto())) + + val results: List = makePredictionRequest(model, request) + + storeResults(dataset, results) + } - val response = restTemplate.postForEntity(fooResourceUrl, request, String::class.java) + private fun storeResults(dataset: Dataset, results: List) { + dataset.results.clear() + dataset.results.addAll( + listOf( + DataEntry( + null, + dataset, + DataEntryType.ARRAY, + DataEntryRole.RESULTS, + results + ) + ) + ) + + datasetRepository.save(dataset) + } + + private fun makePredictionRequest( + model: Model, + request: HttpEntity + ): List { + val restTemplate = RestTemplate() + val inferenceUrl = "${runtimeResolver.resolveRuntime(model)}/predict/" + val response = restTemplate.postForEntity(inferenceUrl, request, PredictionResponseDto::class.java) - println(response) + val results: List = response.body?.predictions ?: emptyList() + return results } } diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt similarity index 67% rename from src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt rename to src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt index 64233f6..577c96f 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/PredictionRequestDto.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionRequestDto.kt @@ -1,12 +1,12 @@ -package org.jaqpot.api.service.model +package org.jaqpot.api.service.model.dto import com.fasterxml.jackson.annotation.JsonInclude -import org.jaqpot.api.entity.Dataset +import org.jaqpot.api.model.DatasetDto @JsonInclude(JsonInclude.Include.ALWAYS) class PredictionRequestDto( val rawModel: List, - val dataset: Dataset, + val dataset: DatasetDto, val additionalInfo: String? = null, val doaMatrix: String? = null ) diff --git a/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt new file mode 100644 index 0000000..a40d61a --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/service/model/dto/PredictionResponseDto.kt @@ -0,0 +1,8 @@ +package org.jaqpot.api.service.model.dto + +import com.fasterxml.jackson.annotation.JsonInclude + +@JsonInclude(JsonInclude.Include.ALWAYS) +class PredictionResponseDto( + val predictions: List> +) diff --git a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt b/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt index 0bd60b7..ec1bd73 100644 --- a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt +++ b/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeResolver.kt @@ -1,11 +1,13 @@ package org.jaqpot.api.service.runtime +import org.jaqpot.api.entity.Model +import org.jaqpot.api.service.runtime.config.RuntimeProvider import org.springframework.stereotype.Component @Component class RuntimeResolver(val runtimeProvider: RuntimeProvider) { - fun resolveRuntime(): String { - return runtimeProvider.jaqpotpyPretrainedUrl; + fun resolveRuntime(model: Model): String { + return runtimeProvider.jaqpotpyPretrainedUrl } } diff --git a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeProvider.kt b/src/main/kotlin/org/jaqpot/api/service/runtime/config/RuntimeProvider.kt similarity index 84% rename from src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeProvider.kt rename to src/main/kotlin/org/jaqpot/api/service/runtime/config/RuntimeProvider.kt index 1e44389..e083bc0 100644 --- a/src/main/kotlin/org/jaqpot/api/service/runtime/RuntimeProvider.kt +++ b/src/main/kotlin/org/jaqpot/api/service/runtime/config/RuntimeProvider.kt @@ -1,4 +1,4 @@ -package org.jaqpot.api.service.runtime +package org.jaqpot.api.service.runtime.config import org.springframework.beans.factory.annotation.Value import org.springframework.context.annotation.Configuration diff --git a/src/main/resources/db/migration/V2__create_main_entities.sql b/src/main/resources/db/migration/V2__create_main_entities.sql index fd1d8e1..4b451e7 100644 --- a/src/main/resources/db/migration/V2__create_main_entities.sql +++ b/src/main/resources/db/migration/V2__create_main_entities.sql @@ -57,6 +57,7 @@ CREATE TABLE data_entry updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL, dataset_id BIGINT NOT NULL, type VARCHAR(255) NOT NULL, + role VARCHAR(255) NOT NULL, values JSONB NOT NULL, CONSTRAINT pk_dataentry PRIMARY KEY (id) ); diff --git a/src/main/resources/openapi.yaml b/src/main/resources/openapi.yaml index 256ebff..1e00247 100644 --- a/src/main/resources/openapi.yaml +++ b/src/main/resources/openapi.yaml @@ -239,7 +239,7 @@ components: type: object required: - type - - dataEntries + - input properties: id: type: integer @@ -250,7 +250,11 @@ components: enum: - PREDICTION example: "PREDICTION" - dataEntries: + input: + type: array + items: + $ref: '#/components/schemas/DataEntry' + results: type: array items: $ref: '#/components/schemas/DataEntry' From fee6d6f49a691057854531e37bd6230bf9a13f97 Mon Sep 17 00:00:00 2001 From: Alex Arvanitidis Date: Tue, 4 Jun 2024 11:25:18 +0300 Subject: [PATCH 5/5] fix: make async method run properly async --- src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt | 2 ++ .../kotlin/org/jaqpot/api/service/model/ModelService.kt | 8 +++++--- .../org/jaqpot/api/service/model/PredictionService.kt | 9 --------- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt b/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt index ac620c8..4966326 100644 --- a/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt +++ b/src/main/kotlin/org/jaqpot/api/JaqpotApiApplication.kt @@ -3,9 +3,11 @@ package org.jaqpot.api import org.springframework.boot.autoconfigure.SpringBootApplication import org.springframework.boot.context.properties.ConfigurationPropertiesScan import org.springframework.boot.runApplication +import org.springframework.scheduling.annotation.EnableAsync @SpringBootApplication @ConfigurationPropertiesScan("org.jaqpot.api") +@EnableAsync class JaqpotApiApplication fun main(args: Array) { diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index 125a576..438ff82 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -7,6 +7,7 @@ import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.model.ModelDto +import org.jaqpot.api.repository.DatasetRepository import org.jaqpot.api.repository.ModelRepository import org.springframework.data.repository.findByIdOrNull import org.springframework.http.HttpStatus @@ -22,7 +23,7 @@ class ModelService( private val authenticationFacade: AuthenticationFacade, private val modelRepository: ModelRepository, private val userService: UserService, - private val predictionService: PredictionService + private val predictionService: PredictionService, private val datasetRepository: DatasetRepository ) : ModelApiDelegate { override fun createModel(modelDto: ModelDto): ResponseEntity { val userId = authenticationFacade.userId @@ -48,12 +49,13 @@ class ModelService( val model = this.modelRepository.findByIdOrNull(modelId) ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") val userId = authenticationFacade.userId + val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) - val predictionDataset = this.predictionService.createAndPredictDataset(model, userId, datasetDto) + this.predictionService.executePredictionAndSaveResults(model, dataset) val location: URI = ServletUriComponentsBuilder .fromCurrentRequest().path("/{id}") - .buildAndExpand(predictionDataset.id).toUri() + .buildAndExpand(dataset.id).toUri() return ResponseEntity.created(location).build() } diff --git a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt index 2b8972b..f4c420c 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/PredictionService.kt @@ -2,8 +2,6 @@ package org.jaqpot.api.service.model import org.jaqpot.api.entity.* import org.jaqpot.api.mapper.toDto -import org.jaqpot.api.mapper.toEntity -import org.jaqpot.api.model.DatasetDto import org.jaqpot.api.repository.DatasetRepository import org.jaqpot.api.service.model.dto.PredictionRequestDto import org.jaqpot.api.service.model.dto.PredictionResponseDto @@ -21,13 +19,6 @@ class PredictionService( private val runtimeResolver: RuntimeResolver ) { - fun createAndPredictDataset(model: Model, userId: String, datasetDto: DatasetDto): Dataset { - val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) - this.executePredictionAndSaveResults(model, dataset) - - return dataset - } - @Async fun executePredictionAndSaveResults(model: Model, dataset: Dataset) { val rawModel = Base64.getEncoder().encodeToString(model.actualModel)