Skip to content

Commit

Permalink
fix: Validate user authorized when mode ALL in enrollments [TECH-1589] (
Browse files Browse the repository at this point in the history
#15583)

* fix: Remove unnecessary capture mode check [TECH-1589]

* fix: Add test to validate mode ALL behavior [TECH-1589]

* fix: Remove unnecessary test param [TECH-1589]

* fix: Validate tracked entity returned in test [TECH-1589]

* fix: Validate user authorized when mode ALL in enrollments [TECH-1589]

* fix: Add org unit mode in enrollment service tests [TECH-1589]

* fix: Add org unit mode in enrollment service tests [TECH-1589]

* fix: Add org unit mode in enrollment service tests [TECH-1589]

* fix: Validate mode ALL on enrollments old API [TECH-1589]
  • Loading branch information
muilpp authored Nov 7, 2023
1 parent 376fc9b commit b6648ef
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static void validateOrgUnitMode(
private static void validateUserCanSearchOrgUnitModeALL(User user) throws BadRequestException {
if (user != null
&& !(user.isSuper()
|| user.isAuthorized(F_TRACKED_ENTITY_INSTANCE_SEARCH_IN_ALL_ORGUNITS.name()))) {
|| user.isAuthorized(F_TRACKED_ENTITY_INSTANCE_SEARCH_IN_ALL_ORGUNITS))) {
throw new BadRequestException(
"Current user is not authorized to query across all organisation units");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ACCESSIBLE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.CAPTURE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.CHILDREN;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.DESCENDANTS;
Expand Down Expand Up @@ -292,14 +291,6 @@ public void validate(EnrollmentQueryParams params) throws IllegalQueryException
throw new IllegalQueryException("Params cannot be null");
}

User user = params.getUser();

if (params.isOrganisationUnitMode(ACCESSIBLE)
&& (user == null || !user.hasDataViewOrganisationUnitWithFallback())) {
violation =
"Current user must be associated with at least one organisation unit when selection mode is ACCESSIBLE";
}

if (params.hasProgram() && params.hasTrackedEntityType()) {
violation = "Program and tracked entity cannot be specified simultaneously";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.tracker.export.OperationsParamsValidator.validateOrgUnitMode;

import java.util.HashSet;
import java.util.Set;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -72,6 +74,7 @@ public EnrollmentQueryParams map(EnrollmentOperationParams operationParams)

User user = currentUserService.getCurrentUser();
Set<OrganisationUnit> orgUnits = validateOrgUnits(operationParams.getOrgUnitUids(), user);
validateOrgUnitMode(operationParams.getOrgUnitMode(), user, program);

EnrollmentQueryParams params = new EnrollmentQueryParams();
params.setProgram(program);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.Set;
import javax.annotation.Nonnull;
import lombok.RequiredArgsConstructor;
import org.hisp.dhis.common.OrganisationUnitSelectionMode;
import org.hisp.dhis.common.QueryFilter;
import org.hisp.dhis.common.UID;
import org.hisp.dhis.feedback.BadRequestException;
Expand Down Expand Up @@ -76,9 +75,7 @@ public TrackedEntityQueryParams map(TrackedEntityOperationParams operationParams
validateTrackedEntityType(operationParams.getTrackedEntityTypeUid());

User user = operationParams.getUser();
Set<OrganisationUnit> orgUnits =
validateOrgUnits(
user, operationParams.getOrganisationUnits(), operationParams.getOrgUnitMode());
Set<OrganisationUnit> orgUnits = validateOrgUnits(user, operationParams.getOrganisationUnits());

TrackedEntityQueryParams params = new TrackedEntityQueryParams();
mapAttributeFilters(params, operationParams.getFilters());
Expand Down Expand Up @@ -135,8 +132,7 @@ private void mapAttributeFilters(
}
}

private Set<OrganisationUnit> validateOrgUnits(
User user, Set<String> orgUnitIds, OrganisationUnitSelectionMode orgUnitMode)
private Set<OrganisationUnit> validateOrgUnits(User user, Set<String> orgUnitIds)
throws BadRequestException, ForbiddenException {
Set<OrganisationUnit> orgUnits = new HashSet<>();
for (String orgUnitUid : orgUnitIds) {
Expand All @@ -156,10 +152,6 @@ private Set<OrganisationUnit> validateOrgUnits(
orgUnits.add(orgUnit);
}

if (orgUnitMode == OrganisationUnitSelectionMode.CAPTURE && user != null) {
orgUnits.addAll(user.getOrganisationUnits());
}

return orgUnits;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ACCESSIBLE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.SELECTED;
import static org.hisp.dhis.utils.Assertions.assertContainsOnly;
import static org.hisp.dhis.utils.Assertions.assertIsEmpty;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -118,6 +120,8 @@ void setUp() {
orgUnit2.getUid(), user.getTeiSearchOrganisationUnitsWithFallback()))
.thenReturn(true);

user.setTeiSearchOrganisationUnits(Set.of(orgUnit1, orgUnit2));

program = new Program();
program.setUid(PROGRAM_UID);
when(programService.getProgram(PROGRAM_UID)).thenReturn(program);
Expand All @@ -135,7 +139,8 @@ void setUp() {
@Test
void shouldMapWithoutFetchingNullParamsWhenParamsAreNotSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams operationParams = EnrollmentOperationParams.EMPTY;
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ACCESSIBLE).build();

mapper.map(operationParams);

Expand All @@ -151,10 +156,17 @@ void shouldMapOrgUnitsWhenOrgUnitUidsAreSpecified()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(ORG_UNIT_1_UID, ORG_UNIT_2_UID))
.orgUnitMode(SELECTED)
.programUid(program.getUid())
.build();
when(trackerAccessManager.canAccess(user, program, orgUnit1)).thenReturn(true);
when(trackerAccessManager.canAccess(user, program, orgUnit2)).thenReturn(true);
when(organisationUnitService.isInUserHierarchy(
orgUnit1.getUid(), user.getTeiSearchOrganisationUnitsWithFallback()))
.thenReturn(true);
when(organisationUnitService.isInUserHierarchy(
orgUnit2.getUid(), user.getTeiSearchOrganisationUnitsWithFallback()))
.thenReturn(true);

EnrollmentQueryParams params = mapper.map(operationParams);

Expand All @@ -166,6 +178,7 @@ void shouldThrowExceptionWhenOrgUnitNotFound() {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of("JW6BrFd0HLu", ORG_UNIT_2_UID))
.orgUnitMode(SELECTED)
.programUid(PROGRAM_UID)
.build();

Expand Down Expand Up @@ -193,7 +206,7 @@ void shouldThrowExceptionWhenOrgUnitNotInScope() {
@Test
void shouldMapProgramWhenProgramUidIsSpecified() throws BadRequestException, ForbiddenException {
EnrollmentOperationParams requestParams =
EnrollmentOperationParams.builder().programUid(PROGRAM_UID).build();
EnrollmentOperationParams.builder().programUid(PROGRAM_UID).orgUnitMode(ACCESSIBLE).build();

EnrollmentQueryParams params = mapper.map(requestParams);

Expand All @@ -214,7 +227,10 @@ void shouldThrowExceptionWhenProgramNotFound() {
void shouldMapTrackedEntityTypeWhenTrackedEntityTypeUidIsSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().trackedEntityTypeUid(TRACKED_ENTITY_TYPE_UID).build();
EnrollmentOperationParams.builder()
.trackedEntityTypeUid(TRACKED_ENTITY_TYPE_UID)
.orgUnitMode(ACCESSIBLE)
.build();

EnrollmentQueryParams params = mapper.map(operationParams);

Expand All @@ -235,7 +251,10 @@ void shouldThrowExceptionWhenTrackedEntityTypeNotFound() {
void shouldMapTrackedEntityWhenTrackedEntityUidIsSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().trackedEntityUid(TRACKED_ENTITY_UID).build();
EnrollmentOperationParams.builder()
.trackedEntityUid(TRACKED_ENTITY_UID)
.orgUnitMode(ACCESSIBLE)
.build();

EnrollmentQueryParams params = mapper.map(operationParams);

Expand All @@ -259,6 +278,7 @@ void shouldMapOrderInGivenOrder() throws BadRequestException, ForbiddenException
EnrollmentOperationParams.builder()
.orderBy("enrollmentDate", SortDirection.ASC)
.orderBy("created", SortDirection.DESC)
.orgUnitMode(ACCESSIBLE)
.build();

EnrollmentQueryParams params = mapper.map(operationParams);
Expand All @@ -273,9 +293,10 @@ void shouldMapOrderInGivenOrder() throws BadRequestException, ForbiddenException
@Test
void shouldMapNullOrderingParamsWhenNoOrderingParamsAreSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams requestParams = EnrollmentOperationParams.EMPTY;
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ACCESSIBLE).build();

EnrollmentQueryParams params = mapper.map(requestParams);
EnrollmentQueryParams params = mapper.map(operationParams);

assertIsEmpty(params.getOrder());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ void shouldReturnPaginatedEnrollmentsGivenNonDefaultPageSize()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.ASC)
.build();

Expand Down Expand Up @@ -553,6 +554,7 @@ void shouldReturnPaginatedEnrollmentsGivenNonDefaultPageSizeAndTotalPages()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.ASC)
.build();

Expand Down Expand Up @@ -590,7 +592,10 @@ void shouldOrderEnrollmentsByPrimaryKeyDescByDefault()
.toList();

EnrollmentOperationParams params =
EnrollmentOperationParams.builder().orgUnitUids(Set.of(orgUnit.getUid())).build();
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.build();

List<String> enrollments = getEnrollments(params);

Expand All @@ -603,6 +608,7 @@ void shouldOrderEnrollmentsByEnrolledAtAsc()
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.ASC)
.build();

Expand All @@ -617,6 +623,7 @@ void shouldOrderEnrollmentsByEnrolledAtDesc()
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.DESC)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ACCESSIBLE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ALL;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.CAPTURE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.SELECTED;
import static org.hisp.dhis.tracker.TrackerTestUtils.oneHourAfter;
import static org.hisp.dhis.tracker.TrackerTestUtils.oneHourBefore;
import static org.hisp.dhis.tracker.TrackerTestUtils.uids;
Expand Down Expand Up @@ -69,6 +72,7 @@
import org.hisp.dhis.trackedentityattributevalue.TrackedEntityAttributeValue;
import org.hisp.dhis.user.User;
import org.hisp.dhis.user.UserService;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;

Expand Down Expand Up @@ -410,7 +414,10 @@ void shouldGetEnrollmentsWhenUserHasReadAccessToProgramAndNoOrgUnitNorOrgUnitMod
manager.updateNoAcl(programA);

EnrollmentOperationParams params =
EnrollmentOperationParams.builder().programUid(programA.getUid()).build();
EnrollmentOperationParams.builder()
.programUid(programA.getUid())
.orgUnitMode(ACCESSIBLE)
.build();

List<Enrollment> enrollments = enrollmentService.getEnrollments(params);

Expand Down Expand Up @@ -447,6 +454,7 @@ void shouldGetEnrollmentWhenEnrollmentsAndOtherParamsAreSpecified()
EnrollmentOperationParams.builder()
.programUid(programA.getUid())
.enrollmentUids(Set.of(enrollmentA.getUid()))
.orgUnitMode(ACCESSIBLE)
.build();

List<Enrollment> enrollments = enrollmentService.getEnrollments(params);
Expand All @@ -464,6 +472,7 @@ void shouldGetEnrollmentsByTrackedEntityWhenUserHasAccessToTrackedEntityType()
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(trackedEntityA.getOrganisationUnit().getUid()))
.orgUnitMode(SELECTED)
.trackedEntityUid(trackedEntityA.getUid())
.build();

Expand All @@ -485,6 +494,7 @@ void shouldFailGettingEnrollmentsByTrackedEntityWhenUserHasNoAccessToTrackedEnti
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(trackedEntityA.getOrganisationUnit().getUid()))
.orgUnitMode(SELECTED)
.trackedEntityUid(trackedEntityA.getUid())
.build();

