diff --git a/h/migrations/versions/e554d862135f_add_index_to_flag_user_id.py b/h/migrations/versions/e554d862135f_add_index_to_flag_user_id.py new file mode 100644 index 00000000000..72b8776a1c0 --- /dev/null +++ b/h/migrations/versions/e554d862135f_add_index_to_flag_user_id.py @@ -0,0 +1,25 @@ +""" +Add index to flag.user_id + +Revision ID: e554d862135f +Revises: 5655d56d7c29 +Create Date: 2017-03-16 12:35:45.791202 +""" + +from __future__ import unicode_literals + +from alembic import op + + +revision = 'e554d862135f' +down_revision = '5655d56d7c29' + + +def upgrade(): + op.execute('COMMIT') + op.create_index(op.f('ix__flag__user_id'), 'flag', ['user_id'], + unique=False, postgresql_concurrently=True) + + +def downgrade(): + op.drop_index(op.f('ix__flag__user_id'), 'flag') diff --git a/h/models/flag.py b/h/models/flag.py index 691a86eb969..f9432c1444c 100644 --- a/h/models/flag.py +++ b/h/models/flag.py @@ -32,7 +32,7 @@ class Flag(Base, Timestamps): user_id = sa.Column(sa.Integer, sa.ForeignKey('user.id', ondelete='cascade'), - nullable=False) + nullable=False, index=True) #: The user who created the flag. user = sa.orm.relationship('User') diff --git a/h/services/flag.py b/h/services/flag.py index d984832a669..bfff20366e4 100644 --- a/h/services/flag.py +++ b/h/services/flag.py @@ -2,7 +2,10 @@ from __future__ import unicode_literals +from memex import uri + from h import models +from h import storage class FlagService(object): @@ -50,6 +53,45 @@ def create(self, user, annotation): annotation=annotation) self.session.add(flag) + def list(self, user, group=None, uris=None): + """ + Return a list of flags made by the given user. + + :param user: The user to filter flags on. + :type user: h.models.User + + :param group: The annotation group pubid for filtering flags. + :type group: unicode + + :param uris: A list of annotation uris for filtering flags. + :type uris: list of unicode + + :returns: list of flags (``h.models.Flag``) + :rtype: list + """ + + query = self.session.query(models.Flag).filter_by(user=user) + + joined_annotation = False + + if group is not None: + joined_annotation = True + query = query.join(models.Annotation) \ + .filter(models.Annotation.groupid == group) + + if uris: + query_uris = set() + for u in uris: + expanded = storage.expand_uri(self.session, u) + query_uris.update([uri.normalize(e) for e in expanded]) + + if not joined_annotation: + joined_annotation = True + query = query.join(models.Annotation) + query = query.filter(models.Annotation.target_uri_normalized.in_(query_uris)) + + return query + def flag_service_factory(context, request): return FlagService(request.db) diff --git a/h/views/api_flags.py b/h/views/api_flags.py index 8da0ece6cae..f58009a6cf7 100644 --- a/h/views/api_flags.py +++ b/h/views/api_flags.py @@ -21,6 +21,23 @@ def create(context, request): return HTTPNoContent() +@api_config(route_name='api.flags', + request_method='GET', + link_name='flag.index', + description='List a users flagged annotations for review.', + effective_principals=security.Authenticated) +def index(request): + group = request.GET.get('group') + if not group: + group = None + + uris = request.GET.getall('uri') + + svc = request.find_service(name='flag') + flags = svc.list(request.authenticated_user, group=group, uris=uris) + return [{'annotation': flag.annotation_id} for flag in flags] + + def _fetch_annotation(context, request): try: annotation_id = request.json_body.get('annotation') diff --git a/tests/h/conftest.py b/tests/h/conftest.py index 4086dd9d3f7..129943f8179 100644 --- a/tests/h/conftest.py +++ b/tests/h/conftest.py @@ -17,6 +17,7 @@ from pyramid import testing from pyramid.request import apply_request_extensions from sqlalchemy.orm import sessionmaker +from webob.multidict import MultiDict from h import db from h import models # noqa: ensure base class set for memex @@ -229,6 +230,9 @@ def pyramid_request(db_session, fake_feature, pyramid_settings): request.matched_route = mock.Mock() request.registry.settings = pyramid_settings request.is_xhr = False + request.params = MultiDict() + request.GET = request.params + request.POST = request.params return request diff --git a/tests/h/services/flag_test.py b/tests/h/services/flag_test.py index 0e376bbe4bd..311adcfc29c 100644 --- a/tests/h/services/flag_test.py +++ b/tests/h/services/flag_test.py @@ -6,7 +6,7 @@ from h.services import flag from h import models -from h._compat import xrange +from h._compat import text_type, xrange @pytest.mark.usefixtures('flags') @@ -53,6 +53,56 @@ def test_it_skips_creating_flag_when_already_exists(self, svc, db_session, facto .count() == 1 +@pytest.mark.usefixtures('flags') +class TestFlagServiceList(object): + def test_it_filters_by_user(self, svc, users, flags): + expected = {f for k, f in flags.iteritems() if k.startswith('alice-')} + assert set(svc.list(users['alice'])) == expected + + def test_it_optionally_filters_by_group(self, svc, users, flags, groups): + expected = [flags['alice-politics']] + result = svc.list(users['alice'], group=groups['politics']).all() + assert result == expected + + def test_it_optionally_filters_by_uri(self, svc, users, flags): + expected = [flags['alice-climate']] + result = svc.list(users['alice'], uris=['https://science.org']).all() + assert result == expected + + def test_it_supports_multiple_uri_filters(self, svc, users, flags): + expected = [flags['alice-climate'], flags['alice-politics']] + result = svc.list(users['alice'], uris=['https://science.org', 'https://news.com']).all() + assert result == expected + + @pytest.fixture + def users(self, factories): + return {'alice': factories.User(username='alice'), + 'bob': factories.User(username='bob')} + + @pytest.fixture + def groups(self, factories): + return {'climate': text_type(factories.Group(name='Climate').pubid), + 'politics': text_type(factories.Group(name='Politics').pubid)} + + @pytest.fixture + def flags(self, factories, users, groups, db_session): + ann_climate = factories.Annotation(groupid=groups['climate'], + target_uri='https://science.com') + factories.DocumentURI(claimant='https://science.org', + uri='https://science.org', + type='rel-alternate', + document=ann_climate.document) + + ann_politics = factories.Annotation(groupid=groups['politics'], + target_uri='https://news.com') + + return { + 'alice-climate': factories.Flag(user=users['alice'], annotation=ann_climate), + 'alice-politics': factories.Flag(user=users['alice'], annotation=ann_politics), + 'bob-politics': factories.Flag(user=users['bob'], annotation=ann_politics), + } + + class TestFlagServiceFactory(object): def test_it_returns_flag_service(self, pyramid_request): svc = flag.flag_service_factory(None, pyramid_request) diff --git a/tests/h/views/api_flags_test.py b/tests/h/views/api_flags_test.py index 03ce81be58a..d49d49dcc29 100644 --- a/tests/h/views/api_flags_test.py +++ b/tests/h/views/api_flags_test.py @@ -101,3 +101,64 @@ def flag_service(self, pyramid_config): flag_service = mock.Mock(spec_set=['create']) pyramid_config.register_service(flag_service, name='flag') return flag_service + + +@pytest.mark.usefixtures('flag_service') +class TestIndex(object): + def test_it_passes_user_filter(self, pyramid_request, flag_service): + views.index(pyramid_request) + flag_service.list.assert_called_once_with(pyramid_request.authenticated_user, + group=None, + uris=[]) + + def test_it_passes_group_filter(self, pyramid_request, flag_service): + pyramid_request.GET['group'] = 'test-pubid' + + views.index(pyramid_request) + + flag_service.list.assert_called_once_with(mock.ANY, + group='test-pubid', + uris=[]) + + def test_it_skips_empty_group_filter(self, pyramid_request, flag_service): + pyramid_request.GET['group'] = '' + + views.index(pyramid_request) + + flag_service.list.assert_called_once_with(mock.ANY, + group=None, + uris=[]) + + def test_it_passes_uris_filter(self, pyramid_request, flag_service): + pyramid_request.GET.add('uri', 'https://example.com/document') + pyramid_request.GET.add('uri', 'https://example.org/document') + + views.index(pyramid_request) + + flag_service.list.assert_called_once_with(mock.ANY, + group=mock.ANY, + uris=['https://example.com/document', 'https://example.org/document']) + + def test_it_renders_flags(self, pyramid_request, flags): + expected = [{'annotation': f.annotation_id} for f in flags] + + response = views.index(pyramid_request) + assert response == expected + + @pytest.fixture + def flags(self, factories): + return [factories.Flag.build(annotation_id='test-annotation-1'), + factories.Flag.build(annotation_id='test-annotation-2'), + factories.Flag.build(annotation_id='test-annotation-3')] + + @pytest.fixture + def flag_service(self, pyramid_config, flags): + flag_service = mock.Mock(spec_set=['list']) + flag_service.list.return_value = flags + pyramid_config.register_service(flag_service, name='flag') + return flag_service + + @pytest.fixture + def pyramid_request(self, pyramid_request): + pyramid_request.authenticated_user = mock.Mock() + return pyramid_request