Skip to content

Commit

Permalink
feat(JAQPOT-128): update model organizations endpoint (#27)
Browse files Browse the repository at this point in the history
* feat(JAQPOT-128): update-model-organizations endpoint

* feat: add get models paginated
  • Loading branch information
alarv authored Jun 9, 2024
1 parent 501f131 commit 60bfebe
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 12 deletions.
8 changes: 8 additions & 0 deletions src/main/kotlin/org/jaqpot/api/entity/Model.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ class Model(
@SQLRestriction("feature_dependency = 'INDEPENDENT'")
val independentFeatures: MutableList<Feature>,

@ManyToMany
@JoinTable(
name = "organization_models",
joinColumns = [JoinColumn(name = "model_id")],
inverseJoinColumns = [JoinColumn(name = "organization_id")]
)
val organizations: MutableSet<Organization> = mutableSetOf(),

@Enumerated(EnumType.STRING)
@Column(nullable = false)
val visibility: ModelVisibility,
Expand Down
1 change: 0 additions & 1 deletion src/main/kotlin/org/jaqpot/api/entity/Organization.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class Organization(
@Column(name = "user_id", nullable = false)
val userIds: Set<String> = mutableSetOf(),


@ManyToMany
@JoinTable(
name = "organization_models",
Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ fun Model.toDto(userDto: UserDto): ModelDto {
this.id,
this.meta,
this.type,
this.organizations.map { it.toDto() },
this.reliability,
this.pretrained,
userDto,
Expand All @@ -38,6 +39,7 @@ fun ModelDto.toEntity(creatorId: String): Model {
mutableListOf(),
mutableListOf(),
mutableListOf(),
mutableSetOf(),
this.visibility.toEntity(),
this.reliability,
this.pretrained,
Expand Down
3 changes: 1 addition & 2 deletions src/main/kotlin/org/jaqpot/api/mapper/OrganizationMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package org.jaqpot.api.mapper

import org.jaqpot.api.entity.Organization
import org.jaqpot.api.model.OrganizationDto
import org.jaqpot.api.model.UserDto

fun Organization.toDto(userDto: UserDto): OrganizationDto {
fun Organization.toDto(): OrganizationDto {
return OrganizationDto(
this.name,
this.contactEmail,
Expand Down
16 changes: 16 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/PageMapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jaqpot.api.mapper

import org.jaqpot.api.entity.Model
import org.jaqpot.api.model.GetModels200ResponseDto
import org.jaqpot.api.model.UserDto
import org.springframework.data.domain.Page

fun Page<Model>.toGetModels200ResponseDto(creatorDto: UserDto): GetModels200ResponseDto {
return GetModels200ResponseDto(
this.content.map { it.toDto(creatorDto) },
this.totalElements.toInt(),
this.totalPages,
this.pageable.pageSize,
this.pageable.pageNumber
)
}
7 changes: 6 additions & 1 deletion src/main/kotlin/org/jaqpot/api/repository/ModelRepository.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package org.jaqpot.api.repository

import org.jaqpot.api.entity.Model
import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.data.repository.CrudRepository
import org.springframework.data.repository.PagingAndSortingRepository

interface ModelRepository : CrudRepository<Model, Long>
interface ModelRepository : PagingAndSortingRepository<Model, Long>, CrudRepository<Model, Long> {
fun findAllByCreatorId(creatorId: String, pageable: Pageable): Page<Model>
}
51 changes: 45 additions & 6 deletions src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package org.jaqpot.api.service.model

import jakarta.transaction.Transactional
import org.jaqpot.api.ModelApiDelegate
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.mapper.toEntity
import org.jaqpot.api.model.DatasetDto
import org.jaqpot.api.model.ModelDto
import org.jaqpot.api.mapper.toGetModels200ResponseDto
import org.jaqpot.api.model.*
import org.jaqpot.api.repository.DatasetRepository
import org.jaqpot.api.repository.ModelRepository
import org.jaqpot.api.repository.OrganizationRepository
import org.jaqpot.api.service.authentication.AuthenticationFacade
import org.jaqpot.api.service.authentication.UserService
import org.springframework.data.repository.findByIdOrNull
import org.springframework.data.domain.PageRequest
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity
import org.springframework.security.access.prepost.PostAuthorize
Expand All @@ -25,8 +27,20 @@ class ModelService(
private val authenticationFacade: AuthenticationFacade,
private val modelRepository: ModelRepository,
private val userService: UserService,
private val predictionService: PredictionService, private val datasetRepository: DatasetRepository
private val predictionService: PredictionService,
private val datasetRepository: DatasetRepository,
private val organizationRepository: OrganizationRepository,
) : ModelApiDelegate {

override fun getModels(page: Int, size: Int): ResponseEntity<GetModels200ResponseDto> {
val creatorId = authenticationFacade.userId
val pageable = PageRequest.of(page, size)
val modelsPage = modelRepository.findAllByCreatorId(creatorId, pageable)
val creator = userService.getUserById(creatorId)

return ResponseEntity.ok().body(modelsPage.toGetModels200ResponseDto(creator))
}

override fun createModel(modelDto: ModelDto): ResponseEntity<Unit> {
if (modelDto.id != null) {
throw IllegalStateException("ID should not be provided for resource creation.")
Expand All @@ -53,8 +67,9 @@ class ModelService(
@PreAuthorize("@predictModelAuthorizationLogic.decide(#root, #modelId)")
override fun predictWithModel(modelId: Long, datasetDto: DatasetDto): ResponseEntity<Unit> {
if (datasetDto.type == DatasetDto.Type.PREDICTION) {
val model = this.modelRepository.findByIdOrNull(modelId)
?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
val model = modelRepository.findById(modelId).orElseThrow {
throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
}
val userId = authenticationFacade.userId
val dataset = this.datasetRepository.save(datasetDto.toEntity(model, userId))

Expand All @@ -68,5 +83,29 @@ class ModelService(

throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown dataset type", null)
}

// TODO add authorization
@Transactional
override fun updateModelOrganizations(
modelId: kotlin.Long,
updateModelOrganizationsRequestDto: UpdateModelOrganizationsRequestDto
): ResponseEntity<UpdateModelOrganizations200ResponseDto> {
val model = modelRepository.findById(modelId).orElseThrow {
throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
}

// Fetch the organizations to be associated
val organizations = organizationRepository.findAllById(updateModelOrganizationsRequestDto.organizationIds!!)

// Clear the current associations
model.organizations.clear()

// Update with new associations
model.organizations.addAll(organizations)

// Persist the changes
modelRepository.save(model)
return ResponseEntity.ok(UpdateModelOrganizations200ResponseDto("Organizations updated successfully!"))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class OrganizationService(
val organization = organizationRepository.findByName(name)
?: throw NotFoundException("Organization with name $name not found.")

val creatorDto = userService.getUserById(organization.creatorId)
return ResponseEntity.ok(organization.toDto(creatorDto))
return ResponseEntity.ok(organization.toDto())
}
}
85 changes: 85 additions & 0 deletions src/main/resources/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,49 @@ paths:
description: Invalid input
x-stoplight:
id: 9ffjy7o77jc41
get:
x-spring-paginated: true
summary: Get paginated models
security:
- bearerAuth: [ ]
tags:
- model
operationId: getModels
parameters:
- name: page
in: query
required: false
schema:
type: integer
default: 0
- name: size
in: query
required: false
schema:
type: integer
default: 10
responses:
'200':
description: Paginated list of models
content:
application/json:
schema:
type: object
properties:
content:
type: array
items:
$ref: '#/components/schemas/Model'
totalElements:
type: integer
totalPages:
type: integer
pageSize:
type: integer
pageNumber:
type: integer
'400':
description: Invalid input
'/v1/models/{id}':
get:
summary: Get a Model
Expand Down Expand Up @@ -102,6 +145,44 @@ paths:
description: Model not found
'500':
description: Internal Server Error
/v1/models/{modelId}/organizations:
put:
summary: Update organizations for a model
operationId: updateModelOrganizations
tags:
- Model
parameters:
- name: modelId
in: path
required: true
schema:
type: integer
format: int64
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
organizationIds:
type: array
items:
type: integer
format: int64
responses:
'200':
description: Organizations updated successfully
content:
application/json:
schema:
type: object
properties:
message:
type: string
example: Organizations updated successfully
'404':
description: Model or Organization not found
'/v1/datasets/{id}':
get:
summary: Get a Dataset
Expand Down Expand Up @@ -264,6 +345,10 @@ components:
type: array
items:
$ref: '#/components/schemas/Feature'
organizations:
type: array
items:
$ref: '#/components/schemas/Organization'
visibility:
type: string
enum:
Expand Down

0 comments on commit 60bfebe

Please sign in to comment.