Expand All @@ -501,6 +511,7 @@ void shouldReturnEnrollmentIfEnrollmentWasUpdatedBeforePassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.lastUpdated(oneHourBeforeLastUpdated)
.build();

Expand All @@ -517,6 +528,7 @@ void shouldReturnEmptyIfEnrollmentWasUpdatedAfterPassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.lastUpdated(oneHourAfterLastUpdated)
.build();

Expand All @@ -534,6 +546,7 @@ void shouldReturnEnrollmentIfEnrollmentStartedBeforePassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programStartDate(oneHourBeforeEnrollmentDate)
.build();
Expand All @@ -552,6 +565,7 @@ void shouldReturnEmptyIfEnrollmentStartedAfterPassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programStartDate(oneHourAfterEnrollmentDate)
.build();
Expand All @@ -570,6 +584,7 @@ void shouldReturnEnrollmentIfEnrollmentEndedAfterPassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programEndDate(oneHourAfterEnrollmentDate)
.build();
Expand All @@ -588,6 +603,7 @@ void shouldReturnEmptyIfEnrollmentEndedBeforePassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programEndDate(oneHourBeforeEnrollmentDate)
.build();
Expand All @@ -597,6 +613,31 @@ void shouldReturnEmptyIfEnrollmentEndedBeforePassedDateAndTime()
assertIsEmpty(enrollments);
}

@Test
void shouldFailWhenOrgUnitModeAllAndUserNotAuthorized() {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ALL).build();

BadRequestException exception =
Assertions.assertThrows(
BadRequestException.class, () -> enrollmentService.getEnrollments(operationParams));
Assertions.assertEquals(
"Current user is not authorized to query across all organisation units",
exception.getMessage());
}

@Test
void shouldReturnAllEnrollmentsWhenOrgUnitModeAllAndUserAuthorized()
throws ForbiddenException, BadRequestException, NotFoundException {
injectSecurityContext(admin);

EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ALL).build();

List<Enrollment> enrollments = enrollmentService.getEnrollments(operationParams);
assertContainsOnly(List.of(enrollmentA, enrollmentB, enrollmentChildA), enrollments);
}

private static List<String> attributeUids(Enrollment enrollment) {
return enrollment.getTrackedEntity().getTrackedEntityAttributeValues().stream()
.map(v -> v.getAttribute().getUid())
Expand Down
Loading

0 comments on commit b6648ef

Please sign in to comment.