From a8490dd310dacbed02b1865a6edfbb3cda12f9e0 Mon Sep 17 00:00:00 2001 From: John Baldwin Date: Mon, 21 Sep 2020 16:58:11 +0200 Subject: [PATCH] Improved query performance for LearnerMetricsViewSet * Added prefetch_related to the User model in LearnerMetricsViewSet to query the related "CourseEnrollment" records * Updated test coverage to include users, courses, and enrollments for two sites in order to ensure we are retrieving only the data we want and avoid "site bleed". For this, created an initial generalized fixture in "tests/conftest.py". This initial fixture contains User, CourseOverview, CourseEnrollment, Organization, and Site records. If the version of "organizations" is our fork (supports multi-site) then OrganizationCourse and UserOrganizationMapping records are also filled This commit lays the groundwork for additional performance imporovement to reduce the number of queries needed in the serializers used for LearnerMetricsViewSet --- figures/views.py | 40 ++++- tests/conftest.py | 81 +++++++-- tests/views/helpers.py | 21 +++ tests/views/test_learner_metrics_viewset.py | 176 ++++++++++---------- 4 files changed, 210 insertions(+), 108 deletions(-) diff --git a/figures/views.py b/figures/views.py index 545a53d7a..43d4382f5 100644 --- a/figures/views.py +++ b/figures/views.py @@ -427,17 +427,45 @@ def query_param_course_keys(self): cid_list = self.request.GET.getlist('course') return [CourseKey.from_string(elem.replace(' ', '+')) for elem in cid_list] + def get_enrolled_users(self, site, course_keys): + """Get users enrolled in the specific courses for the specified site + + Args: + site: The site for which is being called + course_keys: list of Open edX course keys + + Returns: + Django QuerySet of users enrolled in the specified courses + + Note: + We should move this to `figures.sites` + """ + qs = figures.sites.get_users_for_site(site).filter( + courseenrollment__course_id__in=course_keys + ).prefetch_related('courseenrollment_set') + return qs + def get_queryset(self): """ - This function has a hack to filter users until we can get the `filter_class` - working + If one or more course keys are given as query parameters, then + * Course key filtering mode is ued. Any invalid keys are filtered out + from the list + * If no valid course keys are found, then an empty list is returned from + this view """ site = django.contrib.sites.shortcuts.get_current_site(self.request) - queryset = figures.sites.get_users_for_site(site) - course_keys = self.query_param_course_keys() + try: + course_keys = self.query_param_course_keys() + except InvalidKeyError: + raise NotFound() if course_keys: - queryset = figures.sites.users_enrolled_in_courses(course_keys) - return queryset + site_course_keys = figures.sites.get_course_keys_for_site(site) + if not set(course_keys).issubset(set(site_course_keys)): + raise NotFound() + return self.get_enrolled_users(site=site, course_keys=course_keys) + else: + return figures.sites.get_users_for_site(site).prefetch_related( + 'courseenrollment_set') def get_serializer_context(self): context = super(LearnerMetricsViewSet, self).get_serializer_context() diff --git a/tests/conftest.py b/tests/conftest.py index 0b1d3e4a7..92ce30bb0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,13 @@ from django.utils.timezone import utc from tests.factories import ( + CourseEnrollmentFactory, CourseOverviewFactory, OrganizationFactory, OrganizationCourseFactory, StudentModuleFactory, SiteFactory, + UserFactory, ) from tests.helpers import organizations_support_sites @@ -16,6 +18,10 @@ if organizations_support_sites(): from tests.factories import UserOrganizationMappingFactory + def map_users_to_org(org, users): + [UserOrganizationMappingFactory(user=user, + organization=org) for user in users] + @pytest.fixture @pytest.mark.django_db @@ -33,26 +39,79 @@ def sm_test_data(db): sm = [] for co in course_overviews: sm += [StudentModuleFactory(course_id=co.id, - created=created_date, - modified=modified_date, - ) for co in course_overviews] - + created=created_date, + modified=modified_date) for co in course_overviews] if organizations_support_sites(): org = OrganizationFactory(sites=[site]) for co in course_overviews: OrganizationCourseFactory(organization=org, course_id=str(co.id)) for rec in sm: - UserOrganizationMappingFactory(user=rec.student, - organization=org) + UserOrganizationMappingFactory(user=rec.student, organization=org) else: org = OrganizationFactory() + return dict(site=site, + organization=org, + course_overviews=course_overviews, + student_modules=sm, + year_for=year_for, + month_for=month_for) + + +@pytest.mark.django_db +def make_site_data(num_users=3, num_courses=2): + + site = SiteFactory() + if organizations_support_sites(): + org = OrganizationFactory(sites=[site]) + else: + org = OrganizationFactory() + courses = [CourseOverviewFactory() for i in range(num_courses)] + users = [UserFactory() for i in range(num_users)] + enrollments = [] + + users = [UserFactory() for i in range(num_users)] + + enrollments = [] + for i, user in enumerate(users): + # Create increasing number of enrollments for each user, maximum to one less + # than the number of courses + for j in range(i): + enrollments.append( + CourseEnrollmentFactory(course=courses[j-1], user=user) + ) + + if organizations_support_sites(): + for course in courses: + OrganizationCourseFactory(organization=org, + course_id=str(course.id)) + + # Set up user mappings + map_users_to_org(org, users) + return dict( site=site, - organization=org, - course_overviews=course_overviews, - student_modules=sm, - year_for=year_for, - month_for=month_for + org=org, + courses=courses, + users=users, + enrollments=enrollments, ) + + +@pytest.fixture +@pytest.mark.django_db +def lm_test_data(db, settings): + """Learner Metrics Test Data + + user0 not enrolled in any courses + user1 enrolled in 1 course + user2 enrolled in 2 courses + + """ + if organizations_support_sites(): + settings.FEATURES['FIGURES_IS_MULTISITE'] = True + + our_site_data = make_site_data() + other_site_data = make_site_data() + return dict(us=our_site_data, them=other_site_data) diff --git a/tests/views/helpers.py b/tests/views/helpers.py index bf51f0f27..103f2669e 100644 --- a/tests/views/helpers.py +++ b/tests/views/helpers.py @@ -4,6 +4,13 @@ from tests.factories import UserFactory +from tests.helpers import organizations_support_sites + + +if organizations_support_sites(): + from tests.factories import UserOrganizationMappingFactory + + def create_test_users(): ''' Creates four test users to test the combination of permissions @@ -31,3 +38,17 @@ def is_response_paginated(response_data): # If we can't get keys, wer'e certainly not paginated return False return set(keys) == set([u'count', u'next', u'previous', u'results']) + + +def make_caller(org): + """Convenience method to create the API caller user + """ + if organizations_support_sites(): + # TODO: set is_staff to False after we have test coverage + caller = UserFactory(is_staff=True) + UserOrganizationMappingFactory(user=caller, + organization=org, + is_amc_admin=True) + else: + caller = UserFactory(is_staff=True) + return caller diff --git a/tests/views/test_learner_metrics_viewset.py b/tests/views/test_learner_metrics_viewset.py index f6ea014b5..29c48398f 100644 --- a/tests/views/test_learner_metrics_viewset.py +++ b/tests/views/test_learner_metrics_viewset.py @@ -10,54 +10,14 @@ from figures.sites import get_user_ids_for_site from figures.views import LearnerMetricsViewSet -from tests.factories import ( - CourseEnrollmentFactory, - CourseOverviewFactory, - # LearnerCourseGradeMetricsFactory, - OrganizationFactory, - SiteFactory, - UserFactory, -) - from tests.helpers import organizations_support_sites from tests.views.base import BaseViewTest -from tests.views.helpers import is_response_paginated - -if organizations_support_sites(): - from tests.factories import UserOrganizationMappingFactory - - def map_users_to_org_site(caller, site, users): - org = OrganizationFactory(sites=[site]) - UserOrganizationMappingFactory(user=caller, - organization=org, - is_amc_admin=True) - [UserOrganizationMappingFactory(user=user, - organization=org) for user in users] - # return created objects that the test will need - return caller +from tests.views.helpers import is_response_paginated, make_caller -@pytest.fixture -def enrollment_test_data(): - """Stands up shared test data. We need to revisit this - """ - num_courses = 2 - site = SiteFactory() - course_overviews = [CourseOverviewFactory() for i in range(num_courses)] - # Create a number of enrollments for each course - enrollments = [] - for num_enroll, co in enumerate(course_overviews, 1): - enrollments += [CourseEnrollmentFactory( - course_id=co.id) for i in range(num_enroll)] - - # This is a convenience for the test method - users = [enrollment.user for enrollment in enrollments] - return dict( - site=site, - course_overviews=course_overviews, - enrollments=enrollments, - users=users, - ) +def filter_enrollments(enrollments, courses): + course_ids = [elem.id for elem in courses] + return [elem for elem in enrollments if elem.course_id in course_ids] @pytest.mark.django_db @@ -103,17 +63,6 @@ def setup(self, db, settings): settings.FEATURES['FIGURES_IS_MULTISITE'] = True super(TestLearnerMetricsViewSet, self).setup(db) - def make_caller(self, site, users): - """Convenience method to create the API caller user - """ - if organizations_support_sites(): - # TODO: set is_staff to False after we have test coverage - caller = UserFactory(is_staff=True) - map_users_to_org_site(caller=caller, site=site, users=users) - else: - caller = UserFactory(is_staff=True) - return caller - def make_request(self, monkeypatch, request_path, site, caller, action): """Convenience method to make the API request @@ -128,7 +77,7 @@ def make_request(self, monkeypatch, request_path, site, caller, action): view = self.view_class.as_view({'get': action}) return view(request) - def test_list_method_all(self, monkeypatch, enrollment_test_data): + def test_list_method_all(self, monkeypatch, lm_test_data): """Partial test coverage to check we get all site users Checks returned user ids against all user ids for the site @@ -137,16 +86,16 @@ def test_list_method_all(self, monkeypatch, enrollment_test_data): Does NOT check values in the `enrollments` key. This should be done as follow up work """ - site = enrollment_test_data['site'] - users = enrollment_test_data['users'] - - caller = self.make_caller(site, users) - other_site = SiteFactory() - assert site.domain != other_site.domain + us = lm_test_data['us'] + them = lm_test_data['them'] + our_courses = us['courses'] + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + assert len(our_courses) > 1 response = self.make_request(request_path=self.base_request_path, monkeypatch=monkeypatch, - site=site, + site=us['site'], caller=caller, action='list') @@ -155,32 +104,31 @@ def test_list_method_all(self, monkeypatch, enrollment_test_data): results = response.data['results'] # Check user ids result_ids = [obj['id'] for obj in results] - user_ids = get_user_ids_for_site(site=site) + user_ids = get_user_ids_for_site(site=us['site']) assert set(result_ids) == set(user_ids) # Spot check the first record top_keys = ['id', 'username', 'email', 'fullname', 'is_active', 'date_joined', 'enrollments'] assert set(results[0].keys()) == set(top_keys) - def test_course_param_single(self, monkeypatch, enrollment_test_data): + def test_course_param_single(self, monkeypatch, lm_test_data): """Test that the 'course' query parameter works """ - site = enrollment_test_data['site'] - users = enrollment_test_data['users'] - enrollments = enrollment_test_data['enrollments'] - course_overviews = enrollment_test_data['course_overviews'] + us = lm_test_data['us'] + them = lm_test_data['them'] + our_enrollments = us['enrollments'] + our_courses = us['courses'] - caller = self.make_caller(site, users) - other_site = SiteFactory() - assert site.domain != other_site.domain - assert len(course_overviews) > 1 - query_params = '?course={}'.format(str(course_overviews[0].id)) + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + assert len(our_courses) > 1 + query_params = '?course={}'.format(str(our_courses[0].id)) request_path = self.base_request_path + query_params response = self.make_request(request_path=request_path, monkeypatch=monkeypatch, - site=site, + site=us['site'], caller=caller, action='list') @@ -190,37 +138,83 @@ def test_course_param_single(self, monkeypatch, enrollment_test_data): # Check user ids result_ids = [obj['id'] for obj in results] - our_enrollments = [elem for elem in enrollments if elem.course_id == course_overviews[0].id] - expected_user_ids = [obj.user.id for obj in our_enrollments] + course_enrollments = [elem for elem in our_enrollments + if elem.course_id == our_courses[0].id] + expected_user_ids = [obj.user.id for obj in course_enrollments] assert set(result_ids) == set(expected_user_ids) - def test_course_param_multiple(self, monkeypatch, enrollment_test_data): + def test_course_param_multiple(self, monkeypatch, lm_test_data): """Test that the 'course' query parameter works """ - site = enrollment_test_data['site'] - users = enrollment_test_data['users'] - enrollments = enrollment_test_data['enrollments'] - course_overviews = enrollment_test_data['course_overviews'] - - caller = self.make_caller(site, users) - other_site = SiteFactory() - assert site.domain != other_site.domain - assert len(course_overviews) > 1 - query_params = '?course={}&course={}'.format(str(course_overviews[0].id), - str(course_overviews[1].id)) + us = lm_test_data['us'] + them = lm_test_data['them'] + our_enrollments = us['enrollments'] + our_courses = us['courses'] + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + assert len(our_courses) > 1 + + filtered_courses = our_courses[:2] + + # TODO: build params from 'filtered_courses' + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + str(our_courses[1].id)) request_path = self.base_request_path + query_params response = self.make_request(request_path=request_path, monkeypatch=monkeypatch, - site=site, + site=us['site'], caller=caller, action='list') + # Continue updating here assert response.status_code == status.HTTP_200_OK assert is_response_paginated(response.data) results = response.data['results'] # Check user ids result_ids = [obj['id'] for obj in results] - expected_user_ids = [obj.user.id for obj in enrollments] + expected_enrollments = filter_enrollments(enrollments=our_enrollments, + courses=filtered_courses) + expected_user_ids = [obj.user.id for obj in expected_enrollments] assert set(result_ids) == set(expected_user_ids) + + def invalid_course_ids_raise_404(self, monkeypatch, lm_test_data, query_params): + us = lm_test_data['us'] + them = lm_test_data['them'] + + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + request_path = self.base_request_path + query_params + response = self.make_request(request_path=request_path, + monkeypatch=monkeypatch, + site=us['site'], + caller=caller, + action='list') + return response.status_code == status.HTTP_404_NOT_FOUND + + def test_valid_and_course_param_from_other_site_invalid(self, + monkeypatch, + lm_test_data): + """Test that the 'course' query parameter works + + """ + our_courses = lm_test_data['us']['courses'] + their_courses = lm_test_data['them']['courses'] + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + str(their_courses[0].id)) + assert self.invalid_course_ids_raise_404(monkeypatch, + lm_test_data, + query_params) + + def test_valid_and_mangled_course_param_invalid(self, monkeypatch, lm_test_data): + """Test that the 'course' query parameter works + + """ + our_courses = lm_test_data['us']['courses'] + mangled_course_id = 'she-sell-seashells-by-the-seashore' + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + mangled_course_id) + assert self.invalid_course_ids_raise_404(monkeypatch, + lm_test_data, + query_params)