Skip to content

Commit

Permalink
Improved query performance for LearnerMetricsViewSet
Browse files Browse the repository at this point in the history
* Added prefetch_related to the User model in LearnerMetricsViewSet to
query the related "CourseEnrollment" records
* Updated test coverage to include users, courses, and enrollments for
two sites in order to ensure we are retrieving only the data we want and
avoid "site bleed". For this, created an initial generalized fixture in
"tests/conftest.py". This initial fixture contains User, CourseOverview,
CourseEnrollment, Organization, and Site records. If the version of
"organizations" is our fork (supports multi-site) then
OrganizationCourse and UserOrganizationMapping records are also filled

This commit lays the groundwork for additional performance imporovement
to reduce the number of queries needed in the serializers used for
LearnerMetricsViewSet
  • Loading branch information
johnbaldwin committed Sep 21, 2020
1 parent 592cf3b commit a8490dd
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 108 deletions.
40 changes: 34 additions & 6 deletions figures/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,17 +427,45 @@ def query_param_course_keys(self):
cid_list = self.request.GET.getlist('course')
return [CourseKey.from_string(elem.replace(' ', '+')) for elem in cid_list]

def get_enrolled_users(self, site, course_keys):
"""Get users enrolled in the specific courses for the specified site
Args:
site: The site for which is being called
course_keys: list of Open edX course keys
Returns:
Django QuerySet of users enrolled in the specified courses
Note:
We should move this to `figures.sites`
"""
qs = figures.sites.get_users_for_site(site).filter(
courseenrollment__course_id__in=course_keys
).prefetch_related('courseenrollment_set')
return qs

def get_queryset(self):
"""
This function has a hack to filter users until we can get the `filter_class`
working
If one or more course keys are given as query parameters, then
* Course key filtering mode is ued. Any invalid keys are filtered out
from the list
* If no valid course keys are found, then an empty list is returned from
this view
"""
site = django.contrib.sites.shortcuts.get_current_site(self.request)
queryset = figures.sites.get_users_for_site(site)
course_keys = self.query_param_course_keys()
try:
course_keys = self.query_param_course_keys()
except InvalidKeyError:
raise NotFound()
if course_keys:
queryset = figures.sites.users_enrolled_in_courses(course_keys)
return queryset
site_course_keys = figures.sites.get_course_keys_for_site(site)
if not set(course_keys).issubset(set(site_course_keys)):
raise NotFound()
return self.get_enrolled_users(site=site, course_keys=course_keys)
else:
return figures.sites.get_users_for_site(site).prefetch_related(
'courseenrollment_set')

def get_serializer_context(self):
context = super(LearnerMetricsViewSet, self).get_serializer_context()
Expand Down
81 changes: 70 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
from django.utils.timezone import utc

from tests.factories import (
CourseEnrollmentFactory,
CourseOverviewFactory,
OrganizationFactory,
OrganizationCourseFactory,
StudentModuleFactory,
SiteFactory,
UserFactory,
)
from tests.helpers import organizations_support_sites


if organizations_support_sites():
from tests.factories import UserOrganizationMappingFactory

def map_users_to_org(org, users):
[UserOrganizationMappingFactory(user=user,
organization=org) for user in users]


@pytest.fixture
@pytest.mark.django_db
Expand All @@ -33,26 +39,79 @@ def sm_test_data(db):
sm = []
for co in course_overviews:
sm += [StudentModuleFactory(course_id=co.id,
created=created_date,
modified=modified_date,
) for co in course_overviews]

created=created_date,
modified=modified_date) for co in course_overviews]

if organizations_support_sites():
org = OrganizationFactory(sites=[site])
for co in course_overviews:
OrganizationCourseFactory(organization=org, course_id=str(co.id))
for rec in sm:
UserOrganizationMappingFactory(user=rec.student,
organization=org)
UserOrganizationMappingFactory(user=rec.student, organization=org)
else:
org = OrganizationFactory()

return dict(site=site,
organization=org,
course_overviews=course_overviews,
student_modules=sm,
year_for=year_for,
month_for=month_for)


@pytest.mark.django_db
def make_site_data(num_users=3, num_courses=2):

site = SiteFactory()
if organizations_support_sites():
org = OrganizationFactory(sites=[site])
else:
org = OrganizationFactory()
courses = [CourseOverviewFactory() for i in range(num_courses)]
users = [UserFactory() for i in range(num_users)]
enrollments = []

users = [UserFactory() for i in range(num_users)]

enrollments = []
for i, user in enumerate(users):
# Create increasing number of enrollments for each user, maximum to one less
# than the number of courses
for j in range(i):
enrollments.append(
CourseEnrollmentFactory(course=courses[j-1], user=user)
)

if organizations_support_sites():
for course in courses:
OrganizationCourseFactory(organization=org,
course_id=str(course.id))

# Set up user mappings
map_users_to_org(org, users)

return dict(
site=site,
organization=org,
course_overviews=course_overviews,
student_modules=sm,
year_for=year_for,
month_for=month_for
org=org,
courses=courses,
users=users,
enrollments=enrollments,
)


@pytest.fixture
@pytest.mark.django_db
def lm_test_data(db, settings):
"""Learner Metrics Test Data
user0 not enrolled in any courses
user1 enrolled in 1 course
user2 enrolled in 2 courses
"""
if organizations_support_sites():
settings.FEATURES['FIGURES_IS_MULTISITE'] = True

our_site_data = make_site_data()
other_site_data = make_site_data()
return dict(us=our_site_data, them=other_site_data)
21 changes: 21 additions & 0 deletions tests/views/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

from tests.factories import UserFactory

from tests.helpers import organizations_support_sites


if organizations_support_sites():
from tests.factories import UserOrganizationMappingFactory


def create_test_users():
'''
Creates four test users to test the combination of permissions
Expand Down Expand Up @@ -31,3 +38,17 @@ def is_response_paginated(response_data):
# If we can't get keys, wer'e certainly not paginated
return False
return set(keys) == set([u'count', u'next', u'previous', u'results'])


def make_caller(org):
"""Convenience method to create the API caller user
"""
if organizations_support_sites():
# TODO: set is_staff to False after we have test coverage
caller = UserFactory(is_staff=True)
UserOrganizationMappingFactory(user=caller,
organization=org,
is_amc_admin=True)
else:
caller = UserFactory(is_staff=True)
return caller
Loading

0 comments on commit a8490dd

Please sign in to comment.