Skip to content

Commit

Permalink
Take course/assignment ID on get_request_XXX methods
Browse files Browse the repository at this point in the history
Instead of trying different locations in secession (matchdict, parsed
params...) pass the right location directly form the view.
  • Loading branch information
marcospri committed Nov 20, 2024
1 parent d265985 commit 61eb886
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 91 deletions.
28 changes: 2 additions & 26 deletions lms/services/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,8 @@ def __init__( # noqa: PLR0913
self._organization_service = organization_service
self._h_authority = h_authority

def get_request_assignment(self, request) -> Assignment:
def get_request_assignment(self, request, assigment_id: int) -> Assignment:
"""Get and authorize an assignment for the given request."""
# Requests that are scoped to one assignment on the URL parameter
assigment_id = request.matchdict.get("assignment_id")
if not assigment_id:
# Request that are scoped to a single assignment but as a query parameter
assigment_id = request.parsed_params.get("assignment_id")

if (
not assigment_id
and request.parsed_params.get("assignment_ids")
and len(request.parsed_params["assignment_ids"]) == 1
):
# Request that take a list of assignments, but we only recieved one, the requests is scoped to that one assignment
assigment_id = request.parsed_params["assignment_ids"][0]

assignment = self._assignment_service.get_by_id(assigment_id)
if not assignment:
raise HTTPNotFound()
Expand All @@ -73,18 +59,8 @@ def get_request_assignment(self, request) -> Assignment:

return assignment

def get_request_course(self, request):
def get_request_course(self, request, course_id: int):
"""Get and authorize a course for the given request."""
# Requests that are scoped to one course on the URL parameter
course_id = request.matchdict.get("course_id")
if (
not course_id
and request.parsed_params.get("course_ids")
and len(request.parsed_params["course_ids"]) == 1
):
# Request that take a list of courses, but we only recieved one, the requests is scoped to that one course
course_id = request.parsed_params["course_ids"][0]

course = self._course_service.get_by_id(course_id)
if not course:
raise HTTPNotFound()
Expand Down
8 changes: 6 additions & 2 deletions lms/views/dashboard/api/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def assignments(self) -> APIAssignments:
permission=Permissions.DASHBOARD_VIEW,
)
def assignment(self) -> APIAssignment:
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.request.matchdict["assignment_id"]
)
api_assignment = APIAssignment(
id=assignment.id,
title=assignment.title,
Expand Down Expand Up @@ -135,7 +137,9 @@ def course_assignments_metrics(self) -> APIAssignments:
self.request
)

course = self.dashboard_service.get_request_course(self.request)
course = self.dashboard_service.get_request_course(
self.request, self.request.matchdict["course_id"]
)
course_students = self.request.db.scalars(
self.user_service.get_users(
course_ids=[course.id],
Expand Down
4 changes: 3 additions & 1 deletion lms/views/dashboard/api/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def courses_metrics(self) -> APICourses:
permission=Permissions.DASHBOARD_VIEW,
)
def course(self) -> APICourse:
course = self.dashboard_service.get_request_course(self.request)
course = self.dashboard_service.get_request_course(
self.request, self.request.matchdict["course_id"]
)
return {
"id": course.id,
"title": course.lms_name,
Expand Down
8 changes: 6 additions & 2 deletions lms/views/dashboard/api/grading.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self, request) -> None:
schema=AutoGradeSyncSchema,
)
def create_grading_sync(self):
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.request.matchdict["assignment_id"]
)

if self.auto_grading_service.get_in_progress_sync(assignment):
self.request.response.status_int = 400
Expand Down Expand Up @@ -85,7 +87,9 @@ def create_grading_sync(self):
permission=Permissions.GRADE_ASSIGNMENT,
)
def get_grading_sync(self):
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.request.matchdict["assignment_id"]
)
if grading_sync := self.auto_grading_service.get_last_sync(assignment):
return self._serialize_grading_sync(grading_sync)

Expand Down
12 changes: 9 additions & 3 deletions lms/views/dashboard/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def students(self) -> APIStudents:
)
def students_metrics(self) -> APIStudents:
"""Fetch the stats for one particular assignment."""
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.request.parsed_params["assignment_id"]
)

request_segment_authority_provided_ids = self.request.parsed_params.get(
"segment_authority_provided_ids"
Expand Down Expand Up @@ -211,7 +213,9 @@ def _students_query(
and not segment_authority_provided_ids
):
# Fetch the assignment to be sure the current user has access to it.
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, assignment_ids[0]
)

