diff --git a/tests/test_candidates.py b/tests/test_candidates.py index 78c331de0..9e078e455 100644 --- a/tests/test_candidates.py +++ b/tests/test_candidates.py @@ -164,7 +164,6 @@ def test_cand_filters(self): response = self._response(page) self.assertGreater(original_count, response['pagination']['count']) - def test_candidate_sort(self): candidates = [ factories.CandidateFactory(candidate_status='P'), @@ -180,6 +179,22 @@ def test_candidate_sort(self): results = self._results(api.url_for(CandidateSearch, sort='-candidate_status')) self.assertEqual([each['candidate_id'] for each in results], candidate_ids) + def test_candidate_sort_nulls_last(self): + """ + Nulls will sort last by default when sorting ascending - + sort_nulls_last forces nulls to the bottom for descending sort + """ + candidates = [ + factories.CandidateFactory(candidate_id='1'), + factories.CandidateFactory(candidate_id='2', candidate_status='P'), + factories.CandidateFactory(candidate_id='3', candidate_status='C'), + ] + candidate_ids = [each.candidate_id for each in candidates] + results = self._results(api.url_for(CandidateList, sort='candidate_status', sort_nulls_last=True)) + self.assertEqual([each['candidate_id'] for each in results], candidate_ids[::-1]) + results = self._results(api.url_for(CandidateList, sort='-candidate_status', sort_nulls_last=True)) + self.assertEqual([each['candidate_id'] for each in results], ['2', '3', '1']) + class TestCandidateHistory(ApiBaseTest): def setUp(self): diff --git a/tests/test_itemized.py b/tests/test_itemized.py index a09519b25..a529b9624 100644 --- a/tests/test_itemized.py +++ b/tests/test_itemized.py @@ -642,6 +642,17 @@ def test_amount_sched_e(self): results = self._results(api.url_for(ScheduleEView, min_amount=100, max_amount=150)) self.assertTrue(all(100 <= each['expenditure_amount'] <= 150 for each in results)) + def test_sort_sched_e(self): + expenditures = [ + factories.ScheduleEFactory(expenditure_amount=50), + factories.ScheduleEFactory(expenditure_amount=100, expenditure_date=datetime.date(2016, 1, 1)), + factories.ScheduleEFactory(expenditure_amount=150, expenditure_date=datetime.date(2016, 2, 1)), + factories.ScheduleEFactory(expenditure_amount=200, expenditure_date=datetime.date(2016, 3, 1)), + ] + sub_ids = [str(each.sub_id) for each in expenditures] + results = self._results(api.url_for(ScheduleEView, sort='-expenditure_date', sort_nulls_last=True)) + self.assertEqual([each['sub_id'] for each in results], sub_ids[::-1]) + def test_filters_sched_e(self): filters = [ ('image_number', ScheduleE.image_number, ['123', '456']), diff --git a/webservices/args.py b/webservices/args.py index 85f4c7705..522ff5589 100644 --- a/webservices/args.py +++ b/webservices/args.py @@ -117,8 +117,8 @@ def __call__(self, value): status_code=422 ) -def make_sort_args(default=None, validator=None, default_hide_null=False, default_reverse_nulls=True, - default_nulls_only=False): +def make_sort_args(default=None, validator=None, default_hide_null=False, + default_nulls_only=False, default_sort_nulls_last=False): return { 'sort': fields.Str( missing=default, @@ -132,13 +132,17 @@ def make_sort_args(default=None, validator=None, default_hide_null=False, defaul 'sort_null_only': fields.Bool( missing=default_nulls_only, description='Toggle that filters out all rows having sort column that is non-null' + ), + 'sort_nulls_last': fields.Bool( + missing=default_sort_nulls_last, + description='Toggle that sorts null values last' ) } -def make_multi_sort_args(default=None, validator=None, default_hide_null=False, default_reverse_nulls=True, - default_nulls_only=False): - args = make_sort_args(default, validator, default_hide_null, default_reverse_nulls, default_nulls_only) +def make_multi_sort_args(default=None, validator=None, default_hide_null=False, + default_nulls_only=False, default_sort_nulls_last=False): + args = make_sort_args(default, validator, default_hide_null, default_nulls_only, default_sort_nulls_last) args['sort'] = fields.List(fields.Str, missing=default, validate=validator, required=False, allow_none=True, description='Provide a field to sort by. Use - for descending order.',) return args diff --git a/webservices/sorting.py b/webservices/sorting.py index 74dbc1db7..10329786a 100644 --- a/webservices/sorting.py +++ b/webservices/sorting.py @@ -41,14 +41,14 @@ def parse_option(option, model=None, aliases=None, join_columns=None, query=None def multi_sort(query, keys, model, aliases=None, join_columns=None, clear=False, - hide_null=False, index_column=None): + hide_null=False, index_column=None, nulls_last=False): for key in keys: - query,_ = sort(query, key, model, aliases, join_columns, clear, hide_null, index_column) + query,_ = sort(query, key, model, aliases, join_columns, clear, hide_null, index_column, nulls_last) return query,_ def sort(query, key, model, aliases=None, join_columns=None, clear=False, - hide_null=False, index_column=None): + hide_null=False, index_column=None, nulls_last=False): """Sort query using string-formatted columns. :param query: Original query @@ -59,7 +59,7 @@ def sort(query, key, model, aliases=None, join_columns=None, clear=False, :param clear: Clear existing sort conditions :param hide_null: Exclude null values on sorted column(s) :param index_column: - :param reverse_nulls: Swap order of null values on sorted column(s) in results; + :param nulls_last: Sort null values on sorted column(s) last in results; Ignored if hide_null is True """ @@ -111,7 +111,10 @@ def sort(query, key, model, aliases=None, join_columns=None, clear=False, is_expression = True sort_column = order(column) - query = query.order_by(sort_column) + if nulls_last and not hide_null: + query = query.order_by(sa.nullslast(sort_column)) + else: + query = query.order_by(sort_column) if relationship: query = query.join(relationship) diff --git a/webservices/utils.py b/webservices/utils.py index cf68f74fa..c200322c9 100644 --- a/webservices/utils.py +++ b/webservices/utils.py @@ -62,19 +62,18 @@ def check_cap(kwargs, cap): ) -def fetch_page(query, kwargs, model=None, aliases=None, join_columns=None, clear=False, - count=None, cap=100, index_column=None, multi=False): +def fetch_page(query, kwargs, model=None, aliases=None, join_columns=None, clear=False, count=None, cap=100, index_column=None, multi=False): check_cap(kwargs, cap) - sort, hide_null, reverse_nulls = kwargs.get('sort'), kwargs.get('sort_hide_null'), kwargs.get('sort_reverse_nulls') + sort, hide_null, nulls_last = kwargs.get('sort'), kwargs.get('sort_hide_null'), kwargs.get('sort_nulls_last') if sort and multi: query, _ = sorting.multi_sort( query, sort, model=model, aliases=aliases, join_columns=join_columns, - clear=clear, hide_null=hide_null, index_column=index_column + clear=clear, hide_null=hide_null, index_column=index_column, nulls_last=nulls_last ) elif sort: query, _ = sorting.sort( query, sort, model=model, aliases=aliases, join_columns=join_columns, - clear=clear, hide_null=hide_null, index_column=index_column + clear=clear, hide_null=hide_null, index_column=index_column, nulls_last=nulls_last ) paginator = paginators.OffsetPaginator(query, kwargs['per_page'], count=count) return paginator.get_page(kwargs['page']) @@ -216,11 +215,11 @@ def fetch_seek_page(query, kwargs, index_column, clear=False, count=None, cap=10 def fetch_seek_paginator(query, kwargs, index_column, clear=False, count=None, cap=100): check_cap(kwargs, cap) model = index_column.parent.class_ - sort, hide_null = kwargs.get('sort'), kwargs.get('sort_hide_null') + sort, hide_null, nulls_last = kwargs.get('sort'), kwargs.get('sort_hide_null'), kwargs.get('sort_nulls_last') if sort: query, sort_column = sorting.sort( query, sort, - model=model, clear=clear, hide_null=hide_null + model=model, clear=clear, hide_null=hide_null, nulls_last=nulls_last ) else: sort_column = None