Skip to content

Commit

Permalink
feat(authz): add support for ids params in get authorized entities en…
Browse files Browse the repository at this point in the history
…dpoint (#1055)
  • Loading branch information
bobeal authored Nov 30, 2023
1 parent 1dadb88 commit 4dc1319
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class EnabledAuthorizationService(
sub,
accessRights,
entitiesQuery.typeSelection,
entitiesQuery.paginationQuery.limit,
entitiesQuery.paginationQuery.offset
entitiesQuery.ids,
entitiesQuery.paginationQuery
).bind()

// for each entity user is admin of, retrieve the full details of rights other users have on it
Expand Down Expand Up @@ -133,7 +133,8 @@ class EnabledAuthorizationService(
val count = entityAccessRightsService.getSubjectAccessRightsCount(
sub,
accessRights,
entitiesQuery.typeSelection
entitiesQuery.typeSelection,
entitiesQuery.ids
).bind()

Pair(count, entitiesAccessControlWithSubjectRights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import com.egm.stellio.search.authorization.EntityAccessRights.SubjectRightInfo
import com.egm.stellio.search.service.EntityPayloadService
import com.egm.stellio.search.util.*
import com.egm.stellio.shared.config.ApplicationProperties
import com.egm.stellio.shared.model.APIException
import com.egm.stellio.shared.model.AccessDeniedException
import com.egm.stellio.shared.model.ResourceNotFoundException
import com.egm.stellio.shared.model.*
import com.egm.stellio.shared.util.*
import com.egm.stellio.shared.util.AccessRight.*
import com.egm.stellio.shared.util.AuthContextModel.AUTH_TERM_CLIENT_ID
Expand Down Expand Up @@ -163,9 +161,9 @@ class EntityAccessRightsService(
suspend fun getSubjectAccessRights(
sub: Option<Sub>,
accessRights: List<AccessRight>,
type: String? = null,
limit: Int,
offset: Int
type: EntityTypeSelection? = null,
ids: Set<URI>? = null,
paginationQuery: PaginationQuery,
): Either<APIException, List<EntityAccessRights>> = either {
val subjectUuids = subjectReferentialService.getSubjectAndGroupsUUID(sub).bind()
val isStellioAdmin = subjectReferentialService.hasStellioAdminRole(subjectUuids).bind()
Expand All @@ -177,15 +175,16 @@ class EntityAccessRightsService(
FROM entity_access_rights ear
LEFT JOIN entity_payload ep ON ear.entity_id = ep.entity_id
WHERE ${if (isStellioAdmin) "1 = 1" else "subject_id IN (:subject_uuids)" }
${if (accessRights.isNotEmpty()) " AND access_right in (:access_rights)" else ""}
${if (accessRights.isNotEmpty()) " AND access_right IN (:access_rights)" else ""}
${if (!type.isNullOrEmpty()) " AND ${buildTypeQuery(type)}" else ""}
${if (!ids.isNullOrEmpty()) " AND ear.entity_id IN (:entities_ids)" else ""}
ORDER BY entity_id
LIMIT :limit
OFFSET :offset;
""".trimIndent()
)
.bind("limit", limit)
.bind("offset", offset)
.bind("limit", paginationQuery.limit)
.bind("offset", paginationQuery.offset)
.let {
if (!isStellioAdmin)
it.bind("subject_uuids", subjectUuids)
Expand All @@ -196,6 +195,11 @@ class EntityAccessRightsService(
it.bind("access_rights", accessRights.map { it.attributeName })
else it
}
.let {
if (!ids.isNullOrEmpty())
it.bind("entities_ids", ids)
else it
}
.allToMappedList { rowToEntityAccessControl(it, isStellioAdmin) }
.groupBy { it.id }
// a user may have multiple rights on a given entity (e.g., through groups memberships)
Expand All @@ -214,7 +218,8 @@ class EntityAccessRightsService(
suspend fun getSubjectAccessRightsCount(
sub: Option<Sub>,
accessRights: List<AccessRight>,
type: String? = null
type: EntityTypeSelection? = null,
ids: Set<URI>? = null
): Either<APIException, Int> = either {
val subjectUuids = subjectReferentialService.getSubjectAndGroupsUUID(sub).bind()
val isStellioAdmin = subjectReferentialService.hasStellioAdminRole(subjectUuids).bind()
Expand All @@ -226,8 +231,9 @@ class EntityAccessRightsService(
FROM entity_access_rights ear
LEFT JOIN entity_payload ep ON ear.entity_id = ep.entity_id
WHERE ${if (isStellioAdmin) "1 = 1" else "subject_id IN (:subject_uuids)" }
${if (accessRights.isNotEmpty()) " AND access_right in (:access_rights)" else ""}
${if (accessRights.isNotEmpty()) " AND access_right IN (:access_rights)" else ""}
${if (!type.isNullOrEmpty()) " AND ${buildTypeQuery(type)}" else ""}
${if (!ids.isNullOrEmpty()) " AND ear.entity_id IN (:entities_ids)" else ""}
""".trimIndent()
)
.let {
Expand All @@ -240,6 +246,11 @@ class EntityAccessRightsService(
it.bind("access_rights", accessRights.map { it.attributeName })
else it
}
.let {
if (!ids.isNullOrEmpty())
it.bind("entities_ids", ids)
else it
}
.oneToResult { toInt(it["count"]) }
.bind()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ class EnabledAuthorizationServiceTests {
right = AccessRight.R_CAN_WRITE
)
).right()
coEvery { entityAccessRightsService.getSubjectAccessRightsCount(any(), any(), any()) } returns Either.Right(1)
coEvery {
entityAccessRightsService.getSubjectAccessRightsCount(any(), any(), any(), any())
} returns Either.Right(1)
coEvery {
entityAccessRightsService.getAccessRightsForEntities(any(), any())
} returns emptyMap<URI, Map<AccessRight, List<SubjectRightInfo>>>().right()
Expand Down Expand Up @@ -380,7 +382,9 @@ class EnabledAuthorizationServiceTests {
right = AccessRight.R_CAN_WRITE
)
).right()
coEvery { entityAccessRightsService.getSubjectAccessRightsCount(any(), any(), any()) } returns Either.Right(1)
coEvery {
entityAccessRightsService.getSubjectAccessRightsCount(any(), any(), any(), any())
} returns Either.Right(1)
coEvery {
entityAccessRightsService.getAccessRightsForEntities(any(), any())
} returns mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.egm.stellio.search.model.EntityPayload
import com.egm.stellio.search.service.EntityPayloadService
import com.egm.stellio.search.support.WithTimescaleContainer
import com.egm.stellio.shared.model.AccessDeniedException
import com.egm.stellio.shared.model.PaginationQuery
import com.egm.stellio.shared.util.*
import com.egm.stellio.shared.util.AuthContextModel.AUTHORIZATION_COMPOUND_CONTEXT
import com.egm.stellio.shared.util.AuthContextModel.AUTH_TERM_NAME
Expand Down Expand Up @@ -260,8 +261,7 @@ class EntityAccessRightsServiceTests : WithTimescaleContainer {
entityAccessRightsService.getSubjectAccessRights(
Some(subjectUuid),
emptyList(),
limit = 100,
offset = 0
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(1, it.size)
val entityAccessControl = it[0]
Expand Down Expand Up @@ -292,8 +292,7 @@ class EntityAccessRightsServiceTests : WithTimescaleContainer {
entityAccessRightsService.getSubjectAccessRights(
Some(subjectUuid),
emptyList(),
limit = 100,
offset = 0
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(2, it.size)
it.forEach { entityAccessControl ->
Expand Down Expand Up @@ -325,8 +324,7 @@ class EntityAccessRightsServiceTests : WithTimescaleContainer {
Some(subjectUuid),
emptyList(),
BEEHIVE_TYPE,
100,
0
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(1, it.size)
val entityAccessControl = it[0]
Expand All @@ -343,6 +341,72 @@ class EntityAccessRightsServiceTests : WithTimescaleContainer {
}
}

@Test
fun `it should get all entities an user has access to wrt ids`() = runTest {
val entityId03 = "urn:ngsi-ld:Entity:03".toUri()

createEntityPayload(entityId01, setOf(BEEHIVE_TYPE), AUTH_READ)
createEntityPayload(entityId02, setOf(BEEHIVE_TYPE))
createEntityPayload(entityId03, setOf(APIARY_TYPE))
entityAccessRightsService.setRoleOnEntity(subjectUuid, entityId01, AccessRight.R_CAN_WRITE).shouldSucceed()
entityAccessRightsService.setRoleOnEntity(subjectUuid, entityId03, AccessRight.R_CAN_WRITE).shouldSucceed()
entityAccessRightsService.setRoleOnEntity(UUID.randomUUID().toString(), entityId02, AccessRight.R_CAN_WRITE)
.shouldSucceed()

entityAccessRightsService.getSubjectAccessRights(
Some(subjectUuid),
emptyList(),
null,
setOf(entityId01, entityId02),
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(1, it.size)
assertEquals(entityId01, it[0].id)
}

entityAccessRightsService.getSubjectAccessRightsCount(
Some(subjectUuid),
emptyList(),
BEEHIVE_TYPE,
setOf(entityId01, entityId03)
).shouldSucceedWith {
assertEquals(1, it)
}
}

@Test
fun `it should get all entities an user has access to wrt ids and types`() = runTest {
val entityId03 = "urn:ngsi-ld:Entity:03".toUri()

createEntityPayload(entityId01, setOf(BEEHIVE_TYPE), AUTH_READ)
createEntityPayload(entityId02, setOf(BEEHIVE_TYPE))
createEntityPayload(entityId03, setOf(APIARY_TYPE))
entityAccessRightsService.setRoleOnEntity(subjectUuid, entityId01, AccessRight.R_CAN_WRITE).shouldSucceed()
entityAccessRightsService.setRoleOnEntity(subjectUuid, entityId03, AccessRight.R_CAN_WRITE).shouldSucceed()
entityAccessRightsService.setRoleOnEntity(UUID.randomUUID().toString(), entityId02, AccessRight.R_CAN_WRITE)
.shouldSucceed()

entityAccessRightsService.getSubjectAccessRights(
Some(subjectUuid),
emptyList(),
BEEHIVE_TYPE,
setOf(entityId01, entityId03),
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(1, it.size)
assertEquals(entityId01, it[0].id)
}

entityAccessRightsService.getSubjectAccessRightsCount(
Some(subjectUuid),
emptyList(),
BEEHIVE_TYPE,
setOf(entityId01, entityId03)
).shouldSucceedWith {
assertEquals(1, it)
}
}

@Test
fun `it should get all entities an user has access to wrt access rights`() = runTest {
val entityId03 = "urn:ngsi-ld:Entity:03".toUri()
Expand All @@ -358,8 +422,7 @@ class EntityAccessRightsServiceTests : WithTimescaleContainer {
entityAccessRightsService.getSubjectAccessRights(
Some(subjectUuid),
listOf(AccessRight.R_CAN_WRITE),
limit = 100,
offset = 0
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(1, it.size)
val entityAccessControl = it[0]
Expand Down Expand Up @@ -388,8 +451,7 @@ class EntityAccessRightsServiceTests : WithTimescaleContainer {
Some(subjectUuid),
emptyList(),
BEEHIVE_TYPE,
100,
0
paginationQuery = PaginationQuery(limit = 100, offset = 0)
).shouldSucceedWith {
assertEquals(1, it.size)
val entityAccessControl = it[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.test.context.ActiveProfiles
import java.net.URI
import java.time.Instant
import java.time.ZoneOffset
import java.time.ZonedDateTime

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.NONE, classes = [QueryService::class])
Expand All @@ -49,7 +47,7 @@ class QueryServiceTests {
@MockkBean
private lateinit var temporalEntityAttributeService: TemporalEntityAttributeService

private val now = Instant.now().atZone(ZoneOffset.UTC)
private val now = ngsiLdDateTime()

private val entityUri = "urn:ngsi-ld:BeeHive:TESTC".toUri()

Expand Down

0 comments on commit 4dc1319

Please sign in to comment.