return self.user_service.get_users_for_assignment(
role_scope=RoleScope.COURSE,
Expand All @@ -223,7 +227,9 @@ def _students_query(
# Single course fetch
if course_ids and len(course_ids) == 1 and not segment_authority_provided_ids:
# Fetch the course to be sure the current user has access to it.
course = self.dashboard_service.get_request_course(self.request)
course = self.dashboard_service.get_request_course(
self.request, course_id=course_ids[0]
)

return self.user_service.get_users_for_course(
role_scope=RoleScope.COURSE,
Expand Down
8 changes: 6 additions & 2 deletions lms/views/dashboard/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def assignment_show(self):
Authenticated via the LTIUser present in a cookie making this endpoint accessible directly in the browser.
"""
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.request.matchdict["assignment_id"]
)
self.request.context.js_config.enable_dashboard_mode(
AUTHORIZATION_DURATION_SECONDS
)
Expand All @@ -105,7 +107,9 @@ def course_show(self):
Authenticated via the LTIUser present in a cookie making this endpoint accessible directly in the browser.
"""
course = self.dashboard_service.get_request_course(self.request)
course = self.dashboard_service.get_request_course(
self.request, self.request.matchdict["course_id"]
)
self.request.context.js_config.enable_dashboard_mode(
AUTHORIZATION_DURATION_SECONDS
)
Expand Down
49 changes: 10 additions & 39 deletions tests/unit/lms/services/dashboard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,31 @@

class TestDashboardService:
def test_get_request_assignment_404(self, pyramid_request, assignment_service, svc):
pyramid_request.matchdict["assignment_id"] = sentinel.id
assignment_service.get_by_id.return_value = None

with pytest.raises(HTTPNotFound):
svc.get_request_assignment(pyramid_request)
svc.get_request_assignment(pyramid_request, sentinel.id)

def test_get_request_assignment_403(self, pyramid_request, course_service, svc):
pyramid_request.matchdict["assignment_id"] = sentinel.id
course_service.is_member.return_value = False

with pytest.raises(HTTPUnauthorized):
svc.get_request_assignment(pyramid_request)
svc.get_request_assignment(pyramid_request, sentinel.id)

def test_get_request_assignment_for_staff(
self, pyramid_request, assignment_service, pyramid_config, svc
):
pyramid_config.testing_securitypolicy(permissive=True)
pyramid_request.matchdict["assignment_id"] = sentinel.id
assignment_service.is_member.return_value = False

assert svc.get_request_assignment(pyramid_request)
assert svc.get_request_assignment(pyramid_request, sentinel.id)

def test_get_request_assignment(
self, pyramid_request, course_service, svc, assignment_service
):
pyramid_request.matchdict["assignment_id"] = sentinel.id
course_service.is_member.return_value = True

assert svc.get_request_assignment(pyramid_request)
assert svc.get_request_assignment(pyramid_request, sentinel.id)

course_service.is_member.assert_called_once_with(
assignment_service.get_by_id.return_value.course,
Expand All @@ -62,55 +58,32 @@ def test_get_request_assignment_for_admin(
assignment_service.get_by_id.return_value = assignment
get_request_admin_organizations.return_value = [organization]

pyramid_request.matchdict["assignment_id"] = sentinel.id

assert svc.get_request_assignment(pyramid_request)

def test_get_request_assignment_for_parsed_params_assignment_id(
self, pyramid_request, assignment_service, svc
):
pyramid_request.parsed_params = {"assignment_ids": [sentinel.parsed_params_id]}

svc.get_request_assignment(pyramid_request)

assignment_service.get_by_id.assert_called_once_with(sentinel.parsed_params_id)
assert svc.get_request_assignment(pyramid_request, sentinel.id)

def test_get_request_course_404(
self,
pyramid_request,
course_service,
svc,
):
pyramid_request.matchdict["course_id"] = sentinel.id
course_service.get_by_id.return_value = None

with pytest.raises(HTTPNotFound):
svc.get_request_course(pyramid_request)
svc.get_request_course(pyramid_request, sentinel.id)

def test_get_request_course_403(self, pyramid_request, course_service, svc):
pyramid_request.matchdict["course_id"] = sentinel.id
course_service.is_member.return_value = False

with pytest.raises(HTTPUnauthorized):
svc.get_request_course(pyramid_request)
svc.get_request_course(pyramid_request, sentinel.id)

def test_get_request_course_for_staff(
self, pyramid_request, course_service, pyramid_config, svc
):
pyramid_config.testing_securitypolicy(permissive=True)
pyramid_request.matchdict["course_id"] = sentinel.id
course_service.is_member.return_value = False

assert svc.get_request_course(pyramid_request)

def test_get_request_for_parsed_params_course_ids(
self, pyramid_request, course_service, svc
):
pyramid_request.parsed_params = {"course_ids": [sentinel.parsed_params_id]}

svc.get_request_course(pyramid_request)

course_service.get_by_id.assert_called_once_with(sentinel.parsed_params_id)
assert svc.get_request_course(pyramid_request, sentinel.id)

def test_get_request_course_for_admin(
self,
Expand All @@ -123,15 +96,13 @@ def test_get_request_course_for_admin(
):
course_service.get_by_id.return_value = course
get_request_admin_organizations.return_value = [organization]
pyramid_request.matchdict["course_id"] = sentinel.id

assert svc.get_request_course(pyramid_request)
assert svc.get_request_course(pyramid_request, sentinel.id)

def test_get_request_course(self, pyramid_request, course_service, svc):
pyramid_request.matchdict["course_id"] = sentinel.id
course_service.is_member.return_value = True

assert svc.get_request_course(pyramid_request)
assert svc.get_request_course(pyramid_request, sentinel.id)

def test_add_dashboard_admin(self, svc, db_session):
admin = svc.add_dashboard_admin(
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/lms/views/dashboard/api/assignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_assignment(
response = views.assignment()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)

assert response == {
Expand All @@ -91,7 +91,7 @@ def test_assignment_with_auto_grading(
response = views.assignment()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)

assert response == {
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_assignment_with_groups(
response = views.assignment()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)
assignment_service.get_assignment_groups.assert_called_once_with(assignment)

Expand All @@ -146,7 +146,7 @@ def test_assignment_with_sections(
response = views.assignment()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)
assignment_service.get_assignment_sections.assert_called_once_with(assignment)

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/lms/views/dashboard/api/course_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def test_course(self, views, pyramid_request, dashboard_service):

response = views.course()

dashboard_service.get_request_course.assert_called_once_with(pyramid_request)
dashboard_service.get_request_course.assert_called_once_with(
pyramid_request, sentinel.id
)

assert response == {
"id": course.id,
Expand Down
15 changes: 10 additions & 5 deletions tests/unit/lms/views/dashboard/api/grading_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, sentinel

import pytest
from h_matchers import Any
Expand All @@ -20,6 +20,7 @@ def test_create_grading_sync_with_no_lms_users(
dashboard_service,
assignment,
):
pyramid_request.matchdict = {"assignment_id": sentinel.id}
dashboard_service.get_request_assignment.return_value = assignment
auto_grading_service.get_in_progress_sync.return_value = None
pyramid_request.parsed_params["grades"] = [
Expand All @@ -29,7 +30,7 @@ def test_create_grading_sync_with_no_lms_users(
views.create_grading_sync()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)
auto_grading_service.get_in_progress_sync.assert_called_once_with(
dashboard_service.get_request_assignment.return_value
Expand All @@ -44,12 +45,13 @@ def test_create_grading_sync_with_existing_sync(
dashboard_service,
assignment,
):
pyramid_request.matchdict = {"assignment_id": sentinel.id}
dashboard_service.get_request_assignment.return_value = assignment

views.create_grading_sync()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)
auto_grading_service.get_in_progress_sync.assert_called_once_with(
dashboard_service.get_request_assignment.return_value
Expand All @@ -65,6 +67,7 @@ def test_create_grading_sync(
assignment,
db_session,
):
pyramid_request.matchdict = {"assignment_id": sentinel.id}
pyramid_request.parsed_params["grades"] = [
{"h_userid": "STUDENT_1", "grade": 0.5},
{"h_userid": "STUDENT_2", "grade": 1},
Expand All @@ -78,7 +81,7 @@ def test_create_grading_sync(
response = views.create_grading_sync()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)
auto_grading_service.get_in_progress_sync.assert_called_once_with(
dashboard_service.get_request_assignment.return_value
Expand All @@ -105,11 +108,12 @@ def test_get_grading_sync(
dashboard_service,
grading_sync,
):
pyramid_request.matchdict = {"assignment_id": sentinel.id}
auto_grading_service.get_last_sync.return_value = grading_sync
response = views.get_grading_sync()

dashboard_service.get_request_assignment.assert_called_once_with(
pyramid_request
pyramid_request, sentinel.id
)
auto_grading_service.get_last_sync.assert_called_once_with(
dashboard_service.get_request_assignment.return_value
Expand All @@ -132,6 +136,7 @@ def test_get_grading_sync(
def test_get_grading_sync_not_found(
self, auto_grading_service, views, pyramid_request
):
pyramid_request.matchdict = {"assignment_id": sentinel.id}
auto_grading_service.get_last_sync.return_value = None

response = views.get_grading_sync()
Expand Down
Loading

0 comments on commit 61eb886

Please sign in to comment.