Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-414): archive model #122

Merged
merged 8 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/kotlin/org/jaqpot/api/entity/Model.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import jakarta.validation.constraints.Size
import org.hibernate.annotations.JdbcTypeCode
import org.hibernate.annotations.SQLRestriction
import org.hibernate.type.SqlTypes
import java.time.OffsetDateTime

@Entity
class Model(
Expand Down Expand Up @@ -67,6 +68,11 @@ class Model(

val selectedFeatures: List<String>?,

@Column(nullable = false)
var archived: Boolean = false,

var archivedAt: OffsetDateTime? = null,

@Size(min = 3, max = 1000)
@Column(columnDefinition = "TEXT")
var tags: String?,
Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/org/jaqpot/api/mapper/ModelMapper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ fun Model.toDto(userDto: UserDto? = null, userCanEdit: Boolean? = null, isAdmin:
test = this.testScores?.map { it.toDto() },
crossValidation = this.crossValidationScores?.map { it.toDto() },
),
archived = this.archived,
archivedAt = this.archivedAt,
createdAt = this.createdAt,
updatedAt = this.updatedAt,
)
Expand Down
13 changes: 10 additions & 3 deletions src/main/kotlin/org/jaqpot/api/repository/ModelRepository.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ import org.springframework.data.jpa.repository.Query
import org.springframework.data.repository.CrudRepository
import org.springframework.data.repository.PagingAndSortingRepository
import org.springframework.data.repository.query.Param
import java.time.OffsetDateTime
import java.util.*


interface ModelRepository : PagingAndSortingRepository<Model, Long>, CrudRepository<Model, Long> {

fun findAllByCreatorId(creatorId: String, pageable: Pageable): Page<Model>
fun findAllByCreatorIdAndArchivedIsFalse(creatorId: String, pageable: Pageable): Page<Model>
fun findAllByCreatorIdAndArchivedIsTrue(creatorId: String, pageable: Pageable): Page<Model>

fun findOneByLegacyId(legacyId: String): Optional<Model>

@Query(
"""
SELECT m FROM Model m
WHERE m.visibility = 'ORG_SHARED'
WHERE m.visibility = 'ORG_SHARED'
AND m.archived = false
AND EXISTS (
SELECT 1 FROM m.sharedWithOrganizations o
JOIN o.organization org
Expand All @@ -35,6 +39,7 @@ interface ModelRepository : PagingAndSortingRepository<Model, Long>, CrudReposit
"""
SELECT m FROM Model m
WHERE m.visibility = 'ORG_SHARED'
AND m.archived = false
AND EXISTS (
SELECT 1 FROM m.sharedWithOrganizations o
JOIN o.organization org
Expand All @@ -49,7 +54,7 @@ interface ModelRepository : PagingAndSortingRepository<Model, Long>, CrudReposit
value = """
SELECT *, ts_rank_cd(textsearchable_index_col, to_tsquery(:query)) AS rank
FROM model, to_tsquery(:query) query
WHERE model.visibility = 'PUBLIC' AND textsearchable_index_col @@ query
WHERE model.visibility = 'PUBLIC' AND model.archived = false AND textsearchable_index_col @@ query
ORDER BY rank DESC
""",
// countQuery = """
Expand All @@ -70,4 +75,6 @@ interface ModelRepository : PagingAndSortingRepository<Model, Long>, CrudReposit
@Transactional
@Query("UPDATE Model m SET m.rawPreprocessor = NULL WHERE m.id = :id")
fun setRawPreprocessorToNull(@Param("id") id: Long?)

fun findAllByArchivedIsTrueAndArchivedAtBefore(date: OffsetDateTime): List<Model>
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import org.springframework.security.access.expression.method.MethodSecurityExpre
import org.springframework.stereotype.Component
import org.springframework.web.server.ResponseStatusException

@Component("partialFeatureUpdateAuthorizationLogic")
class PartialFeatureUpdateAuthorizationLogic(
@Component("modelUpdateAuthorizationLogic")
class ModelUpdateAuthorizationLogic(
private val modelRepository: ModelRepository,
private val authenticationFacade: AuthenticationFacade
) {
Expand All @@ -21,8 +21,6 @@ class PartialFeatureUpdateAuthorizationLogic(
throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
}



return authenticationFacade.userId == model.creatorId
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jaqpot.api.service.dataset

import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.transaction.Transactional
import org.jaqpot.api.DatasetApiDelegate
import org.jaqpot.api.mapper.toDto
import org.jaqpot.api.mapper.toGetDatasets200ResponseDto
Expand Down Expand Up @@ -52,6 +53,7 @@ class DatasetService(
return ResponseEntity.ok().body(datasets.toGetDatasets200ResponseDto(inputsMap, resultsMap))
}

@Transactional
@Scheduled(cron = "0 0 3 * * *" /* every day at 3:00 AM */)
fun purgeExpiredDatasets() {
logger.info { "Purging expired datasets" }
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jaqpot/api/service/model/DoaService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DoaService(private val doaRepository: DoaRepository, private val storageSe
return
}
logger.info { "Storing raw doa to storage for doa with id ${doa.id} and model ${doa.model.id}" }
if (storageService.storeDoa(doa)) {
if (storageService.storeRawDoa(doa)) {
logger.info { "Successfully moved raw doa to storage for doa ${doa.id} and model ${doa.model.id}" }
doaRepository.setRawDoaToNull(doa.id)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.springframework.web.server.ResponseStatusException
@Service
class FeatureService(private val modelRepository: ModelRepository, private val featureRepository: FeatureRepository) :
FeatureApiDelegate {
@PreAuthorize("@partialFeatureUpdateAuthorizationLogic.decide(#root, #modelId)")
@PreAuthorize("@modelUpdateAuthorizationLogic.decide(#root, #modelId)")
override fun partiallyUpdateModelFeature(
modelId: Long,
featureId: Long,
Expand Down
133 changes: 127 additions & 6 deletions src/main/kotlin/org/jaqpot/api/service/model/ModelService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ import org.springframework.data.domain.PageRequest
import org.springframework.data.domain.Sort
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity
import org.springframework.scheduling.annotation.Scheduled
import org.springframework.security.access.prepost.PostAuthorize
import org.springframework.security.access.prepost.PreAuthorize
import org.springframework.stereotype.Service
import org.springframework.web.server.ResponseStatusException
import org.springframework.web.servlet.support.ServletUriComponentsBuilder
import java.net.URI
import java.time.OffsetDateTime


private val logger = KotlinLogging.logger {}
const val JAQPOT_METADATA_KEY = "jaqpotMetadata"
const val JAQPOT_ROW_ID_KEY = "jaqpotRowId"
const val JAQPOT_ROW_LABEL_KEY = "jaqpotRowLabel"
Expand Down Expand Up @@ -72,6 +74,8 @@ class ModelService(
ModelTypeDto.R_TREE_CLASS,
ModelTypeDto.R_TREE_REGR
)
const val ARCHIVED_MODEL_EXPIRATION_DAYS = 30L
private val logger = KotlinLogging.logger {}
}

@PreAuthorize("hasAnyAuthority('admin', 'upci')")
Expand All @@ -89,7 +93,7 @@ class ModelService(
override fun getModels(page: Int, size: Int, sort: List<String>?): ResponseEntity<GetModels200ResponseDto> {
val creatorId = authenticationFacade.userId
val pageable = PageRequest.of(page, size, Sort.by(parseSortParameters(sort)))
val modelsPage = modelRepository.findAllByCreatorId(creatorId, pageable)
val modelsPage = modelRepository.findAllByCreatorIdAndArchivedIsFalse(creatorId, pageable)
val modelIdToUserMap = modelsPage.content.associateBy(
{ it.id!! },
{ userService.getUserById(it.creatorId).orElse(UserDto(it.creatorId)) }
Expand Down Expand Up @@ -213,6 +217,10 @@ class ModelService(
// TODO once there are no models with rawPreprocessor in the database, remove this
storeRawPreprocessorToStorage(model)

if (model.archived) {
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Model with id $modelId is archived")
}

val userId = authenticationFacade.userId
val toEntity = datasetDto.toEntity(
model,
Expand Down Expand Up @@ -363,6 +371,50 @@ class ModelService(
return ResponseEntity.ok(model.toDto(modelCreator, userCanEdit, isAdmin))
}

@WithRateLimitProtectionByUser(limit = 10, intervalInSeconds = 60)
@PreAuthorize("@modelUpdateAuthorizationLogic.decide(#root, #modelId)")
override fun archiveModel(modelId: Long): ResponseEntity<ArchiveModel200ResponseDto> {
val existingModel = modelRepository.findById(modelId).orElseThrow {
throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
}

if (existingModel.archived) {
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Model with id $modelId is already archived")
}

existingModel.archived = true
existingModel.archivedAt = OffsetDateTime.now()

modelRepository.save(existingModel)

return ResponseEntity(
ArchiveModel200ResponseDto(id = modelId, archivedAt = existingModel.archivedAt),
HttpStatus.OK
)
}

@WithRateLimitProtectionByUser(limit = 10, intervalInSeconds = 60)
@PreAuthorize("@modelUpdateAuthorizationLogic.decide(#root, #modelId)")
override fun unarchiveModel(modelId: Long): ResponseEntity<UnarchiveModel200ResponseDto> {
val existingModel = modelRepository.findById(modelId).orElseThrow {
throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $modelId not found")
}

if (!existingModel.archived) {
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Model with id $modelId is not archived")
}

existingModel.archived = false
existingModel.archivedAt = null

modelRepository.save(existingModel)

return ResponseEntity(
UnarchiveModel200ResponseDto(id = modelId),
HttpStatus.OK
)
}

@Cacheable(CacheKeys.SEARCH_MODELS)
override fun searchModels(
query: String,
Expand All @@ -380,13 +432,82 @@ class ModelService(
return ResponseEntity.ok(modelsPage.toGetModels200ResponseDto(modelIdToUserMap))
}

override fun getArchivedModels(
page: Int,
size: Int,
sort: List<String>?
): ResponseEntity<GetModels200ResponseDto> {
val userId = authenticationFacade.userId
val pageable = PageRequest.of(page, size, Sort.by(parseSortParameters(sort)))

val archivedModelsPage = modelRepository.findAllByCreatorIdAndArchivedIsTrue(userId, pageable)

val modelIdToUserMap = archivedModelsPage.content.associateBy(
{ it.id!! },
{ userService.getUserById(it.creatorId).orElse(UserDto(it.creatorId)) }
)

return ResponseEntity.ok().body(archivedModelsPage.toGetModels200ResponseDto(modelIdToUserMap))
}

@CacheEvict("searchModels", allEntries = true)
@PreAuthorize("hasAuthority('admin')")
override fun deleteModelById(id: Long): ResponseEntity<Unit> {
modelRepository.delete(modelRepository.findById(id).orElseThrow {
throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $id not found")
})
return ResponseEntity.noContent().build()
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "This endpoint is not supported")
// val model = modelRepository.findById(id).orElseThrow {
// throw ResponseStatusException(HttpStatus.NOT_FOUND, "Model with id $id not found")
// }
//
// deleteModel(model)
//
// return ResponseEntity.noContent().build()
}

private fun deleteModel(model: Model) {
logger.info { "Deleting model with id ${model.id}" }

if (model.doas.isNotEmpty()) {
model.doas.forEach {
logger.info { "Deleting DOA with id ${it.id} for model with id ${model.id}" }
val deletedRawDoa = storageService.deleteRawDoa(it)
logger.info { "Deleted raw DOA for model with id ${model.id}: $deletedRawDoa" }
}
}

logger.info { "Deleting raw preprocessor for model with id ${model.id}" }
val deletedRawPreprocessor = storageService.deleteRawPreprocessor(model)
logger.info { "Deleted raw preprocessor for model with id ${model.id}: $deletedRawPreprocessor" }

logger.info { "Deleting raw model for model with id ${model.id}" }
val deletedRawModel = storageService.deleteRawModel(model)
logger.info { "Deleted raw model for model with id ${model.id}: $deletedRawModel" }

modelRepository.delete(model)

logger.info { "Deleted model with id ${model.id}" }
}

@Transactional
@Scheduled(cron = "0 0 3 * * *" /* every day at 3:00 AM */)
fun purgeExpiredArchivedModels() {
logger.info { "Purging expired archived models" }

val expiredArchivedModels = modelRepository.findAllByArchivedIsTrueAndArchivedAtBefore(
OffsetDateTime.now().minusDays(ARCHIVED_MODEL_EXPIRATION_DAYS)
)

var deletionCount = 0

expiredArchivedModels.forEach {
try {
this.deleteModel(it)
deletionCount++
} catch (e: Exception) {
logger.error(e) { "Failed to delete model with id ${it.id}" }
}
}

logger.info { "Purged $deletionCount expired archived models" }
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PredictionService(
dataset.result = results
dataset.executionFinishedAt = OffsetDateTime.now()
datasetRepository.save(dataset)
if (storageService.storeDataset(dataset)) {
if (storageService.storeRawDataset(dataset)) {
datasetRepository.setDatasetInputAndResultToNull(dataset.id)
}
}
Expand All @@ -70,7 +70,7 @@ class PredictionService(
dataset.failureReason = err.toString()

datasetRepository.save(dataset)
if (storageService.storeDataset(dataset)) {
if (storageService.storeRawDataset(dataset)) {
datasetRepository.setDatasetInputAndResultToNull(dataset.id)
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/main/kotlin/org/jaqpot/api/storage/Storage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@ interface Storage {
obj: ByteArray,
metadata: Map<String, String> = mapOf(),
)

fun deleteObject(
bucketName: String,
keyName: String
)
}
Loading
Loading