diff --git a/src/main/kotlin/org/jaqpot/api/entity/Model.kt b/src/main/kotlin/org/jaqpot/api/entity/Model.kt index 06f82ae..2045b2f 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Model.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Model.kt @@ -47,6 +47,14 @@ class Model( @SQLRestriction("feature_dependency = 'INDEPENDENT'") val independentFeatures: MutableList, + @ManyToMany + @JoinTable( + name = "organization_models", + joinColumns = [JoinColumn(name = "model_id")], + inverseJoinColumns = [JoinColumn(name = "organization_id")] + ) + val organizations: MutableSet = mutableSetOf(), + @Enumerated(EnumType.STRING) @Column(nullable = false) val visibility: ModelVisibility, diff --git a/src/main/kotlin/org/jaqpot/api/entity/Organization.kt b/src/main/kotlin/org/jaqpot/api/entity/Organization.kt index 2551208..676dd8c 100644 --- a/src/main/kotlin/org/jaqpot/api/entity/Organization.kt +++ b/src/main/kotlin/org/jaqpot/api/entity/Organization.kt @@ -31,7 +31,6 @@ class Organization( @Column(name = "user_id", nullable = false) val userIds: Set = mutableSetOf(), - @ManyToMany @JoinTable( name = "organization_models", diff --git a/src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt index 6bd1bfd..7b96f55 100644 --- a/src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt +++ b/src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt @@ -18,6 +18,7 @@ fun Model.toDto(userDto: UserDto): ModelDto { this.id, this.meta, this.type, + this.organizations.map { it.toDto() }, this.reliability, this.pretrained, userDto, @@ -38,6 +39,7 @@ fun ModelDto.toEntity(creatorId: String): Model { mutableListOf(), mutableListOf(), mutableListOf(), + mutableSetOf(), this.visibility.toEntity(), this.reliability, this.pretrained, diff --git a/src/main/kotlin/org/jaqpot/api/mapper/OrganizationMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/OrganizationMapper.kt index 49785f7..ee23be5 100644 --- a/src/main/kotlin/org/jaqpot/api/mapper/OrganizationMapper.kt +++ b/src/main/kotlin/org/jaqpot/api/mapper/OrganizationMapper.kt @@ -2,9 +2,8 @@ package org.jaqpot.api.mapper import org.jaqpot.api.entity.Organization import org.jaqpot.api.model.OrganizationDto -import org.jaqpot.api.model.UserDto -fun Organization.toDto(userDto: UserDto): OrganizationDto { +fun Organization.toDto(): OrganizationDto { return OrganizationDto( this.name, this.contactEmail, diff --git a/src/main/kotlin/org/jaqpot/api/mapper/PageMapper.kt b/src/main/kotlin/org/jaqpot/api/mapper/PageMapper.kt new file mode 100644 index 0000000..9cbd5e4 --- /dev/null +++ b/src/main/kotlin/org/jaqpot/api/mapper/PageMapper.kt @@ -0,0 +1,16 @@ +package org.jaqpot.api.mapper + +import org.jaqpot.api.entity.Model +import org.jaqpot.api.model.GetModels200ResponseDto +import org.jaqpot.api.model.UserDto +import org.springframework.data.domain.Page + +fun Page.toGetModels200ResponseDto(creatorDto: UserDto): GetModels200ResponseDto { + return GetModels200ResponseDto( + this.content.map { it.toDto(creatorDto) }, + this.totalElements.toInt(), + this.totalPages, + this.pageable.pageSize, + this.pageable.pageNumber + ) +} diff --git a/src/main/kotlin/org/jaqpot/api/repository/ModelRepository.kt b/src/main/kotlin/org/jaqpot/api/repository/ModelRepository.kt index 2de3aab..cdb5cf3 100644 --- a/src/main/kotlin/org/jaqpot/api/repository/ModelRepository.kt +++ b/src/main/kotlin/org/jaqpot/api/repository/ModelRepository.kt @@ -1,6 +1,11 @@ package org.jaqpot.api.repository import org.jaqpot.api.entity.Model +import org.springframework.data.domain.Page +import org.springframework.data.domain.Pageable import org.springframework.data.repository.CrudRepository +import org.springframework.data.repository.PagingAndSortingRepository -interface ModelRepository : CrudRepository +interface ModelRepository : PagingAndSortingRepository, CrudRepository { + fun findAllByCreatorId(creatorId: String, pageable: Pageable): Page +} diff --git a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt index 58de3ee..f603803 100644 --- a/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt @@ -1,15 +1,17 @@ package org.jaqpot.api.service.model +import jakarta.transaction.Transactional import org.jaqpot.api.ModelApiDelegate import org.jaqpot.api.mapper.toDto import org.jaqpot.api.mapper.toEntity -import org.jaqpot.api.model.DatasetDto -import org.jaqpot.api.model.ModelDto +import org.jaqpot.api.mapper.toGetModels200ResponseDto +import org.jaqpot.api.model.* import org.jaqpot.api.repository.DatasetRepository import org.jaqpot.api.repository.ModelRepository +import org.jaqpot.api.repository.OrganizationRepository import org.jaqpot.api.service.authentication.AuthenticationFacade import org.jaqpot.api.service.authentication.UserService -import org.springframework.data.repository.findByIdOrNull +import org.springframework.data.domain.PageRequest import org.springframework.http.HttpStatus import org.springframework.http.ResponseEntity import org.springframework.security.access.prepost.PostAuthorize @@ -25,8 +27,20 @@ class ModelService( private val authenticationFacade: AuthenticationFacade, private val modelRepository: ModelRepository, private val userService: UserService, - private val predictionService: PredictionService, private val datasetRepository: DatasetRepository + private val predictionService: PredictionService, + private val datasetRepository: DatasetRepository, + private val organizationRepository: OrganizationRepository, ) : ModelApiDelegate { + + override fun getModels(page: Int, size: Int): ResponseEntity { + val creatorId = authenticationFacade.userId + val pageable = PageRequest.of(page, size) + val modelsPage = modelRepository.findAllByCreatorId(creatorId, pageable) + val creator = userService.getUserById(creatorId) + + return ResponseEntity.ok().body(modelsPage.toGetModels200ResponseDto(creator)) + } + override fun createModel(modelDto: ModelDto): ResponseEntity { if (modelDto.id != null) { throw IllegalStateException("ID should not be provided for resource creation.") @@ -53,8 +67,9 @@ class ModelService( @PreAuthorize("@predictModelAuthorizationLogic.decide(#root, #modelId)") override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity { if (datasetDto.type == DatasetDto.Type.PREDICTION) { - val model = this.modelRepository.findByIdOrNull(modelId) - ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") + val model = modelRepository.findById(modelId).orElseThrow { + throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") + } val userId = authenticationFacade.userId val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId)) @@ -68,5 +83,29 @@ class ModelService( throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null) } + + // TODO add authorization + @Transactional + override fun updateModelOrganizations( + modelId: kotlin.Long, + updateModelOrganizationsRequestDto: UpdateModelOrganizationsRequestDto + ): ResponseEntity { + val model = modelRepository.findById(modelId).orElseThrow { + throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found") + } + + // Fetch the organizations to be associated + val organizations = organizationRepository.findAllById(updateModelOrganizationsRequestDto.organizationIds!!) + + // Clear the current associations + model.organizations.clear() + + // Update with new associations + model.organizations.addAll(organizations) + + // Persist the changes + modelRepository.save(model) + return ResponseEntity.ok(UpdateModelOrganizations200ResponseDto("Organizations updated successfully!")) + } } diff --git a/src/main/kotlin/org/jaqpot/api/service/organization/OrganizationService.kt b/src/main/kotlin/org/jaqpot/api/service/organization/OrganizationService.kt index f4771f1..35d6e5a 100644 --- a/src/main/kotlin/org/jaqpot/api/service/organization/OrganizationService.kt +++ b/src/main/kotlin/org/jaqpot/api/service/organization/OrganizationService.kt @@ -41,7 +41,6 @@ class OrganizationService( val organization = organizationRepository.findByName(name) ?: throw NotFoundException("Organization with name $name not found.") - val creatorDto = userService.getUserById(organization.creatorId) - return ResponseEntity.ok(organization.toDto(creatorDto)) + return ResponseEntity.ok(organization.toDto()) } } diff --git a/src/main/resources/openapi.yaml b/src/main/resources/openapi.yaml index 4cc4ebf..f35c7a8 100644 --- a/src/main/resources/openapi.yaml +++ b/src/main/resources/openapi.yaml @@ -42,6 +42,49 @@ paths: description: Invalid input x-stoplight: id: 9ffjy7o77jc41 + get: + x-spring-paginated: true + summary: Get paginated models + security: + - bearerAuth: [ ] + tags: + - model + operationId: getModels + parameters: + - name: page + in: query + required: false + schema: + type: integer + default: 0 + - name: size + in: query + required: false + schema: + type: integer + default: 10 + responses: + '200': + description: Paginated list of models + content: + application/json: + schema: + type: object + properties: + content: + type: array + items: + $ref: '#/components/schemas/Model' + totalElements: + type: integer + totalPages: + type: integer + pageSize: + type: integer + pageNumber: + type: integer + '400': + description: Invalid input '/v1/models/{id}': get: summary: Get a Model @@ -102,6 +145,44 @@ paths: description: Model not found '500': description: Internal Server Error + /v1/models/{modelId}/organizations: + put: + summary: Update organizations for a model + operationId: updateModelOrganizations + tags: + - Model + parameters: + - name: modelId + in: path + required: true + schema: + type: integer + format: int64 + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + organizationIds: + type: array + items: + type: integer + format: int64 + responses: + '200': + description: Organizations updated successfully + content: + application/json: + schema: + type: object + properties: + message: + type: string + example: Organizations updated successfully + '404': + description: Model or Organization not found '/v1/datasets/{id}': get: summary: Get a Dataset @@ -264,6 +345,10 @@ components: type: array items: $ref: '#/components/schemas/Feature' + organizations: + type: array + items: + $ref: '#/components/schemas/Organization' visibility: type: string enum: