diff --git a/.gitignore b/.gitignore index fbaf2392..06a54cfb 100644 --- a/.gitignore +++ b/.gitignore @@ -53,6 +53,7 @@ coverage.xml # Django stuff: *.log local_settings.py +devsite/devsite/.env # Flask stuff: instance/ diff --git a/devsite/devsite/.env.example b/devsite/devsite/.env.example new file mode 100644 index 00000000..20c3ed96 --- /dev/null +++ b/devsite/devsite/.env.example @@ -0,0 +1,8 @@ +# Figures devsite environment settings example file + +# Set to true to enable Figures multisite environment in devsite +FIGURES_IS_MULTISITE=true + +# Set synthetic data seed options +SEED_DAYS_BACK=60 +SEED_NUM_LEARNERS_PER_COURSE=25 diff --git a/devsite/devsite/seed.py b/devsite/devsite/seed.py index 3b386021..efff87e0 100644 --- a/devsite/devsite/seed.py +++ b/devsite/devsite/seed.py @@ -10,6 +10,7 @@ import faker import random +from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.sites.models import Site from django.db.utils import IntegrityError @@ -19,28 +20,45 @@ from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from student.models import CourseAccessRole, CourseEnrollment, UserProfile +from organizations.models import Organization, OrganizationCourse + from figures.compat import RELEASE_LINE, GeneratedCertificate from figures.models import ( CourseDailyMetrics, LearnerCourseGradeMetrics, SiteDailyMetrics, ) -from figures.helpers import as_course_key, as_datetime, days_from, prev_day +from figures.helpers import ( + as_course_key, + as_datetime, + days_from, + prev_day, + is_multisite, +) from figures.pipeline import course_daily_metrics as pipeline_cdm from figures.pipeline import site_daily_metrics as pipeline_sdm from devsite import cans +if is_multisite(): + # First trying this without capturing 'ImportError' + from organizations.models import UserOrganizationMapping + + FAKE = faker.Faker() LAST_DAY = days_from(datetime.datetime.now(), -2).replace(tzinfo=utc) -DAYS_BACK = 180 -NO_LEARNERS_PER_COURSE = 50 + +DAYS_BACK = settings.DEVSITE_SEED['DAYS_BACK'] +NUM_LEARNERS_PER_COURSE = settings.DEVSITE_SEED['NUM_LEARNERS_PER_COURSE'] # Quick and dirty debuging VERBOSE = False +FOO_ORG = 'FOO' + + def get_site(): """ In demo mode, we have just one site (for now) @@ -82,7 +100,7 @@ def seed_course_overviews(data=None): if not data: data = cans.COURSE_OVERVIEW_DATA # append with randomly generated course overviews to test pagination - new_courses = [generate_course_overview(i, org='FOO') for i in xrange(20)] + new_courses = [generate_course_overview(i, org=FOO_ORG) for i in xrange(20)] data += new_courses for rec in data: @@ -166,7 +184,7 @@ def seed_course_enrollments(): TODO: make the number of users variable """ for co in CourseOverview.objects.all(): - users = seed_users(cans.users.UserGenerator(NO_LEARNERS_PER_COURSE)) + users = seed_users(cans.users.UserGenerator(NUM_LEARNERS_PER_COURSE)) seed_course_enrollments_for_course(co.id, users, DAYS_BACK) @@ -318,14 +336,46 @@ def seed_lcgm_all(): seed_lcgm_for_course(**seed_args) +def hotwire_multisite(): + """ + This is a quick and dirty implementation of a single site in multisite mode + """ + params = dict( + name='Foo Organization', + short_name='FOO', + description='Foo org description', + logo=None, + active=True, + ) + org = Organization.objects.create(**params) + if is_multisite(): + org.sites = [get_site()] + org.save() + + for course in CourseOverview.objects.all(): + OrganizationCourse.objects.create(course_id=str(course.id), + organization=org, + active=True) + for user in get_user_model().objects.all(): + # For now, not setting is_amc_admin roles + UserOrganizationMapping.objects.create(user=user, + organization=org, + is_active=True) + + def wipe(): + print('Wiping synthetic data...') clear_non_admin_users() CourseEnrollment.objects.all().delete() StudentModule.objects.all().delete() CourseOverview.objects.all().delete() CourseDailyMetrics.objects.all().delete() SiteDailyMetrics.objects.all().delete() - LearnerCourseGradeMetrics.all().delete() + LearnerCourseGradeMetrics.objects.all().delete() + Organization.objects.all().delete() + OrganizationCourse.objects.all().delete() + if is_multisite(): + UserOrganizationMapping.objects.all().delete() def seed_all(): @@ -335,9 +385,14 @@ def seed_all(): seed_course_overviews() print("seeding users...") seed_users() + print("seeding course enrollments...") seed_course_enrollments() + if is_multisite(): + print("Hotwiring multisite...") + hotwire_multisite() + print("- skipping seeding seed_course_access_roles, broken") # print("seeding course enrollments...") # seed_course_access_roles() diff --git a/devsite/devsite/settings.py b/devsite/devsite/settings.py index a5e5c97e..7a08e138 100644 --- a/devsite/devsite/settings.py +++ b/devsite/devsite/settings.py @@ -9,6 +9,8 @@ import os import sys +import environ + from figures.settings.lms_production import ( update_webpack_loader, update_celerybeat_schedule, @@ -16,6 +18,16 @@ OPENEDX_RELEASE = os.environ.get('OPENEDX_RELEASE', 'HAWTHORN').upper() + +env = environ.Env( + FIGURES_IS_MULTISITE=(bool, False), + SEED_DAYS_BACK=(int, 60), + SEED_NUM_LEARNERS_PER_COURSE=(int, 25) +) + +environ.Env.read_env() + + if OPENEDX_RELEASE == 'GINKGO': MOCKS_DIR = 'mocks/ginkgo' else: @@ -199,9 +211,10 @@ # Included here for completeness in having this settings file match behavior in # the LMS settings CELERYBEAT_SCHEDULE = {} -FEATURES = {} +FEATURES = { + 'FIGURES_IS_MULTISITE': env('FIGURES_IS_MULTISITE') +} -FEATURES = {} # The LMS defines ``ENV_TOKENS`` to load settings declared in `lms.env.json` # We have an empty dict here to replicate behavior in the LMS @@ -214,3 +227,8 @@ INTERNAL_IPS = [ '127.0.0.1' ] + +DEVSITE_SEED = { + 'DAYS_BACK': env('SEED_DAYS_BACK'), + 'NUM_LEARNERS_PER_COURSE': env('SEED_NUM_LEARNERS_PER_COURSE') +} diff --git a/devsite/requirements/ginkgo.txt b/devsite/requirements/ginkgo.txt index 2692bc0a..16279179 100644 --- a/devsite/requirements/ginkgo.txt +++ b/devsite/requirements/ginkgo.txt @@ -35,6 +35,7 @@ django-webpack-loader==0.4.1 # appsembler/gingko/master users 0.4.1 django-model-utils==2.3.1 +django-environ==0.4.5 django-celery==3.2.1 jsonfield==1.0.3 # Version used in Ginkgo. Hawthorn uses version 2.0.2 diff --git a/devsite/requirements/hawthorn_community.txt b/devsite/requirements/hawthorn_community.txt index 78ab7589..173bbcac 100644 --- a/devsite/requirements/hawthorn_community.txt +++ b/devsite/requirements/hawthorn_community.txt @@ -33,6 +33,7 @@ django-countries==4.6.1 django-filter==1.0.4 django-webpack-loader==0.6.0 django-model-utils==3.0.0 +django-environ==0.4.5 jsonfield==2.0.2 diff --git a/devsite/requirements/hawthorn_multisite.txt b/devsite/requirements/hawthorn_multisite.txt index 9a448b63..396aa350 100644 --- a/devsite/requirements/hawthorn_multisite.txt +++ b/devsite/requirements/hawthorn_multisite.txt @@ -34,6 +34,7 @@ django-countries==4.6.1 django-filter==1.0.4 django-webpack-loader==0.6.0 django-model-utils==3.0.0 +django-environ==0.4.5 jsonfield==2.0.2 diff --git a/figures/models.py b/figures/models.py index 234bc59e..8fb08d60 100644 --- a/figures/models.py +++ b/figures/models.py @@ -192,11 +192,19 @@ class LearnerCourseGradeMetricsManager(models.Manager): """Custom model manager for LearnerCourseGrades model """ def most_recent_for_learner_course(self, user, course_id): - queryset = self.filter(user=user, course_id=str(course_id)) - if queryset: - return queryset.order_by('-date_for')[0] - else: - return None + """Gets the most recent record for the given user and course + + We have this because we implement sparse data, meaning we only create + new records when data has changed. this means that for a given course, + learners may not have the same "most recent date" + + This means we have to be careful of where we use this method in our + API as it costs a query per call. We will likely require restructuring + or augmenting our data if we need to bulk retrieve + """ + queryset = self.filter(user=user, + course_id=str(course_id)).order_by('-date_for') + return queryset[0] if queryset else None def most_recent_for_course(self, course_id): statement = """ \ diff --git a/figures/serializers.py b/figures/serializers.py index beffe632..e8f61503 100644 --- a/figures/serializers.py +++ b/figures/serializers.py @@ -798,7 +798,7 @@ class EnrollmentMetricsSerializerV2(serializers.ModelSerializer): course_id = serializers.CharField() date_enrolled = serializers.DateTimeField(source='created', format="%Y-%m-%d") - is_enrolled = serializers.BooleanField() + is_enrolled = serializers.BooleanField(source='is_active') progress_percent = serializers.SerializerMethodField() progress_details = serializers.SerializerMethodField() @@ -821,14 +821,6 @@ def to_representation(self, instance): user=instance.user, course_id=str(instance.course_id)) return super(EnrollmentMetricsSerializerV2, self).to_representation(instance) - def get_is_enrolled(self, obj): - """ - CourseEnrollment has to do some work to get this value - TODO: inspect CourseEnrollment._get_enrollment_state to see how we - can speed this up, avoiding construction of `CourseEnrollmentState` - """ - return CourseEnrollment.is_enrolled(obj.user, obj.course_id) - def get_progress_percent(self, obj): # pylint: disable=unused-argument value = self._lcgm.progress_percent if self._lcgm else 0 return float(Decimal(value).quantize(Decimal('.00'))) @@ -839,6 +831,25 @@ def get_progress_details(self, obj): # pylint: disable=unused-argument return self._lcgm.progress_details if self._lcgm else None +class LearnerMetricsListSerializer(serializers.ListSerializer): + """ + See if we need to add to class: # pylint: disable=abstract-method + """ + def __init__(self, instance=None, data=empty, **kwargs): + """instance is a queryset of users + + TODO: Ensure that we only have our own site's course keys + """ + self.site = kwargs['context'].get('site') + self.course_keys = kwargs['context'].get('course_keys') + + if not self.course_keys: + self.course_keys = figures.sites.get_course_keys_for_site(self.site) + + super(LearnerMetricsListSerializer, self).__init__( + instance=instance, data=data, **kwargs) + + class LearnerMetricsSerializer(serializers.ModelSerializer): fullname = serializers.CharField(source='profile.name', default=None) # enrollments = EnrollmentMetricsSerializerV2(source='courseenrollment_set', @@ -847,16 +858,17 @@ class LearnerMetricsSerializer(serializers.ModelSerializer): class Meta: model = get_user_model() + list_serializer_class = LearnerMetricsListSerializer fields = ('id', 'username', 'email', 'fullname', 'is_active', 'date_joined', 'enrollments') read_only_fields = fields def get_enrollments(self, user): - site_enrollments = figures.sites.get_course_enrollments_for_site( - self.context.get('site')) - user_enrollments = site_enrollments.filter(user=user) - course_keys = self.context.get('course_keys') - if course_keys: - user_enrollments = user_enrollments.filter(course_id__in=course_keys) + """ + Use the course ids identified in this serializer's list serializer to + filter enrollments + """ + user_enrollments = user.courseenrollment_set.filter( + course_id__in=self.parent.course_keys) return EnrollmentMetricsSerializerV2(user_enrollments, many=True).data diff --git a/figures/sites.py b/figures/sites.py index d61b8a95..857fc2f1 100644 --- a/figures/sites.py +++ b/figures/sites.py @@ -136,6 +136,13 @@ def get_organizations_for_site(site): def get_course_keys_for_site(site): + """ + + Developer note: We could improve this function with caching + Question is which is the most efficient way to know cache expiry + + We may also be able to reduce the queries here to also improve performance + """ if figures.helpers.is_multisite(): orgs = organizations.models.Organization.objects.filter(sites__in=[site]) org_courses = organizations.models.OrganizationCourse.objects.filter( diff --git a/figures/views.py b/figures/views.py index 545a53d7..d90d314d 100644 --- a/figures/views.py +++ b/figures/views.py @@ -427,17 +427,44 @@ def query_param_course_keys(self): cid_list = self.request.GET.getlist('course') return [CourseKey.from_string(elem.replace(' ', '+')) for elem in cid_list] + def get_enrolled_users(self, site, course_keys): + """Get users enrolled in the specific courses for the specified site + + Args: + site: The site for which is being called + course_keys: list of Open edX course keys + + Returns: + Django QuerySet of users enrolled in the specified courses + + Note: + We should move this to `figures.sites` + """ + qs = figures.sites.get_users_for_site(site).filter( + courseenrollment__course_id__in=course_keys + ).select_related('profile').prefetch_related('courseenrollment_set') + return qs + def get_queryset(self): """ - This function has a hack to filter users until we can get the `filter_class` - working + If one or more course keys are given as query parameters, then + * Course key filtering mode is ued. Any invalid keys are filtered out + from the list + * If no valid course keys are found, then an empty list is returned from + this view """ site = django.contrib.sites.shortcuts.get_current_site(self.request) - queryset = figures.sites.get_users_for_site(site) - course_keys = self.query_param_course_keys() - if course_keys: - queryset = figures.sites.users_enrolled_in_courses(course_keys) - return queryset + course_keys = figures.sites.get_course_keys_for_site(site) + try: + param_course_keys = self.query_param_course_keys() + except InvalidKeyError: + raise NotFound() + if param_course_keys: + if not set(param_course_keys).issubset(set(course_keys)): + raise NotFound() + else: + course_keys = param_course_keys + return self.get_enrolled_users(site=site, course_keys=course_keys) def get_serializer_context(self): context = super(LearnerMetricsViewSet, self).get_serializer_context() diff --git a/tests/conftest.py b/tests/conftest.py index 0b1d3e4a..92ce30bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,13 @@ from django.utils.timezone import utc from tests.factories import ( + CourseEnrollmentFactory, CourseOverviewFactory, OrganizationFactory, OrganizationCourseFactory, StudentModuleFactory, SiteFactory, + UserFactory, ) from tests.helpers import organizations_support_sites @@ -16,6 +18,10 @@ if organizations_support_sites(): from tests.factories import UserOrganizationMappingFactory + def map_users_to_org(org, users): + [UserOrganizationMappingFactory(user=user, + organization=org) for user in users] + @pytest.fixture @pytest.mark.django_db @@ -33,26 +39,79 @@ def sm_test_data(db): sm = [] for co in course_overviews: sm += [StudentModuleFactory(course_id=co.id, - created=created_date, - modified=modified_date, - ) for co in course_overviews] - + created=created_date, + modified=modified_date) for co in course_overviews] if organizations_support_sites(): org = OrganizationFactory(sites=[site]) for co in course_overviews: OrganizationCourseFactory(organization=org, course_id=str(co.id)) for rec in sm: - UserOrganizationMappingFactory(user=rec.student, - organization=org) + UserOrganizationMappingFactory(user=rec.student, organization=org) else: org = OrganizationFactory() + return dict(site=site, + organization=org, + course_overviews=course_overviews, + student_modules=sm, + year_for=year_for, + month_for=month_for) + + +@pytest.mark.django_db +def make_site_data(num_users=3, num_courses=2): + + site = SiteFactory() + if organizations_support_sites(): + org = OrganizationFactory(sites=[site]) + else: + org = OrganizationFactory() + courses = [CourseOverviewFactory() for i in range(num_courses)] + users = [UserFactory() for i in range(num_users)] + enrollments = [] + + users = [UserFactory() for i in range(num_users)] + + enrollments = [] + for i, user in enumerate(users): + # Create increasing number of enrollments for each user, maximum to one less + # than the number of courses + for j in range(i): + enrollments.append( + CourseEnrollmentFactory(course=courses[j-1], user=user) + ) + + if organizations_support_sites(): + for course in courses: + OrganizationCourseFactory(organization=org, + course_id=str(course.id)) + + # Set up user mappings + map_users_to_org(org, users) + return dict( site=site, - organization=org, - course_overviews=course_overviews, - student_modules=sm, - year_for=year_for, - month_for=month_for + org=org, + courses=courses, + users=users, + enrollments=enrollments, ) + + +@pytest.fixture +@pytest.mark.django_db +def lm_test_data(db, settings): + """Learner Metrics Test Data + + user0 not enrolled in any courses + user1 enrolled in 1 course + user2 enrolled in 2 courses + + """ + if organizations_support_sites(): + settings.FEATURES['FIGURES_IS_MULTISITE'] = True + + our_site_data = make_site_data() + other_site_data = make_site_data() + return dict(us=our_site_data, them=other_site_data) diff --git a/tests/views/helpers.py b/tests/views/helpers.py index bf51f0f2..103f2669 100644 --- a/tests/views/helpers.py +++ b/tests/views/helpers.py @@ -4,6 +4,13 @@ from tests.factories import UserFactory +from tests.helpers import organizations_support_sites + + +if organizations_support_sites(): + from tests.factories import UserOrganizationMappingFactory + + def create_test_users(): ''' Creates four test users to test the combination of permissions @@ -31,3 +38,17 @@ def is_response_paginated(response_data): # If we can't get keys, wer'e certainly not paginated return False return set(keys) == set([u'count', u'next', u'previous', u'results']) + + +def make_caller(org): + """Convenience method to create the API caller user + """ + if organizations_support_sites(): + # TODO: set is_staff to False after we have test coverage + caller = UserFactory(is_staff=True) + UserOrganizationMappingFactory(user=caller, + organization=org, + is_amc_admin=True) + else: + caller = UserFactory(is_staff=True) + return caller diff --git a/tests/views/test_learner_metrics_viewset.py b/tests/views/test_learner_metrics_viewset.py index f6ea014b..5dc943e3 100644 --- a/tests/views/test_learner_metrics_viewset.py +++ b/tests/views/test_learner_metrics_viewset.py @@ -7,57 +7,21 @@ from rest_framework import status from rest_framework.test import APIRequestFactory, force_authenticate -from figures.sites import get_user_ids_for_site -from figures.views import LearnerMetricsViewSet - -from tests.factories import ( - CourseEnrollmentFactory, - CourseOverviewFactory, - # LearnerCourseGradeMetricsFactory, - OrganizationFactory, - SiteFactory, - UserFactory, +from figures.helpers import as_course_key +from figures.sites import ( + get_course_keys_for_site, + users_enrolled_in_courses, ) +from figures.views import LearnerMetricsViewSet from tests.helpers import organizations_support_sites from tests.views.base import BaseViewTest -from tests.views.helpers import is_response_paginated +from tests.views.helpers import is_response_paginated, make_caller -if organizations_support_sites(): - from tests.factories import UserOrganizationMappingFactory - def map_users_to_org_site(caller, site, users): - org = OrganizationFactory(sites=[site]) - UserOrganizationMappingFactory(user=caller, - organization=org, - is_amc_admin=True) - [UserOrganizationMappingFactory(user=user, - organization=org) for user in users] - # return created objects that the test will need - return caller - - -@pytest.fixture -def enrollment_test_data(): - """Stands up shared test data. We need to revisit this - """ - num_courses = 2 - site = SiteFactory() - course_overviews = [CourseOverviewFactory() for i in range(num_courses)] - # Create a number of enrollments for each course - enrollments = [] - for num_enroll, co in enumerate(course_overviews, 1): - enrollments += [CourseEnrollmentFactory( - course_id=co.id) for i in range(num_enroll)] - - # This is a convenience for the test method - users = [enrollment.user for enrollment in enrollments] - return dict( - site=site, - course_overviews=course_overviews, - enrollments=enrollments, - users=users, - ) +def filter_enrollments(enrollments, courses): + course_ids = [elem.id for elem in courses] + return [elem for elem in enrollments if elem.course_id in course_ids] @pytest.mark.django_db @@ -103,17 +67,6 @@ def setup(self, db, settings): settings.FEATURES['FIGURES_IS_MULTISITE'] = True super(TestLearnerMetricsViewSet, self).setup(db) - def make_caller(self, site, users): - """Convenience method to create the API caller user - """ - if organizations_support_sites(): - # TODO: set is_staff to False after we have test coverage - caller = UserFactory(is_staff=True) - map_users_to_org_site(caller=caller, site=site, users=users) - else: - caller = UserFactory(is_staff=True) - return caller - def make_request(self, monkeypatch, request_path, site, caller, action): """Convenience method to make the API request @@ -128,7 +81,17 @@ def make_request(self, monkeypatch, request_path, site, caller, action): view = self.view_class.as_view({'get': action}) return view(request) - def test_list_method_all(self, monkeypatch, enrollment_test_data): + def matching_enrollment_set_to_course_ids(self, enrollments, course_ids): + """ + enrollment course ids need to be a subset of course_ids + It is ok if there are none or fewer enrollments than course_ids because + a learner might not be enrolled in all the courses on which we are + filtering + """ + enroll_course_ids = set([rec['course_id'] for rec in enrollments]) + return enroll_course_ids.issubset(set([str(rec) for rec in course_ids])) + + def test_list_method_all(self, monkeypatch, lm_test_data): """Partial test coverage to check we get all site users Checks returned user ids against all user ids for the site @@ -137,50 +100,52 @@ def test_list_method_all(self, monkeypatch, enrollment_test_data): Does NOT check values in the `enrollments` key. This should be done as follow up work """ - site = enrollment_test_data['site'] - users = enrollment_test_data['users'] - - caller = self.make_caller(site, users) - other_site = SiteFactory() - assert site.domain != other_site.domain + us = lm_test_data['us'] + them = lm_test_data['them'] + our_courses = us['courses'] + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + assert len(our_courses) > 1 response = self.make_request(request_path=self.base_request_path, monkeypatch=monkeypatch, - site=site, + site=us['site'], caller=caller, action='list') assert response.status_code == status.HTTP_200_OK assert is_response_paginated(response.data) results = response.data['results'] - # Check user ids + # Check users result_ids = [obj['id'] for obj in results] - user_ids = get_user_ids_for_site(site=site) + # Get all enrolled users + course_keys = get_course_keys_for_site(site=us['site']) + users = users_enrolled_in_courses(course_keys) + user_ids = [user.id for user in users] assert set(result_ids) == set(user_ids) # Spot check the first record top_keys = ['id', 'username', 'email', 'fullname', 'is_active', 'date_joined', 'enrollments'] assert set(results[0].keys()) == set(top_keys) - def test_course_param_single(self, monkeypatch, enrollment_test_data): + def test_course_param_single(self, monkeypatch, lm_test_data): """Test that the 'course' query parameter works """ - site = enrollment_test_data['site'] - users = enrollment_test_data['users'] - enrollments = enrollment_test_data['enrollments'] - course_overviews = enrollment_test_data['course_overviews'] + us = lm_test_data['us'] + them = lm_test_data['them'] + our_enrollments = us['enrollments'] + our_courses = us['courses'] - caller = self.make_caller(site, users) - other_site = SiteFactory() - assert site.domain != other_site.domain - assert len(course_overviews) > 1 - query_params = '?course={}'.format(str(course_overviews[0].id)) + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + assert len(our_courses) > 1 + query_params = '?course={}'.format(str(our_courses[0].id)) request_path = self.base_request_path + query_params response = self.make_request(request_path=request_path, monkeypatch=monkeypatch, - site=site, + site=us['site'], caller=caller, action='list') @@ -190,37 +155,109 @@ def test_course_param_single(self, monkeypatch, enrollment_test_data): # Check user ids result_ids = [obj['id'] for obj in results] - our_enrollments = [elem for elem in enrollments if elem.course_id == course_overviews[0].id] - expected_user_ids = [obj.user.id for obj in our_enrollments] + course_enrollments = [elem for elem in our_enrollments + if elem.course_id == our_courses[0].id] + expected_user_ids = [obj.user.id for obj in course_enrollments] assert set(result_ids) == set(expected_user_ids) - def test_course_param_multiple(self, monkeypatch, enrollment_test_data): + for rec in results: + assert self.matching_enrollment_set_to_course_ids( + rec['enrollments'], [our_courses[0].id]) + + def test_course_param_multiple(self, monkeypatch, lm_test_data): """Test that the 'course' query parameter works """ - site = enrollment_test_data['site'] - users = enrollment_test_data['users'] - enrollments = enrollment_test_data['enrollments'] - course_overviews = enrollment_test_data['course_overviews'] - - caller = self.make_caller(site, users) - other_site = SiteFactory() - assert site.domain != other_site.domain - assert len(course_overviews) > 1 - query_params = '?course={}&course={}'.format(str(course_overviews[0].id), - str(course_overviews[1].id)) + us = lm_test_data['us'] + them = lm_test_data['them'] + our_enrollments = us['enrollments'] + our_courses = us['courses'] + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + assert len(our_courses) > 1 + + filtered_courses = our_courses[:2] + + # TODO: build params from 'filtered_courses' + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + str(our_courses[1].id)) request_path = self.base_request_path + query_params response = self.make_request(request_path=request_path, monkeypatch=monkeypatch, - site=site, + site=us['site'], caller=caller, action='list') + # Continue updating here assert response.status_code == status.HTTP_200_OK assert is_response_paginated(response.data) results = response.data['results'] # Check user ids result_ids = [obj['id'] for obj in results] - expected_user_ids = [obj.user.id for obj in enrollments] + expected_enrollments = filter_enrollments(enrollments=our_enrollments, + courses=filtered_courses) + expected_user_ids = [obj.user.id for obj in expected_enrollments] assert set(result_ids) == set(expected_user_ids) + for rec in results: + assert self.matching_enrollment_set_to_course_ids( + rec['enrollments'], [rec.id for rec in filtered_courses]) + + def invalid_course_ids_raise_404(self, monkeypatch, lm_test_data, query_params): + """ + Helper method to test expected 404 calls + """ + us = lm_test_data['us'] + them = lm_test_data['them'] + + caller = make_caller(us['org']) + assert us['site'].domain != them['site'].domain + request_path = self.base_request_path + query_params + response = self.make_request(request_path=request_path, + monkeypatch=monkeypatch, + site=us['site'], + caller=caller, + action='list') + return response.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.skipif(not organizations_support_sites(), + reason='Organizations support sites') + def test_valid_and_course_param_from_other_site_invalid(self, + monkeypatch, + lm_test_data): + """Test that the 'course' query parameter works + + """ + our_courses = lm_test_data['us']['courses'] + their_courses = lm_test_data['them']['courses'] + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + str(their_courses[0].id)) + assert self.invalid_course_ids_raise_404(monkeypatch, + lm_test_data, + query_params) + + def test_valid_and_mangled_course_param_invalid(self, + monkeypatch, + lm_test_data): + """Test that the 'course' query parameter works + + """ + our_courses = lm_test_data['us']['courses'] + mangled_course_id = 'she-sell-seashells-by-the-seashore' + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + mangled_course_id) + assert self.invalid_course_ids_raise_404(monkeypatch, + lm_test_data, + query_params) + + def test_unlinked_course_id_param_invalid(self, monkeypatch, lm_test_data): + """Test that the 'course' query parameter works + + """ + our_courses = lm_test_data['us']['courses'] + unlinked_course_id = as_course_key('course-v1:UnlinkedCourse+UMK+1999') + query_params = '?course={}&course={}'.format(str(our_courses[0].id), + unlinked_course_id) + assert self.invalid_course_ids_raise_404(monkeypatch, + lm_test_data, + query_params)