Skip to content

Commit

Permalink
Merge pull request #4439 from hypothesis/flag-index-endpoint
Browse files Browse the repository at this point in the history
Add flag index endpoint
  • Loading branch information
nickstenning authored Mar 17, 2017
2 parents 75b2138 + 8eafd34 commit 5654ce3
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 2 deletions.
25 changes: 25 additions & 0 deletions h/migrations/versions/e554d862135f_add_index_to_flag_user_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Add index to flag.user_id
Revision ID: e554d862135f
Revises: 5655d56d7c29
Create Date: 2017-03-16 12:35:45.791202
"""

from __future__ import unicode_literals

from alembic import op


revision = 'e554d862135f'
down_revision = '5655d56d7c29'


def upgrade():
op.execute('COMMIT')
op.create_index(op.f('ix__flag__user_id'), 'flag', ['user_id'],
unique=False, postgresql_concurrently=True)


def downgrade():
op.drop_index(op.f('ix__flag__user_id'), 'flag')
2 changes: 1 addition & 1 deletion h/models/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Flag(Base, Timestamps):

user_id = sa.Column(sa.Integer,
sa.ForeignKey('user.id', ondelete='cascade'),
nullable=False)
nullable=False, index=True)

#: The user who created the flag.
user = sa.orm.relationship('User')
Expand Down
42 changes: 42 additions & 0 deletions h/services/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from __future__ import unicode_literals

from memex import uri

from h import models
from h import storage


class FlagService(object):
Expand Down Expand Up @@ -50,6 +53,45 @@ def create(self, user, annotation):
annotation=annotation)
self.session.add(flag)

def list(self, user, group=None, uris=None):
"""
Return a list of flags made by the given user.
:param user: The user to filter flags on.
:type user: h.models.User
:param group: The annotation group pubid for filtering flags.
:type group: unicode
:param uris: A list of annotation uris for filtering flags.
:type uris: list of unicode
:returns: list of flags (``h.models.Flag``)
:rtype: list
"""

query = self.session.query(models.Flag).filter_by(user=user)

joined_annotation = False

if group is not None:
joined_annotation = True
query = query.join(models.Annotation) \
.filter(models.Annotation.groupid == group)

if uris:
query_uris = set()
for u in uris:
expanded = storage.expand_uri(self.session, u)
query_uris.update([uri.normalize(e) for e in expanded])

if not joined_annotation:
joined_annotation = True
query = query.join(models.Annotation)
query = query.filter(models.Annotation.target_uri_normalized.in_(query_uris))

return query


def flag_service_factory(context, request):
return FlagService(request.db)
17 changes: 17 additions & 0 deletions h/views/api_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ def create(context, request):
return HTTPNoContent()


@api_config(route_name='api.flags',
request_method='GET',
link_name='flag.index',
description='List a users flagged annotations for review.',
effective_principals=security.Authenticated)
def index(request):
group = request.GET.get('group')
if not group:
group = None

uris = request.GET.getall('uri')

svc = request.find_service(name='flag')
flags = svc.list(request.authenticated_user, group=group, uris=uris)
return [{'annotation': flag.annotation_id} for flag in flags]


def _fetch_annotation(context, request):
try:
annotation_id = request.json_body.get('annotation')
Expand Down
4 changes: 4 additions & 0 deletions tests/h/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pyramid import testing
from pyramid.request import apply_request_extensions
from sqlalchemy.orm import sessionmaker
from webob.multidict import MultiDict

from h import db
from h import models # noqa: ensure base class set for memex
Expand Down Expand Up @@ -229,6 +230,9 @@ def pyramid_request(db_session, fake_feature, pyramid_settings):
request.matched_route = mock.Mock()
request.registry.settings = pyramid_settings
request.is_xhr = False
request.params = MultiDict()
request.GET = request.params
request.POST = request.params
return request


Expand Down
52 changes: 51 additions & 1 deletion tests/h/services/flag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from h.services import flag
from h import models
from h._compat import xrange
from h._compat import text_type, xrange


@pytest.mark.usefixtures('flags')
Expand Down Expand Up @@ -53,6 +53,56 @@ def test_it_skips_creating_flag_when_already_exists(self, svc, db_session, facto
.count() == 1


@pytest.mark.usefixtures('flags')
class TestFlagServiceList(object):
def test_it_filters_by_user(self, svc, users, flags):
expected = {f for k, f in flags.iteritems() if k.startswith('alice-')}
assert set(svc.list(users['alice'])) == expected

def test_it_optionally_filters_by_group(self, svc, users, flags, groups):
expected = [flags['alice-politics']]
result = svc.list(users['alice'], group=groups['politics']).all()
assert result == expected

def test_it_optionally_filters_by_uri(self, svc, users, flags):
expected = [flags['alice-climate']]
result = svc.list(users['alice'], uris=['https://science.org']).all()
assert result == expected

def test_it_supports_multiple_uri_filters(self, svc, users, flags):
expected = [flags['alice-climate'], flags['alice-politics']]
result = svc.list(users['alice'], uris=['https://science.org', 'https://news.com']).all()
assert result == expected

@pytest.fixture
def users(self, factories):
return {'alice': factories.User(username='alice'),
'bob': factories.User(username='bob')}

@pytest.fixture
def groups(self, factories):
return {'climate': text_type(factories.Group(name='Climate').pubid),
'politics': text_type(factories.Group(name='Politics').pubid)}

@pytest.fixture
def flags(self, factories, users, groups, db_session):
ann_climate = factories.Annotation(groupid=groups['climate'],
target_uri='https://science.com')
factories.DocumentURI(claimant='https://science.org',
uri='https://science.org',
type='rel-alternate',
document=ann_climate.document)

ann_politics = factories.Annotation(groupid=groups['politics'],
target_uri='https://news.com')

return {
'alice-climate': factories.Flag(user=users['alice'], annotation=ann_climate),
'alice-politics': factories.Flag(user=users['alice'], annotation=ann_politics),
'bob-politics': factories.Flag(user=users['bob'], annotation=ann_politics),
}


class TestFlagServiceFactory(object):
def test_it_returns_flag_service(self, pyramid_request):
svc = flag.flag_service_factory(None, pyramid_request)
Expand Down
61 changes: 61 additions & 0 deletions tests/h/views/api_flags_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,64 @@ def flag_service(self, pyramid_config):
flag_service = mock.Mock(spec_set=['create'])
pyramid_config.register_service(flag_service, name='flag')
return flag_service


@pytest.mark.usefixtures('flag_service')
class TestIndex(object):
def test_it_passes_user_filter(self, pyramid_request, flag_service):
views.index(pyramid_request)
flag_service.list.assert_called_once_with(pyramid_request.authenticated_user,
group=None,
uris=[])

def test_it_passes_group_filter(self, pyramid_request, flag_service):
pyramid_request.GET['group'] = 'test-pubid'

views.index(pyramid_request)

flag_service.list.assert_called_once_with(mock.ANY,
group='test-pubid',
uris=[])

def test_it_skips_empty_group_filter(self, pyramid_request, flag_service):
pyramid_request.GET['group'] = ''

views.index(pyramid_request)

flag_service.list.assert_called_once_with(mock.ANY,
group=None,
uris=[])

def test_it_passes_uris_filter(self, pyramid_request, flag_service):
pyramid_request.GET.add('uri', 'https://example.com/document')
pyramid_request.GET.add('uri', 'https://example.org/document')

views.index(pyramid_request)

flag_service.list.assert_called_once_with(mock.ANY,
group=mock.ANY,
uris=['https://example.com/document', 'https://example.org/document'])

def test_it_renders_flags(self, pyramid_request, flags):
expected = [{'annotation': f.annotation_id} for f in flags]

response = views.index(pyramid_request)
assert response == expected

@pytest.fixture
def flags(self, factories):
return [factories.Flag.build(annotation_id='test-annotation-1'),
factories.Flag.build(annotation_id='test-annotation-2'),
factories.Flag.build(annotation_id='test-annotation-3')]

@pytest.fixture
def flag_service(self, pyramid_config, flags):
flag_service = mock.Mock(spec_set=['list'])
flag_service.list.return_value = flags
pyramid_config.register_service(flag_service, name='flag')
return flag_service

@pytest.fixture
def pyramid_request(self, pyramid_request):
pyramid_request.authenticated_user = mock.Mock()
return pyramid_request

0 comments on commit 5654ce3

Please sign in to comment.