From 926f45d21a230d1bc2e1719d6358cfbd8a3492ba Mon Sep 17 00:00:00 2001
From: John Baldwin <jlbaldwin@gmail.com>
Date: Fri, 31 Jul 2020 20:20:56 -0400
Subject: [PATCH] Learner metrics endpoint - Initial commit

This is the initial commit so Matej and work on the front end

The endpoint is `/figures/api/learner-metrics/`

* There is a basic viewset just to exercise the code. The test requires
test data to be filled out and tested in the response

* UserFilterSet needs to be updated or an alternate filter set needs to
be used in order to provide more filtering, in particular
 * Show only users who have enrollments
 * Show only users who do not have enrollments
 * Show only users who have completed
 * Show only users who have not completed

* List serializers need to be added to prefetch data to improve API
performance
* test_learner_metrics_viewset needs to be completed
* Updated the CourseEnrollment mock to provide the `is_enrolled` method
---
 figures/serializers.py                      |  79 ++++++++++
 figures/urls.py                             |  10 +-
 figures/views.py                            |  30 ++++
 mocks/hawthorn/student/models.py            |  44 +++++-
 tests/views/test_learner_metrics_viewset.py | 165 ++++++++++++++++++++
 5 files changed, 326 insertions(+), 2 deletions(-)
 create mode 100644 tests/views/test_learner_metrics_viewset.py

diff --git a/figures/serializers.py b/figures/serializers.py
index 2cfc7ff2..7e589c8a 100644
--- a/figures/serializers.py
+++ b/figures/serializers.py
@@ -18,11 +18,13 @@
 """
 
 import datetime
+from decimal import Decimal
 
 from django.contrib.auth import get_user_model
 from django.contrib.sites.models import Site
 from django_countries import Countries
 from rest_framework import serializers
+from rest_framework.fields import empty
 
 from openedx.core.djangoapps.content.course_overviews.models import CourseOverview  # noqa pylint: disable=import-error
 from openedx.core.djangoapps.user_api.accounts.serializers import AccountLegacyProfileSerializer  # noqa pylint: disable=import-error
@@ -791,6 +793,10 @@ class CourseMauLiveMetricsSerializer(serializers.Serializer):
 
 class EnrollmentMetricsSerializer(serializers.ModelSerializer):
     """Serializer for LearnerCourseGradeMetrics
