Skip to content

Commit

Permalink
Use custom queries for getting users for course, assignments and orga…
Browse files Browse the repository at this point in the history
…nizations

This queries are different from the existing `get_users` one in a few
different ways:

- They use LMSUser / LMSCourse instead of grouping and user.

This avoid the de duplication step needed on the older tables.

- Use one query per case instead of one generic one.

While this leads to more code duplication the generic query has some
inefficiencies trying to cater to all cases.

We still don't handle filter by segments and cases where we filter by
multiple courses or assignments. Those will more complex and we'll
tackle them after including the roster concept on the existing one.
  • Loading branch information
marcospri committed Nov 20, 2024
1 parent 948d770 commit d265985
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 37 deletions.
29 changes: 25 additions & 4 deletions lms/services/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,20 @@ def __init__( # noqa: PLR0913

def get_request_assignment(self, request) -> Assignment:
"""Get and authorize an assignment for the given request."""
assigment_id = request.matchdict.get(
"assignment_id"
) or request.parsed_params.get("assignment_id")
# Requests that are scoped to one assignment on the URL parameter
assigment_id = request.matchdict.get("assignment_id")
if not assigment_id:
# Request that are scoped to a single assignment but as a query parameter
assigment_id = request.parsed_params.get("assignment_id")

if (
not assigment_id
and request.parsed_params.get("assignment_ids")
and len(request.parsed_params["assignment_ids"]) == 1
):
# Request that take a list of assignments, but we only recieved one, the requests is scoped to that one assignment
assigment_id = request.parsed_params["assignment_ids"][0]

assignment = self._assignment_service.get_by_id(assigment_id)
if not assignment:
raise HTTPNotFound()
Expand All @@ -64,7 +75,17 @@ def get_request_assignment(self, request) -> Assignment:

def get_request_course(self, request):
"""Get and authorize a course for the given request."""
course = self._course_service.get_by_id(request.matchdict["course_id"])
# Requests that are scoped to one course on the URL parameter
course_id = request.matchdict.get("course_id")
if (
not course_id
and request.parsed_params.get("course_ids")
and len(request.parsed_params["course_ids"]) == 1
):
# Request that take a list of courses, but we only recieved one, the requests is scoped to that one course
course_id = request.parsed_params["course_ids"][0]

course = self._course_service.get_by_id(course_id)
if not course:
raise HTTPNotFound()

Expand Down
103 changes: 103 additions & 0 deletions lms/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
AssignmentMembership,
Grouping,
GroupingMembership,
LMSCourse,
LMSCourseMembership,
LMSUser,
LMSUserApplicationInstance,
LMSUserAssignmentMembership,
LTIParams,
LTIRole,
LTIUser,
Expand Down Expand Up @@ -139,6 +142,106 @@ def _user_search_query(self, application_instance_id, user_id) -> Select:

return query

def get_users_for_assignment(
self,
role_scope: RoleScope,
role_type: RoleType,
assignment_id: int,
h_userids: list[str] | None = None,
):
"""Get the users that belong to one assignment."""
query = (
select(LMSUser)
.distinct()
.join(
LMSUserAssignmentMembership,
LMSUserAssignmentMembership.lms_user_id == LMSUser.id,
)
.where(
LMSUserAssignmentMembership.assignment_id == assignment_id,
LMSUserAssignmentMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
),
)
)
if h_userids:
query = query.where(LMSUser.h_userid.in_(h_userids))

return query.order_by(LMSUser.display_name, LMSUser.id)

def get_users_for_course(
self,
role_scope: RoleScope,
role_type: RoleType,
course_id: int,
h_userids: list[str] | None = None,
):
"""Get the users that belong to one course."""
query = (
select(LMSUser)
.distinct()
.join(
LMSCourseMembership,
LMSCourseMembership.lms_user_id == LMSUser.id,
)
.join(LMSCourse, LMSCourse.id == LMSCourseMembership.lms_course_id)
# course_id is the PK on Grouping, we need to join with LMSCourse by authority_provided_id
.join(
Grouping,
Grouping.authority_provided_id == LMSCourse.h_authority_provided_id,
)
.where(
Grouping.id == course_id,
LMSCourseMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
),
)
)
if h_userids:
query = query.where(LMSUser.h_userid.in_(h_userids))

