diff --git a/kolibri/auth/api.py b/kolibri/auth/api.py index 40a65b34951..065670293b8 100644 --- a/kolibri/auth/api.py +++ b/kolibri/auth/api.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, print_function, unicode_literals -from django.contrib.auth import get_user +from django.contrib.auth import authenticate, get_user, login, logout from rest_framework import filters, permissions, viewsets from rest_framework.response import Response @@ -120,3 +120,24 @@ class LearnerGroupViewSet(viewsets.ModelViewSet): serializer_class = LearnerGroupSerializer filter_fields = ('parent',) + + +class SessionViewSet(viewsets.ViewSet): + + def destroy(self, request, pk=None): + logout(request) + return Response("successfully deleted!") + + def create(self, request): + username = request.POST.get('username', '') + password = request.POST.get('password', '') + facility_id = request.POST.get('facility', '') + user = authenticate(username=username, password=password, facility=facility_id) + if user is not None and user.is_active: + # Correct password, and the user is marked "active" + login(request, user) + # Success! + return Response("Successfully logged in!") + else: + # Respond with error + return Response("User does not exist with those credentials!") diff --git a/kolibri/auth/api_urls.py b/kolibri/auth/api_urls.py index 2c8daffaa72..b4cb99f1d63 100644 --- a/kolibri/auth/api_urls.py +++ b/kolibri/auth/api_urls.py @@ -1,7 +1,9 @@ from rest_framework import routers from .api import ( - ClassroomViewSet, CurrentFacilityViewSet, DeviceOwnerViewSet, FacilityUserViewSet, FacilityViewSet, LearnerGroupViewSet, MembershipViewSet, RoleViewSet + ClassroomViewSet, CurrentFacilityViewSet, DeviceOwnerViewSet, + FacilityUserViewSet, FacilityViewSet, LearnerGroupViewSet, + MembershipViewSet, RoleViewSet, SessionViewSet ) router = routers.SimpleRouter() @@ -14,5 +16,6 @@ router.register(r'currentfacility', CurrentFacilityViewSet, base_name='currentfacility') router.register(r'classroom', ClassroomViewSet) router.register(r'learnergroup', LearnerGroupViewSet) +router.register(r'session', SessionViewSet, base_name='session') urlpatterns = router.urls diff --git a/kolibri/auth/backends.py b/kolibri/auth/backends.py index f206d2ca8e8..6d123af5c6b 100644 --- a/kolibri/auth/backends.py +++ b/kolibri/auth/backends.py @@ -48,7 +48,7 @@ class DeviceOwnerBackend(object): A class that implements authentication for DeviceOwners. """ - def authenticate(self, username=None, password=None): + def authenticate(self, username=None, password=None, **kwargs): """ Authenticates the user if the credentials correspond to a DeviceOwner. diff --git a/kolibri/auth/test/test_api.py b/kolibri/auth/test/test_api.py index f656e2c5aff..9f410000b23 100644 --- a/kolibri/auth/test/test_api.py +++ b/kolibri/auth/test/test_api.py @@ -8,6 +8,7 @@ from rest_framework import status from rest_framework.test import APITestCase as BaseTestCase +from django.contrib.sessions.models import Session from .. import models @@ -226,3 +227,32 @@ def test_creating_facility_user_via_api_sets_password_correctly(self): self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertTrue(models.FacilityUser.objects.get(username=new_username).check_password(new_password)) self.assertFalse(models.FacilityUser.objects.get(username=new_username).check_password(bad_password)) + + +class LoginLogoutTestCase(APITestCase): + + def setUp(self): + self.device_owner = DeviceOwnerFactory.create() + self.facility = FacilityFactory.create() + self.user = FacilityUserFactory.create(facility=self.facility) + + def test_login_and_logout_device_owner(self): + self.client.post(reverse('session-list'), data={"username": self.device_owner.username, "password": DUMMY_PASSWORD}) + sessions = Session.objects.all() + self.assertEqual(len(sessions), 1) + session_pk = sessions[0].session_key + self.client.delete(reverse('session-detail', kwargs={'pk': session_pk})) + self.assertEqual(len(Session.objects.all()), 0) + + def test_login_and_logout_facility_user(self): + self.client.post(reverse('session-list'), data={"username": self.user.username, "password": DUMMY_PASSWORD, "facility": self.facility.id}) + sessions = Session.objects.all() + self.assertEqual(len(sessions), 1) + session_pk = sessions[0].session_key + self.client.delete(reverse('session-detail', kwargs={'pk': session_pk})) + self.assertEqual(len(Session.objects.all()), 0) + + def test_incorrect_credentials_does_not_log_in_user(self): + self.client.post(reverse('session-list'), data={"username": self.user.username, "password": "foo", "facility": self.facility.id}) + sessions = Session.objects.all() + self.assertEqual(len(sessions), 0)