diff --git a/datasette/app.py b/datasette/app.py index 7502009a45..72ff526435 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -18,7 +18,6 @@ import time from .utils import ( Filters, - compound_pks_from_path, CustomJSONEncoder, compound_keys_after_sql, detect_fts_sql, @@ -33,6 +32,7 @@ path_with_ext, sqlite_timelimit, to_css_class, + urlsafe_components, validate_sql_select, ) from .version import __version__ @@ -613,6 +613,14 @@ async def data(self, request, name, hash, table): search_description = 'search matches "{}"'.format(search) params['search'] = search + # Allow for custom sort order + sort = special_args.get('_sort') + if sort: + order_by = sort + sort_desc = special_args.get('_sort_desc') + if sort_desc: + order_by = '{} desc'.format(sort_desc) + count_sql = 'select count(*) from {table_name} {where}'.format( table_name=escape_sqlite(table), where=( @@ -638,20 +646,46 @@ async def data(self, request, name, hash, table): if is_view: # _next is an offset offset = ' offset {}'.format(int(_next)) - elif use_rowid: - where_clauses.append( - 'rowid > :p{}'.format( - len(params), - ) - ) - params['p{}'.format(len(params))] = _next else: - pk_values = compound_pks_from_path(_next) - if len(pk_values) == len(pks): - param_len = len(params) - where_clauses.append(compound_keys_after_sql(pks, param_len)) - for i, pk_value in enumerate(pk_values): - params['p{}'.format(param_len + i)] = pk_value + components = urlsafe_components(_next) + # If a sort order is applied, the first of these is the sort value + if sort or sort_desc: + sort_value = components[0] + components = components[1:] + print('sort_varlue = {}, components = {}'.format( + sort_value, components + )) + + # Figure out the SQL for next-based-on-primary-key first + next_by_pk_clauses = [] + if use_rowid: + next_by_pk_clauses.append( + 'rowid > :p{}'.format( + len(params), + ) + ) + params['p{}'.format(len(params))] = components[0] + else: + # Apply the tie-breaker based on primary keys + if len(components) == len(pks): + param_len = len(params) + next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) + for i, pk_value in enumerate(components): + params['p{}'.format(param_len + i)] = pk_value + + # Now add the sort SQL, which may incorporate next_by_pk_clauses + if sort or sort_desc: + where_clauses.append( + '({column} {op} :p{p} or ({column} = :p{p} and {next_clauses}))'.format( + column=escape_sqlite(sort or sort_desc), + op='>' if sort else '<', + p=len(params), + next_clauses=' and '.join(next_by_pk_clauses), + ) + ) + params['p{}'.format(len(params))] = sort_value + else: + where_clauses.extend(next_by_pk_clauses) where_clause = '' if where_clauses: @@ -707,9 +741,26 @@ async def data(self, request, name, hash, table): next_value = int(_next or 0) + self.page_size else: next_value = path_from_row_pks(rows[-2], pks, use_rowid) - next_url = urllib.parse.urljoin(request.url, path_with_added_args(request, { - '_next': next_value, - })) + # If there's a sort or sort_desc, add that value as a prefix + if (sort or sort_desc) and not is_view: + prefix = str(rows[-2][sort or sort_desc]) + next_value = '{},{}'.format( + urllib.parse.quote_plus(prefix), next_value + ) + added_args = { + '_next': next_value, + } + if sort: + added_args['_sort'] = sort + else: + added_args['_sort_desc'] = sort_desc + else: + added_args = { + '_next': next_value, + } + next_url = urllib.parse.urljoin(request.url, path_with_added_args( + request, added_args + )) rows = rows[:self.page_size] # Number of filtered rows in whole set: @@ -778,7 +829,7 @@ async def extra_template(): class RowView(RowTableShared): async def data(self, request, name, hash, table, pk_path): table = urllib.parse.unquote_plus(table) - pk_values = compound_pks_from_path(pk_path) + pk_values = urlsafe_components(pk_path) pks = await self.pks_for_table(name, table) use_rowid = not pks select = '*' diff --git a/datasette/utils.py b/datasette/utils.py index c674349b82..b7798bf45e 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -29,9 +29,10 @@ ).split()) -def compound_pks_from_path(path): +def urlsafe_components(token): + "Splits token on commas and URL decodes each component" return [ - urllib.parse.unquote_plus(b) for b in path.split(',') + urllib.parse.unquote_plus(b) for b in token.split(',') ] diff --git a/tests/fixtures.py b/tests/fixtures.py index cf19fa3148..4ede6a906b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,6 +1,7 @@ from datasette.app import Datasette import itertools import os +import random import sqlite3 import sys import string @@ -34,6 +35,25 @@ def generate_compound_rows(num): yield a, b, c, '{}-{}-{}'.format(a, b, c) +def generate_sortable_rows(num): + rand = random.Random(42) + for a, b in itertools.islice( + itertools.product(string.ascii_lowercase, repeat=2), num + ): + yield { + 'pk1': a, + 'pk2': b, + 'content': '{}-{}'.format(a, b), + 'sortable': rand.randint(-100, 100), + 'sortable_with_nulls': rand.choice([ + None, rand.random(), rand.random() + ]), + 'sortable_with_nulls_2': rand.choice([ + None, rand.random(), rand.random() + ]), + } + + METADATA = { 'title': 'Datasette Title', 'description': 'Datasette Description', @@ -70,7 +90,6 @@ def generate_compound_rows(num): INSERT INTO compound_primary_key VALUES ('a', 'b', 'c'); - CREATE TABLE compound_three_primary_keys ( pk1 varchar(30), pk2 varchar(30), @@ -79,6 +98,15 @@ def generate_compound_rows(num): PRIMARY KEY (pk1, pk2, pk3) ); +CREATE TABLE sortable ( + pk1 varchar(30), + pk2 varchar(30), + content text, + sortable integer, + sortable_with_nulls real, + sortable_with_nulls_2 real, + PRIMARY KEY (pk1, pk2) +); CREATE TABLE no_primary_key ( content text, @@ -142,6 +170,13 @@ def generate_compound_rows(num): 'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format( a=a, b=b, c=c, content=content ) for a, b, c, content in generate_compound_rows(1001) +]) + '\n'.join([ + '''INSERT INTO sortable VALUES ( + "{pk1}", "{pk2}", "{content}", {sortable}, + {sortable_with_nulls}, {sortable_with_nulls_2}); + '''.format( + **row + ).replace('None', 'null') for row in generate_sortable_rows(201) ]) diff --git a/tests/test_api.py b/tests/test_api.py index 733727c673..acdd5e7132 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ from .fixtures import ( app_client, generate_compound_rows, + generate_sortable_rows, ) import pytest @@ -13,7 +14,7 @@ def test_homepage(app_client): assert response.json.keys() == {'test_tables': 0}.keys() d = response.json['test_tables'] assert d['name'] == 'test_tables' - assert d['tables_count'] == 9 + assert d['tables_count'] == 10 def test_database_page(app_client): @@ -106,6 +107,16 @@ def test_database_page(app_client): 'outgoing': [], }, 'label_column': None, + }, { + 'columns': [ + 'pk1', 'pk2', 'content', 'sortable', 'sortable_with_nulls', + 'sortable_with_nulls_2' + ], + 'name': 'sortable', + 'count': 201, + 'hidden': False, + 'foreign_keys': {'incoming': [], 'outgoing': []}, + 'label_column': None, }, { 'columns': ['pk', 'content'], 'name': 'table/with/slashes.csv', @@ -345,6 +356,30 @@ def test_paginate_compound_keys_with_extra_filters(app_client): assert expected == [f['content'] for f in fetched] +@pytest.mark.parametrize('query_string,sort_key', [ + ('_sort=sortable', lambda row: row['sortable']), + ('_sort_desc=sortable', lambda row: -row['sortable']), +]) +def test_sortable(app_client, query_string, sort_key): + path = '/test_tables/sortable.jsono?{}'.format(query_string) + fetched = [] + page = 0 + while path: + page += 1 + assert page < 100 + response = app_client.get(path, gather_request=False) + fetched.extend(response.json['rows']) + path = response.json['next_url'] + assert 5 == page + expected = list(generate_sortable_rows(201)) + expected.sort(key=sort_key) + assert [ + r['content'] for r in expected + ] == [ + r['content'] for r in fetched + ] + + @pytest.mark.parametrize('path,expected_rows', [ ('/test_tables/simple_primary_key.json?content=hello', [ ['1', 'hello'], diff --git a/tests/test_utils.py b/tests/test_utils.py index a52042561c..d71ecc60fd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,8 +18,8 @@ ('123%2C433,112', ['123,433', '112']), ('123%2F433%2F112', ['123/433/112']), ]) -def test_compound_pks_from_path(path, expected): - assert expected == utils.compound_pks_from_path(path) +def test_urlsafe_components(path, expected): + assert expected == utils.urlsafe_components(path) @pytest.mark.parametrize('row,pks,expected_path', [