return query.order_by(LMSUser.display_name, LMSUser.id)

def get_users_for_organization( # noqa: PLR0913
self,
role_scope: RoleScope,
role_type: RoleType,
instructor_h_userid: str | None = None,
admin_organization_ids: list[int] | None = None,
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser]]:
candidate_courses = CourseService.courses_permission_check_query(
instructor_h_userid, admin_organization_ids, course_ids=None
).cte("candidate_courses")

query = (
select(LMSUser)
.distinct()
.join(LMSCourseMembership, LMSCourseMembership.lms_user_id == LMSUser.id)
.join(LMSCourse, LMSCourseMembership.lms_course_id == LMSCourse.id)
.join(
Grouping,
Grouping.authority_provided_id == LMSCourse.h_authority_provided_id,
)
.join(candidate_courses, candidate_courses.c[0] == Grouping.id)
.where(
LMSCourseMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
)
)
)

if h_userids:
query = query.where(LMSUser.h_userid.in_(h_userids))

return query.order_by(LMSUser.display_name, LMSUser.id)

def get_users( # noqa: PLR0913, PLR0917
self,
role_scope: RoleScope,
Expand Down
42 changes: 42 additions & 0 deletions lms/views/dashboard/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,51 @@ def _students_query(
segment_authority_provided_ids: list[str],
h_userids: list[str] | None = None,
):
course_ids = self.request.parsed_params.get("course_ids")
# Single assigment fetch
if (
assignment_ids
and len(assignment_ids) == 1
and not segment_authority_provided_ids
):
# Fetch the assignment to be sure the current user has access to it.
assignment = self.dashboard_service.get_request_assignment(self.request)

return self.user_service.get_users_for_assignment(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
assignment_id=assignment.id,
h_userids=h_userids,
)

# Single course fetch
if course_ids and len(course_ids) == 1 and not segment_authority_provided_ids:
# Fetch the course to be sure the current user has access to it.
course = self.dashboard_service.get_request_course(self.request)

return self.user_service.get_users_for_course(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
course_id=course.id,
h_userids=h_userids,
)

admin_organizations = self.dashboard_service.get_request_admin_organizations(
self.request
)
# Full organization fetch
if not course_ids and not assignment_ids and not segment_authority_provided_ids:
return self.user_service.get_users_for_organization(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
h_userids=h_userids,
# Users the current user has access to see
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
admin_organization_ids=[org.id for org in admin_organizations],
)

return self.user_service.get_users(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/lms/services/dashboard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def test_get_request_assignment_for_admin(

assert svc.get_request_assignment(pyramid_request)

def test_get_request_assignment_for_parsed_params_assignment_id(
self, pyramid_request, assignment_service, svc
):
pyramid_request.parsed_params = {"assignment_ids": [sentinel.parsed_params_id]}

svc.get_request_assignment(pyramid_request)

assignment_service.get_by_id.assert_called_once_with(sentinel.parsed_params_id)

def test_get_request_course_404(
self,
pyramid_request,
Expand Down Expand Up @@ -94,6 +103,15 @@ def test_get_request_course_for_staff(

assert svc.get_request_course(pyramid_request)

def test_get_request_for_parsed_params_course_ids(
self, pyramid_request, course_service, svc
):
pyramid_request.parsed_params = {"course_ids": [sentinel.parsed_params_id]}

svc.get_request_course(pyramid_request)

course_service.get_by_id.assert_called_once_with(sentinel.parsed_params_id)

def test_get_request_course_for_admin(
self,
pyramid_request,
Expand Down
Loading

0 comments on commit d265985

Please sign in to comment.