+
+    This is a prototype serializer for exploring API endpoints
+
+    It provides an enrollment major, use minor view
     """
     user = UserIndexSerializer(read_only=True)
     progress_percent = serializers.DecimalField(max_digits=3,
@@ -808,5 +814,78 @@ class Meta:
 
 
 class CourseCompletedSerializer(serializers.Serializer):
+    """Provides course id and user id for course completions
+
+    This serializer is used in the `enrollment-metrics` endpoint
+    """
     course_id = serializers.CharField()
     user_id = serializers.IntegerField()
+
+
+class EnrollmentMetricsSerializerV2(serializers.ModelSerializer):
+    """Provides serialization for an enrollment
+
+    This serializer note not identify the learner. It is used in
+    LearnerMetricsSerializer
+    """
+    course_id = serializers.CharField()
+    date_enrolled = serializers.DateTimeField(source='created',
+                                              format="%Y-%m-%d")
+    is_enrolled = serializers.BooleanField()
+    progress_percent = serializers.SerializerMethodField()
+    progress_details = serializers.SerializerMethodField()
+
+    def __init__(self, instance=None, data=empty, **kwargs):
+        self._lcgm = None
+        super(EnrollmentMetricsSerializerV2, self).__init__(
+            instance=None, data=empty, **kwargs)
+
+    class Meta:
+        model = CourseEnrollment
+        fields = ['id', 'course_id', 'date_enrolled', 'is_enrolled',
+                  'progress_percent', 'progress_details']
+        read_only_fields = fields
+
+    def to_representation(self, instance):
+        """
+        Get the most recent LCGM record for the enrollment, if it exists
+        """
+        self._lcgm = LearnerCourseGradeMetrics.objects.most_recent_for_learner_course(
+            user=instance.user, course_id=str(instance.course_id))
+        return super(EnrollmentMetricsSerializerV2, self).to_representation(instance)
+
+    def get_is_enrolled(self, obj):
+        """
+        CourseEnrollment has to do some work to get this value
+        TODO: inspect CourseEnrollment._get_enrollment_state to see how we
+        can speed this up, avoiding construction of `CourseEnrollmentState`
+        """
+        return CourseEnrollment.is_enrolled(obj.user, obj.course_id)
+
+    def get_progress_percent(self, obj):  # pylint: disable=unused-argument
+        value = self._lcgm.progress_percent if self._lcgm else 0
+        return float(Decimal(value).quantize(Decimal('.00')))
+
+    def get_progress_details(self, obj):  # pylint: disable=unused-argument
+        """Get progress data for a single enrollment
+        """
+        return self._lcgm.progress_details if self._lcgm else None
+
+
+class LearnerMetricsSerializer(serializers.ModelSerializer):
+    fullname = serializers.CharField(source='profile.name', default=None)
+    # enrollments = EnrollmentMetricsSerializerV2(source='courseenrollment_set',
+    #     many=True)
+    enrollments = serializers.SerializerMethodField()
+
+    class Meta:
+        model = get_user_model()
+        fields = ('id', 'username', 'email', 'fullname', 'is_active',
+                  'date_joined', 'enrollments')
+        read_only_fields = fields
+
+    def get_enrollments(self, user):
+        site_enrollments = figures.sites.get_course_enrollments_for_site(
+            self.context.get('site'))
+        user_enrollments = site_enrollments.filter(user=user)
+        return EnrollmentMetricsSerializerV2(user_enrollments, many=True).data
diff --git a/figures/urls.py b/figures/urls.py
index af92fe7d..cfe4d734 100644
--- a/figures/urls.py
+++ b/figures/urls.py
@@ -100,13 +100,21 @@
     views.UserIndexViewSet,
     base_name='user-index')
 
-# Experimental
+
+# New endpoints in development (unstable)
+# Unstable here means the code is subject to change without notice
 
 router.register(
     r'enrollment-metrics',
     views.EnrollmentMetricsViewSet,
     base_name='enrollment-metrics')
 
+router.register(
+    r'learner-metrics',
+    views.LearnerMetricsViewSet,
+    base_name='learner-metrics')
+
+
 urlpatterns = [
 
     # UI Templates
diff --git a/figures/views.py b/figures/views.py
index b0381904..ef157b77 100644
--- a/figures/views.py
+++ b/figures/views.py
@@ -64,6 +64,7 @@
     EnrollmentMetricsSerializer,
     GeneralCourseDataSerializer,
     LearnerDetailsSerializer,
+    LearnerMetricsSerializer,
     SiteDailyMetricsSerializer,
     SiteMauMetricsSerializer,
     SiteMauLiveMetricsSerializer,
@@ -398,6 +399,35 @@ def get_serializer_context(self):
         return context
 
 
+class LearnerMetricsViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet):
+    """Provides user identity and nested enrollment data
+
+    This view is unders active development and subject to change
+
+    TODO: After we get this class tests running, restructure this module:
+    * Group all user model based viewsets together
+    * Make a base user viewset with the `get_queryset` and `get_serializer_context`
+      methods
+    """
+    model = get_user_model()
+    pagination_class = FiguresLimitOffsetPagination
+    serializer_class = LearnerMetricsSerializer
+    filter_backends = (DjangoFilterBackend, )
+
+    # TODO: Improve this filter
+    filter_class = UserFilterSet
+
+    def get_queryset(self):
+        site = django.contrib.sites.shortcuts.get_current_site(self.request)
+        queryset = figures.sites.get_users_for_site(site)
+        return queryset
+
+    def get_serializer_context(self):
+        context = super(LearnerMetricsViewSet, self).get_serializer_context()
+        context['site'] = django.contrib.sites.shortcuts.get_current_site(self.request)
+        return context
+
+
 class EnrollmentMetricsViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet):
     """Initial viewset for enrollment metrics
 
diff --git a/mocks/hawthorn/student/models.py b/mocks/hawthorn/student/models.py
index d912fb55..f8c8a36f 100644
--- a/mocks/hawthorn/student/models.py
+++ b/mocks/hawthorn/student/models.py
@@ -1,5 +1,5 @@
 
-from collections import defaultdict
+from collections import defaultdict, namedtuple
 from datetime import datetime
 
 from pytz import UTC
@@ -129,6 +129,12 @@ def enrollment_counts(self, course_id):
         return enroll_dict
 
 
+# Named tuple for fields pertaining to the state of
+# CourseEnrollment for a user in a course.  This type
+# is used to cache the state in the request cache.
+CourseEnrollmentState = namedtuple('CourseEnrollmentState', 'mode, is_active')
+
+
 class CourseEnrollment(models.Model):
     '''
     The production model is student.models.CourseEnrollment
@@ -191,6 +197,42 @@ def __init__(self, *args, **kwargs):
         # When the property .course_overview is accessed for the first time, this variable will be set.
         self._course_overview = None
 
+    @classmethod
+    def is_enrolled(cls, user, course_key):
+        """
+        Returns True if the user is enrolled in the course (the entry must exist
+        and it must have `is_active=True`). Otherwise, returns False.
+
+        `user` is a Django User object. If it hasn't been saved yet (no `.id`
+               attribute), this method will automatically save it before
+               adding an enrollment for it.
+
+        `course_id` is our usual course_id string (e.g. "edX/Test101/2013_Fall)
+        """
+        enrollment_state = cls._get_enrollment_state(user, course_key)
+        return enrollment_state.is_active or False
+
+    @classmethod
+    def _get_enrollment_state(cls, user, course_key):
+        """
+        Returns the CourseEnrollmentState for the given user
+        and course_key, caching the result for later retrieval.
+
+        Figures note: removed the caching after copying this method
+        """
+        assert user
+
+        if user.is_anonymous:
+            return CourseEnrollmentState(None, None)
+
+        try:
+            record = cls.objects.get(user=user, course_id=course_key)
+            enrollment_state = CourseEnrollmentState(record.mode, record.is_active)
+        except cls.DoesNotExist:
+            enrollment_state = CourseEnrollmentState(None, None)
+
+        return enrollment_state
+
 
 class CourseAccessRole(models.Model):
     user = models.ForeignKey(User)
diff --git a/tests/views/test_learner_metrics_viewset.py b/tests/views/test_learner_metrics_viewset.py
new file mode 100644
index 00000000..997c4da2
--- /dev/null
+++ b/tests/views/test_learner_metrics_viewset.py
@@ -0,0 +1,165 @@
+"""Tests Figures learner-metrics viewset
+"""
+
+import pytest
+
+import django.contrib.sites.shortcuts
+from rest_framework import status
+from rest_framework.test import APIRequestFactory, force_authenticate
+
+from figures.sites import get_user_ids_for_site
+from figures.views import LearnerMetricsViewSet
+
+from tests.factories import (
+    CourseEnrollmentFactory,
+    CourseOverviewFactory,
+    # LearnerCourseGradeMetricsFactory,
+    OrganizationFactory,
+    SiteFactory,
+    UserFactory,
+)
+
+from tests.helpers import organizations_support_sites
+from tests.views.base import BaseViewTest
+from tests.views.helpers import is_response_paginated
+
+if organizations_support_sites():
+    from tests.factories import UserOrganizationMappingFactory
+
+    def map_users_to_org_site(caller, site, users):
+        org = OrganizationFactory(sites=[site])
+        UserOrganizationMappingFactory(user=caller,
+                                       organization=org,
+                                       is_amc_admin=True)
+        [UserOrganizationMappingFactory(user=user,
+                                        organization=org) for user in users]
+        # return created objects that the test will need
+        return caller
+
+
+@pytest.fixture
+def enrollment_test_data():
+    """Stands up shared test data. We need to revisit this
+    """
+    num_courses = 2
+    site = SiteFactory()
+    course_overviews = [CourseOverviewFactory() for i in range(num_courses)]
+    # Create a number of enrollments for each course
+    enrollments = []
+    for num_enroll, co in enumerate(course_overviews, 1):
+        enrollments += [CourseEnrollmentFactory(
+            course_id=co.id) for i in range(num_enroll)]
+
+    # This is a convenience for the test method
+    users = [enrollment.user for enrollment in enrollments]
+    return dict(
+        site=site,
+        course_overviews=course_overviews,
+        enrollments=enrollments,
+        users=users,
+    )
+
+
+@pytest.mark.django_db
+class TestLearnerMetricsViewSet(BaseViewTest):
+    """Tests the learner metrics viewset
+
+    The tests are incomplete
+
+    The list action will return a list of the following records:
+
+    ```
+        {
+            "id": 109,
+            "username": "chasecynthia",
+            "email": "msnyder@gmail.com",
+            "fullname": "Brandon Meyers",
+            "is_active": true,
+            "date_joined": "2020-06-03T00:00:00Z",
+            "enrollments": [
+                {
+                    "id": 9,
+                    "course_id": "course-v1:StarFleetAcademy+SFA01+2161",
+                    "date_enrolled": "2020-02-24",
+                    "is_enrolled": true,
+                    "progress_percent": 1.0,
+                    "progress_details": {
+                        "sections_worked": 20,
+                        "points_possible": 100.0,
+                        "sections_possible": 20,
+                        "points_earned": 50.0
+                    }
+                }
+            ]
+        }
+    ```
+    """
+    base_request_path = 'api/learner-metrics/'
+    view_class = LearnerMetricsViewSet
+
+    @pytest.fixture(autouse=True)
+    def setup(self, db, settings):
+        if organizations_support_sites():
+            settings.FEATURES['FIGURES_IS_MULTISITE'] = True
+        super(TestLearnerMetricsViewSet, self).setup(db)
+
+    def make_caller(self, site, users):
+        """Convenience method to create the API caller user
+        """
+        if organizations_support_sites():
+            # TODO: set is_staff to False after we have test coverage
+            caller = UserFactory(is_staff=True)
+            map_users_to_org_site(caller=caller, site=site, users=users)
+        else:
+            caller = UserFactory(is_staff=True)
+        return caller
+
+    def make_request(self, monkeypatch, request_path, site, caller, action):
+        """Convenience method to make the API request
+
+        Returns the response object
+        """
+        request = APIRequestFactory().get(request_path)
+        request.META['HTTP_HOST'] = site.domain
+        monkeypatch.setattr(django.contrib.sites.shortcuts,
+                            'get_current_site',
+                            lambda req: site)
+        force_authenticate(request, user=caller)
+        view = self.view_class.as_view({'get': action})
+        return view(request)
+
+    def test_list_method_all(self, monkeypatch, enrollment_test_data):
+        """Partial test coverage to check we get all site users
+
+        Checks returned user ids against all user ids for the site
+        Checks top level keys
+
+        Does NOT check values in the `enrollments` key. This should be done as
+        follow up work
+        """
+        site = enrollment_test_data['site']
+        users = enrollment_test_data['users']
+        enrollments = enrollment_test_data['enrollments']
+
+        caller = self.make_caller(site, users)
+        other_site = SiteFactory()
+        assert site.domain != other_site.domain
+
+        response = self.make_request(request_path=self.base_request_path,
+                                     monkeypatch=monkeypatch,
+                                     site=site,
+                                     caller=caller,
+                                     action='list')
+
+        assert response.status_code == status.HTTP_200_OK
+        assert is_response_paginated(response.data)
+        results = response.data['results']
+
+        # Check user ids
+        result_ids = [obj['id'] for obj in results]
+        user_ids = get_user_ids_for_site(site=site)
+        assert set(result_ids) == set(user_ids)
+        # Spot check the first record
+        top_keys = ['id', 'username', 'email', 'fullname', 'is_active',
+                    'date_joined', 'enrollments']
+        assert set(results[0].keys()) == set(top_keys)