From 1f69269fe93e4cd42e56890126cc0dbcf719c6cb Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 13 May 2018 09:44:22 -0300 Subject: [PATCH] Refactored views into new views/ modules, refs #256 --- datasette/app.py | 1142 +---------------------------------- datasette/views/__init__.py | 0 datasette/views/base.py | 356 +++++++++++ datasette/views/database.py | 45 ++ datasette/views/index.py | 59 ++ datasette/views/table.py | 695 +++++++++++++++++++++ 6 files changed, 1165 insertions(+), 1132 deletions(-) create mode 100644 datasette/views/__init__.py create mode 100644 datasette/views/base.py create mode 100644 datasette/views/database.py create mode 100644 datasette/views/index.py create mode 100644 datasette/views/table.py diff --git a/datasette/app.py b/datasette/app.py index beb7e924df..8c4fe41c8a 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,448 +1,43 @@ from sanic import Sanic from sanic import response from sanic.exceptions import NotFound, InvalidUsage -from sanic.views import HTTPMethodView -from sanic.request import RequestParameters +from datasette.views.base import RenderMixin, DatasetteError, ureg, HASH_BLOCK_SIZE +from datasette.views.index import IndexView +from datasette.views.database import DatabaseView, DatabaseDownload +from datasette.views.table import TableView, RowView from jinja2 import Environment, FileSystemLoader, ChoiceLoader, PrefixLoader -import re import sqlite3 from pathlib import Path from concurrent import futures -import asyncio import os -import threading import urllib.parse import itertools import json -import jinja2 import hashlib import sys -import time -import pint import pluggy import traceback from .utils import ( - Filters, - CustomJSONEncoder, - compound_keys_after_sql, detect_fts, detect_spatialite, escape_css_string, escape_sqlite, - filters_should_redirect, get_all_foreign_keys, get_plugins, - is_url, - InvalidSql, module_from_path, - path_from_row_pks, - path_with_added_args, - path_with_ext, - sqlite_timelimit, to_css_class, - urlsafe_components, - validate_sql_select, ) -from . import __version__ from . import hookspecs from .version import __version__ app_root = Path(__file__).parent.parent -HASH_BLOCK_SIZE = 1024 * 1024 -HASH_LENGTH = 7 - -connections = threading.local() -ureg = pint.UnitRegistry() - pm = pluggy.PluginManager('datasette') pm.add_hookspecs(hookspecs) pm.load_setuptools_entrypoints('datasette') -class DatasetteError(Exception): - def __init__(self, message, title=None, error_dict=None, status=500, template=None): - self.message = message - self.title = title - self.error_dict = error_dict or {} - self.status = status - - -class RenderMixin(HTTPMethodView): - def render(self, templates, **context): - template = self.jinja_env.select_template(templates) - select_templates = ['{}{}'.format( - '*' if template_name == template.name else '', - template_name - ) for template_name in templates] - return response.html( - template.render({ - **context, **{ - 'app_css_hash': self.ds.app_css_hash(), - 'select_templates': select_templates, - 'zip': zip, - } - }) - ) - - -class BaseView(RenderMixin): - re_named_parameter = re.compile(':([a-zA-Z0-9_]+)') - - def __init__(self, datasette): - self.ds = datasette - self.files = datasette.files - self.jinja_env = datasette.jinja_env - self.executor = datasette.executor - self.page_size = datasette.page_size - self.max_returned_rows = datasette.max_returned_rows - - def table_metadata(self, database, table): - "Fetch table-specific metadata." - return self.ds.metadata.get( - 'databases', {} - ).get(database, {}).get('tables', {}).get(table, {}) - - def options(self, request, *args, **kwargs): - r = response.text('ok') - if self.ds.cors: - r.headers['Access-Control-Allow-Origin'] = '*' - return r - - def redirect(self, request, path, forward_querystring=True): - if request.query_string and '?' not in path and forward_querystring: - path = '{}?{}'.format( - path, request.query_string - ) - r = response.redirect(path) - r.headers['Link'] = '<{}>; rel=preload'.format(path) - if self.ds.cors: - r.headers['Access-Control-Allow-Origin'] = '*' - return r - - def resolve_db_name(self, db_name, **kwargs): - databases = self.ds.inspect() - hash = None - name = None - if '-' in db_name: - # Might be name-and-hash, or might just be - # a name with a hyphen in it - name, hash = db_name.rsplit('-', 1) - if name not in databases: - # Try the whole name - name = db_name - hash = None - else: - name = db_name - # Verify the hash - try: - info = databases[name] - except KeyError: - raise NotFound('Database not found: {}'.format(name)) - expected = info['hash'][:HASH_LENGTH] - if expected != hash: - should_redirect = '/{}-{}'.format( - name, expected, - ) - if 'table' in kwargs: - should_redirect += '/' + kwargs['table'] - if 'pk_path' in kwargs: - should_redirect += '/' + kwargs['pk_path'] - if 'as_json' in kwargs: - should_redirect += kwargs['as_json'] - if 'as_db' in kwargs: - should_redirect += kwargs['as_db'] - return name, expected, should_redirect - return name, expected, None - - async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def sql_operation_in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.ds.inspect()[db_name] - conn = sqlite3.connect( - 'file:{}?immutable=1'.format(info['file']), - uri=True, - check_same_thread=False, - ) - self.ds.prepare_connection(conn) - setattr(connections, db_name, conn) - - time_limit_ms = self.ds.sql_time_limit_ms - if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms: - time_limit_ms = custom_time_limit - - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params or {}) - max_returned_rows = self.max_returned_rows - if max_returned_rows == page_size: - max_returned_rows += 1 - if max_returned_rows and truncate: - rows = cursor.fetchmany(max_returned_rows + 1) - truncated = len(rows) > max_returned_rows - rows = rows[:max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except Exception as e: - print('ERROR: conn={}, sql = {}, params = {}: {}'.format( - conn, repr(sql), params, e - )) - raise - if truncate: - return rows, truncated, cursor.description - else: - return rows - - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread - ) - - def get_templates(self, database, table=None): - assert NotImplemented - - async def get(self, request, db_name, **kwargs): - name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) - if should_redirect: - return self.redirect(request, should_redirect) - return await self.view_get(request, name, hash, **kwargs) - - async def view_get(self, request, name, hash, **kwargs): - try: - as_json = kwargs.pop('as_json') - except KeyError: - as_json = False - extra_template_data = {} - start = time.time() - status_code = 200 - templates = [] - try: - response_or_template_contexts = await self.data( - request, name, hash, **kwargs - ) - if isinstance(response_or_template_contexts, response.HTTPResponse): - return response_or_template_contexts - else: - data, extra_template_data, templates = response_or_template_contexts - except (sqlite3.OperationalError, InvalidSql) as e: - raise DatasetteError(str(e), title='Invalid SQL', status=400) - except (sqlite3.OperationalError) as e: - raise DatasetteError(str(e)) - except DatasetteError: - raise - end = time.time() - data['query_ms'] = (end - start) * 1000 - for key in ('source', 'source_url', 'license', 'license_url'): - value = self.ds.metadata.get(key) - if value: - data[key] = value - if as_json: - # Special case for .jsono extension - redirect to _shape=objects - if as_json == '.jsono': - return self.redirect( - request, - path_with_added_args( - request, - {'_shape': 'objects'}, - path=request.path.rsplit('.jsono', 1)[0] + '.json' - ), - forward_querystring=False - ) - # Deal with the _shape option - shape = request.args.get('_shape', 'arrays') - if shape in ('objects', 'object', 'array'): - columns = data.get('columns') - rows = data.get('rows') - if rows and columns: - data['rows'] = [ - dict(zip(columns, row)) - for row in rows - ] - if shape == 'object': - error = None - if 'primary_keys' not in data: - error = '_shape=object is only available on tables' - else: - pks = data['primary_keys'] - if not pks: - error = '_shape=object not available for tables with no primary keys' - else: - object_rows = {} - for row in data['rows']: - pk_string = path_from_row_pks(row, pks, not pks) - object_rows[pk_string] = row - data = object_rows - if error: - data = { - 'ok': False, - 'error': error, - 'database': name, - 'database_hash': hash, - } - elif shape == 'array': - data = data['rows'] - elif shape == 'arrays': - pass - else: - status_code = 400 - data = { - 'ok': False, - 'error': 'Invalid _shape: {}'.format(shape), - 'status': 400, - 'title': None, - } - headers = {} - if self.ds.cors: - headers['Access-Control-Allow-Origin'] = '*' - r = response.HTTPResponse( - json.dumps( - data, cls=CustomJSONEncoder - ), - status=status_code, - content_type='application/json', - headers=headers, - ) - else: - extras = {} - if callable(extra_template_data): - extras = extra_template_data() - if asyncio.iscoroutine(extras): - extras = await extras - else: - extras = extra_template_data - context = { - **data, - **extras, - **{ - 'url_json': path_with_ext(request, '.json'), - 'url_jsono': path_with_ext(request, '.jsono'), - 'extra_css_urls': self.ds.extra_css_urls(), - 'extra_js_urls': self.ds.extra_js_urls(), - 'datasette_version': __version__, - } - } - if 'metadata' not in context: - context['metadata'] = self.ds.metadata - r = self.render( - templates, - **context, - ) - r.status = status_code - # Set far-future cache expiry - if self.ds.cache_headers: - r.headers['Cache-Control'] = 'max-age={}'.format( - 365 * 24 * 60 * 60 - ) - return r - - async def custom_sql(self, request, name, hash, sql, editable=True, canned_query=None): - params = request.raw_args - if 'sql' in params: - params.pop('sql') - if '_shape' in params: - params.pop('_shape') - # Extract any :named parameters - named_parameters = self.re_named_parameter.findall(sql) - named_parameter_values = { - named_parameter: params.get(named_parameter) or '' - for named_parameter in named_parameters - } - - # Set to blank string if missing from params - for named_parameter in named_parameters: - if named_parameter not in params: - params[named_parameter] = '' - - extra_args = {} - if params.get('_timelimit'): - extra_args['custom_time_limit'] = int(params['_timelimit']) - rows, truncated, description = await self.execute( - name, sql, params, truncate=True, **extra_args - ) - columns = [r[0] for r in description] - - templates = ['query-{}.html'.format(to_css_class(name)), 'query.html'] - if canned_query: - templates.insert(0, 'query-{}-{}.html'.format( - to_css_class(name), to_css_class(canned_query) - )) - - return { - 'database': name, - 'rows': rows, - 'truncated': truncated, - 'columns': columns, - 'query': { - 'sql': sql, - 'params': params, - } - }, { - 'database_hash': hash, - 'custom_sql': True, - 'named_parameter_values': named_parameter_values, - 'editable': editable, - 'canned_query': canned_query, - }, templates - - -class IndexView(RenderMixin): - def __init__(self, datasette): - self.ds = datasette - self.files = datasette.files - self.jinja_env = datasette.jinja_env - self.executor = datasette.executor - - async def get(self, request, as_json): - databases = [] - for key, info in sorted(self.ds.inspect().items()): - tables = [t for t in info['tables'].values() if not t['hidden']] - hidden_tables = [t for t in info['tables'].values() if t['hidden']] - database = { - 'name': key, - 'hash': info['hash'], - 'path': '{}-{}'.format(key, info['hash'][:HASH_LENGTH]), - 'tables_truncated': sorted( - tables, - key=lambda t: t['count'], - reverse=True - )[:5], - 'tables_count': len(tables), - 'tables_more': len(tables) > 5, - 'table_rows_sum': sum(t['count'] for t in tables), - 'hidden_table_rows_sum': sum(t['count'] for t in hidden_tables), - 'hidden_tables_count': len(hidden_tables), - 'views_count': len(info['views']), - } - databases.append(database) - if as_json: - headers = {} - if self.ds.cors: - headers['Access-Control-Allow-Origin'] = '*' - return response.HTTPResponse( - json.dumps( - {db['name']: db for db in databases}, - cls=CustomJSONEncoder - ), - content_type='application/json', - headers=headers, - ) - else: - return self.render( - ['index.html'], - databases=databases, - metadata=self.ds.metadata, - datasette_version=__version__, - extra_css_urls=self.ds.extra_css_urls(), - extra_js_urls=self.ds.extra_js_urls(), - ) - - class JsonDataView(RenderMixin): def __init__(self, datasette, filename, data_callback): self.ds = datasette @@ -473,722 +68,6 @@ async def favicon(request): return response.text('') -class DatabaseView(BaseView): - async def data(self, request, name, hash): - if request.args.get('sql'): - sql = request.raw_args.pop('sql') - validate_sql_select(sql) - return await self.custom_sql(request, name, hash, sql) - info = self.ds.inspect()[name] - metadata = self.ds.metadata.get('databases', {}).get(name, {}) - self.ds.update_with_inherited_metadata(metadata) - tables = list(info['tables'].values()) - tables.sort(key=lambda t: (t['hidden'], t['name'])) - return { - 'database': name, - 'tables': tables, - 'hidden_count': len([t for t in tables if t['hidden']]), - 'views': info['views'], - 'queries': [{ - 'name': query_name, - 'sql': query_sql, - } for query_name, query_sql in (metadata.get('queries') or {}).items()], - }, { - 'database_hash': hash, - 'show_hidden': request.args.get('_show_hidden'), - 'editable': True, - 'metadata': metadata, - }, ('database-{}.html'.format(to_css_class(name)), 'database.html') - - -class DatabaseDownload(BaseView): - async def view_get(self, request, name, hash, **kwargs): - filepath = self.ds.inspect()[name]['file'] - return await response.file_stream( - filepath, - filename=os.path.basename(filepath), - mime_type='application/octet-stream', - ) - - -class RowTableShared(BaseView): - def sortable_columns_for_table(self, name, table, use_rowid): - table_metadata = self.table_metadata(name, table) - if 'sortable_columns' in table_metadata: - sortable_columns = set(table_metadata['sortable_columns']) - else: - table_info = self.ds.inspect()[name]['tables'].get(table) or {} - sortable_columns = set(table_info.get('columns', [])) - if use_rowid: - sortable_columns.add('rowid') - return sortable_columns - - async def display_columns_and_rows(self, database, table, description, rows, link_column=False, expand_foreign_keys=True): - "Returns columns, rows for specified table - including fancy foreign key treatment" - table_metadata = self.table_metadata(database, table) - info = self.ds.inspect()[database] - sortable_columns = self.sortable_columns_for_table(database, table, True) - columns = [{ - 'name': r[0], - 'sortable': r[0] in sortable_columns, - } for r in description] - tables = info['tables'] - table_info = tables.get(table) or {} - pks = table_info.get('primary_keys') or [] - - # Prefetch foreign key resolutions for later expansion: - fks = {} - labeled_fks = {} - if table_info and expand_foreign_keys: - foreign_keys = table_info['foreign_keys']['outgoing'] - for fk in foreign_keys: - label_column = ( - # First look in metadata.json definition for this foreign key table: - self.table_metadata(database, fk['other_table']).get('label_column') - # Fall back to label_column from .inspect() detection: - or tables.get(fk['other_table'], {}).get('label_column') - ) - if not label_column: - # No label for this FK - fks[fk['column']] = fk['other_table'] - continue - ids_to_lookup = set([row[fk['column']] for row in rows]) - sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format( - other_column=fk['other_column'], - label_column=label_column, - other_table=escape_sqlite(fk['other_table']), - placeholders=', '.join(['?'] * len(ids_to_lookup)), - ) - try: - results = await self.execute(database, sql, list(set(ids_to_lookup))) - except sqlite3.OperationalError: - # Probably hit the timelimit - pass - else: - for id, value in results: - labeled_fks[(fk['column'], id)] = (fk['other_table'], value) - - cell_rows = [] - for row in rows: - cells = [] - # Unless we are a view, the first column is a link - either to the rowid - # or to the simple or compound primary key - if link_column: - cells.append({ - 'column': pks[0] if len(pks) == 1 else 'Link', - 'value': jinja2.Markup( - '{flat_pks}'.format( - database=database, - table=urllib.parse.quote_plus(table), - flat_pks=str(jinja2.escape(path_from_row_pks(row, pks, not pks, False))), - flat_pks_quoted=path_from_row_pks(row, pks, not pks) - ) - ), - }) - - for value, column_dict in zip(row, columns): - column = column_dict['name'] - if link_column and len(pks) == 1 and column == pks[0]: - # If there's a simple primary key, don't repeat the value as it's - # already shown in the link column. - continue - if (column, value) in labeled_fks: - other_table, label = labeled_fks[(column, value)] - display_value = jinja2.Markup( - '{label} {id}'.format( - database=database, - table=urllib.parse.quote_plus(other_table), - link_id=urllib.parse.quote_plus(str(value)), - id=str(jinja2.escape(value)), - label=str(jinja2.escape(label)), - ) - ) - elif column in fks: - display_value = jinja2.Markup( - '{id}'.format( - database=database, - table=urllib.parse.quote_plus(fks[column]), - link_id=urllib.parse.quote_plus(str(value)), - id=str(jinja2.escape(value)))) - elif value is None: - display_value = jinja2.Markup(' ') - elif is_url(str(value).strip()): - display_value = jinja2.Markup( - '{url}'.format( - url=jinja2.escape(value.strip()) - ) - ) - elif column in table_metadata.get('units', {}) and value != '': - # Interpret units using pint - value = value * ureg(table_metadata['units'][column]) - # Pint uses floating point which sometimes introduces errors in the compact - # representation, which we have to round off to avoid ugliness. In the vast - # majority of cases this rounding will be inconsequential. I hope. - value = round(value.to_compact(), 6) - display_value = jinja2.Markup('{:~P}'.format(value).replace(' ', ' ')) - else: - display_value = str(value) - - cells.append({ - 'column': column, - 'value': display_value, - }) - cell_rows.append(cells) - - if link_column: - # Add the link column header. - # If it's a simple primary key, we have to remove and re-add that column name at - # the beginning of the header row. - if len(pks) == 1: - columns = [col for col in columns if col['name'] != pks[0]] - - columns = [{ - 'name': pks[0] if len(pks) == 1 else 'Link', - 'sortable': len(pks) == 1, - }] + columns - return columns, cell_rows - - -class TableView(RowTableShared): - async def data(self, request, name, hash, table): - table = urllib.parse.unquote_plus(table) - canned_query = self.ds.get_canned_query(name, table) - if canned_query is not None: - return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table) - is_view = bool(list(await self.execute( - name, - "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", - {'n': table} - ))[0][0]) - view_definition = None - table_definition = None - if is_view: - view_definition = list(await self.execute( - name, - 'select sql from sqlite_master where name = :n and type="view"', - {'n': table} - ))[0][0] - else: - table_definition_rows = list(await self.execute( - name, - 'select sql from sqlite_master where name = :n and type="table"', - {'n': table} - )) - if not table_definition_rows: - raise NotFound('Table not found: {}'.format(table)) - table_definition = table_definition_rows[0][0] - info = self.ds.inspect() - table_info = info[name]['tables'].get(table) or {} - pks = table_info.get('primary_keys') or [] - use_rowid = not pks and not is_view - if use_rowid: - select = 'rowid, *' - order_by = 'rowid' - order_by_pks = 'rowid' - else: - select = '*' - order_by_pks = ', '.join([escape_sqlite(pk) for pk in pks]) - order_by = order_by_pks - - if is_view: - order_by = '' - - # We roll our own query_string decoder because by default Sanic - # drops anything with an empty value e.g. ?name__exact= - args = RequestParameters( - urllib.parse.parse_qs(request.query_string, keep_blank_values=True) - ) - - # Special args start with _ and do not contain a __ - # That's so if there is a column that starts with _ - # it can still be queried using ?_col__exact=blah - special_args = {} - special_args_lists = {} - other_args = {} - for key, value in args.items(): - if key.startswith('_') and '__' not in key: - special_args[key] = value[0] - special_args_lists[key] = value - else: - other_args[key] = value[0] - - # Handle ?_filter_column and redirect, if present - redirect_params = filters_should_redirect(special_args) - if redirect_params: - return self.redirect( - request, - path_with_added_args(request, redirect_params), - forward_querystring=False - ) - - # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) - if '_sort_by_desc' in special_args: - return self.redirect( - request, - path_with_added_args(request, { - '_sort_desc': special_args.get('_sort'), - '_sort_by_desc': None, - '_sort': None, - }), - forward_querystring=False - ) - - table_metadata = self.table_metadata(name, table) - units = table_metadata.get('units', {}) - filters = Filters(sorted(other_args.items()), units, ureg) - where_clauses, params = filters.build_where_clauses() - - # _search support: - fts_table = info[name]['tables'].get(table, {}).get('fts_table') - search_args = dict( - pair for pair in special_args.items() - if pair[0].startswith('_search') - ) - search_descriptions = [] - search = '' - if fts_table and search_args: - if '_search' in search_args: - # Simple ?_search=xxx - search = search_args['_search'] - where_clauses.append( - 'rowid in (select rowid from [{fts_table}] where [{fts_table}] match :search)'.format( - fts_table=fts_table - ) - ) - search_descriptions.append('search matches "{}"'.format(search)) - params['search'] = search - else: - # More complex: search against specific columns - valid_columns = set(info[name]['tables'][fts_table]['columns']) - for i, (key, search_text) in enumerate(search_args.items()): - search_col = key.split('_search_', 1)[1] - if search_col not in valid_columns: - raise DatasetteError( - 'Cannot search by that column', - status=400 - ) - where_clauses.append( - 'rowid in (select rowid from [{fts_table}] where [{search_col}] match :search_{i})'.format( - fts_table=fts_table, - search_col=search_col, - i=i, - ) - ) - search_descriptions.append( - 'search column "{}" matches "{}"'.format(search_col, search_text) - ) - params['search_{}'.format(i)] = search_text - - table_rows_count = None - sortable_columns = set() - if not is_view: - table_rows_count = table_info['count'] - sortable_columns = self.sortable_columns_for_table(name, table, use_rowid) - - # Allow for custom sort order - sort = special_args.get('_sort') - if sort: - if sort not in sortable_columns: - raise DatasetteError('Cannot sort table by {}'.format(sort)) - order_by = escape_sqlite(sort) - sort_desc = special_args.get('_sort_desc') - if sort_desc: - if sort_desc not in sortable_columns: - raise DatasetteError('Cannot sort table by {}'.format(sort_desc)) - if sort: - raise DatasetteError('Cannot use _sort and _sort_desc at the same time') - order_by = '{} desc'.format(escape_sqlite(sort_desc)) - - from_sql = 'from {table_name} {where}'.format( - table_name=escape_sqlite(table), - where=( - 'where {} '.format(' and '.join(where_clauses)) - ) if where_clauses else '', - ) - count_sql = 'select count(*) {}'.format(from_sql) - - _next = special_args.get('_next') - offset = '' - if _next: - if is_view: - # _next is an offset - offset = ' offset {}'.format(int(_next)) - else: - components = urlsafe_components(_next) - # If a sort order is applied, the first of these is the sort value - if sort or sort_desc: - sort_value = components[0] - # Special case for if non-urlencoded first token was $null - if _next.split(',')[0] == '$null': - sort_value = None - components = components[1:] - - # Figure out the SQL for next-based-on-primary-key first - next_by_pk_clauses = [] - if use_rowid: - next_by_pk_clauses.append( - 'rowid > :p{}'.format( - len(params), - ) - ) - params['p{}'.format(len(params))] = components[0] - else: - # Apply the tie-breaker based on primary keys - if len(components) == len(pks): - param_len = len(params) - next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) - for i, pk_value in enumerate(components): - params['p{}'.format(param_len + i)] = pk_value - - # Now add the sort SQL, which may incorporate next_by_pk_clauses - if sort or sort_desc: - if sort_value is None: - if sort_desc: - # Just items where column is null ordered by pk - where_clauses.append( - '({column} is null and {next_clauses})'.format( - column=escape_sqlite(sort_desc), - next_clauses=' and '.join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - '({column} is not null or ({column} is null and {next_clauses}))'.format( - column=escape_sqlite(sort), - next_clauses=' and '.join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - '({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))'.format( - column=escape_sqlite(sort or sort_desc), - op='>' if sort else '<', - p=len(params), - extra_desc_only='' if sort else ' or {column2} is null'.format( - column2=escape_sqlite(sort or sort_desc), - ), - next_clauses=' and '.join(next_by_pk_clauses), - ) - ) - params['p{}'.format(len(params))] = sort_value - order_by = '{}, {}'.format( - order_by, order_by_pks - ) - else: - where_clauses.extend(next_by_pk_clauses) - - where_clause = '' - if where_clauses: - where_clause = 'where {} '.format(' and '.join(where_clauses)) - - if order_by: - order_by = 'order by {} '.format(order_by) - - # _group_count=col1&_group_count=col2 - group_count = special_args_lists.get('_group_count') or [] - if group_count: - sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( - group_cols=', '.join('"{}"'.format(group_count_col) for group_count_col in group_count), - table_name=escape_sqlite(table), - where=where_clause, - ) - return await self.custom_sql(request, name, hash, sql, editable=True) - - extra_args = {} - # Handle ?_page_size=500 - page_size = request.raw_args.get('_size') - if page_size: - if page_size == 'max': - page_size = self.max_returned_rows - try: - page_size = int(page_size) - if page_size < 0: - raise ValueError - except ValueError: - raise DatasetteError( - '_size must be a positive integer', - status=400 - ) - if page_size > self.max_returned_rows: - raise DatasetteError( - '_size must be <= {}'.format(self.max_returned_rows), - status=400 - ) - extra_args['page_size'] = page_size - else: - page_size = self.page_size - - sql = 'select {select} from {table_name} {where}{order_by}limit {limit}{offset}'.format( - select=select, - table_name=escape_sqlite(table), - where=where_clause, - order_by=order_by, - limit=page_size + 1, - offset=offset, - ) - - if request.raw_args.get('_timelimit'): - extra_args['custom_time_limit'] = int(request.raw_args['_timelimit']) - - rows, truncated, description = await self.execute( - name, sql, params, truncate=True, **extra_args - ) - - # facets support - try: - facets = request.args['_facet'] - except KeyError: - facets = table_metadata.get('facets', []) - facet_results = {} - for column in facets: - facet_sql = ''' - select {col} as value, count(*) as count - {from_sql} - group by {col} order by count desc limit 20 - '''.format(col=escape_sqlite(column), from_sql=from_sql) - try: - facet_rows = await self.execute( - name, - facet_sql, - params, - truncate=False, - custom_time_limit=200 - ) - facet_results[column] = [{ - 'value': row['value'], - 'count': row['count'], - 'toggle_url': urllib.parse.urljoin( - request.url, path_with_added_args( - request, {column: row['value']} - ) - ) - } for row in facet_rows] - except sqlite3.OperationalError: - # Hit time limit - pass - - columns = [r[0] for r in description] - rows = list(rows) - - filter_columns = columns[:] - if use_rowid and filter_columns[0] == 'rowid': - filter_columns = filter_columns[1:] - - # Pagination next link - next_value = None - next_url = None - if len(rows) > page_size and page_size > 0: - if is_view: - next_value = int(_next or 0) + page_size - else: - next_value = path_from_row_pks(rows[-2], pks, use_rowid) - # If there's a sort or sort_desc, add that value as a prefix - if (sort or sort_desc) and not is_view: - prefix = rows[-2][sort or sort_desc] - if prefix is None: - prefix = '$null' - else: - prefix = urllib.parse.quote_plus(str(prefix)) - next_value = '{},{}'.format(prefix, next_value) - added_args = { - '_next': next_value, - } - if sort: - added_args['_sort'] = sort - else: - added_args['_sort_desc'] = sort_desc - else: - added_args = { - '_next': next_value, - } - next_url = urllib.parse.urljoin(request.url, path_with_added_args( - request, added_args - )) - rows = rows[:page_size] - - # Number of filtered rows in whole set: - filtered_table_rows_count = None - if count_sql: - try: - count_rows = list(await self.execute(name, count_sql, params)) - filtered_table_rows_count = count_rows[0][0] - except sqlite3.OperationalError: - # Almost certainly hit the timeout - pass - - # human_description_en combines filters AND search, if provided - human_description_en = filters.human_description_en(extra=search_descriptions) - - if sort or sort_desc: - sorted_by = 'sorted by {}{}'.format( - (sort or sort_desc), - ' descending' if sort_desc else '', - ) - human_description_en = ' '.join([ - b for b in [human_description_en, sorted_by] if b - ]) - - async def extra_template(): - display_columns, display_rows = await self.display_columns_and_rows( - name, table, description, rows, link_column=not is_view, expand_foreign_keys=True - ) - metadata = self.ds.metadata.get( - 'databases', {} - ).get(name, {}).get('tables', {}).get(table, {}) - self.ds.update_with_inherited_metadata(metadata) - return { - 'database_hash': hash, - 'supports_search': bool(fts_table), - 'search': search or '', - 'use_rowid': use_rowid, - 'filters': filters, - 'display_columns': display_columns, - 'filter_columns': filter_columns, - 'display_rows': display_rows, - 'is_sortable': any(c['sortable'] for c in display_columns), - 'path_with_added_args': path_with_added_args, - 'request': request, - 'sort': sort, - 'sort_desc': sort_desc, - 'disable_sort': is_view, - 'custom_rows_and_columns_templates': [ - '_rows_and_columns-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns-table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns.html', - ], - 'metadata': metadata, - } - - return { - 'database': name, - 'table': table, - 'is_view': is_view, - 'view_definition': view_definition, - 'table_definition': table_definition, - 'human_description_en': human_description_en, - 'rows': rows[:page_size], - 'truncated': truncated, - 'table_rows_count': table_rows_count, - 'filtered_table_rows_count': filtered_table_rows_count, - 'columns': columns, - 'primary_keys': pks, - 'units': units, - 'query': { - 'sql': sql, - 'params': params, - }, - 'facet_results': facet_results, - 'next': next_value and str(next_value) or None, - 'next_url': next_url, - }, extra_template, ( - 'table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - 'table.html' - ) - - -class RowView(RowTableShared): - async def data(self, request, name, hash, table, pk_path): - table = urllib.parse.unquote_plus(table) - pk_values = urlsafe_components(pk_path) - info = self.ds.inspect()[name] - table_info = info['tables'].get(table) or {} - pks = table_info.get('primary_keys') or [] - use_rowid = not pks - select = '*' - if use_rowid: - select = 'rowid, *' - pks = ['rowid'] - wheres = [ - '"{}"=:p{}'.format(pk, i) - for i, pk in enumerate(pks) - ] - sql = 'select {} from "{}" where {}'.format( - select, table, ' AND '.join(wheres) - ) - params = {} - for i, pk_value in enumerate(pk_values): - params['p{}'.format(i)] = pk_value - # rows, truncated, description = await self.execute(name, sql, params, truncate=True) - rows, truncated, description = await self.execute(name, sql, params, truncate=True) - columns = [r[0] for r in description] - rows = list(rows) - if not rows: - raise NotFound('Record not found: {}'.format(pk_values)) - - async def template_data(): - display_columns, display_rows = await self.display_columns_and_rows( - name, table, description, rows, link_column=False, expand_foreign_keys=True - ) - for column in display_columns: - column['sortable'] = False - return { - 'database_hash': hash, - 'foreign_key_tables': await self.foreign_key_tables(name, table, pk_values), - 'display_columns': display_columns, - 'display_rows': display_rows, - 'custom_rows_and_columns_templates': [ - '_rows_and_columns-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns-row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns.html', - ], - 'metadata': self.ds.metadata.get( - 'databases', {} - ).get(name, {}).get('tables', {}).get(table, {}), - } - - data = { - 'database': name, - 'table': table, - 'rows': rows, - 'columns': columns, - 'primary_keys': pks, - 'primary_key_values': pk_values, - 'units': self.table_metadata(name, table).get('units', {}) - } - - if 'foreign_key_tables' in (request.raw_args.get('_extras') or '').split(','): - data['foreign_key_tables'] = await self.foreign_key_tables(name, table, pk_values) - - return data, template_data, ( - 'row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - 'row.html' - ) - - async def foreign_key_tables(self, name, table, pk_values): - if len(pk_values) != 1: - return [] - table_info = self.ds.inspect()[name]['tables'].get(table) - if not table_info: - return [] - foreign_keys = table_info['foreign_keys']['incoming'] - if len(foreign_keys) == 0: - return [] - - sql = 'select ' + ', '.join([ - '(select count(*) from {table} where "{column}"=:id)'.format( - table=escape_sqlite(fk['other_table']), - column=fk['other_column'], - ) - for fk in foreign_keys - ]) - try: - rows = list(await self.execute(name, sql, {'id': pk_values[0]})) - except sqlite3.OperationalError: - # Almost certainly hit the timeout - return [] - foreign_table_counts = dict( - zip( - [(fk['other_table'], fk['other_column']) for fk in foreign_keys], - list(rows[0]), - ) - ) - foreign_key_tables = [] - for fk in foreign_keys: - count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0 - foreign_key_tables.append({**fk, **{'count': count}}) - return foreign_key_tables - - class Datasette: def __init__( self, files, num_threads=3, cache_headers=True, page_size=100, @@ -1217,13 +96,12 @@ def __init__( if self.plugins_dir: for filename in os.listdir(self.plugins_dir): filepath = os.path.join(self.plugins_dir, filename) - with open(filepath) as f: - mod = module_from_path(filepath, name=filename) - try: - pm.register(mod) - except ValueError: - # Plugin already registered - pass + mod = module_from_path(filepath, name=filename) + try: + pm.register(mod) + except ValueError: + # Plugin already registered + pass def app_css_hash(self): if not hasattr(self, '_app_css_hash'): diff --git a/datasette/views/__init__.py b/datasette/views/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/datasette/views/base.py b/datasette/views/base.py new file mode 100644 index 0000000000..3687d0a15e --- /dev/null +++ b/datasette/views/base.py @@ -0,0 +1,356 @@ +from sanic.views import HTTPMethodView +from sanic import response +from sanic.exceptions import NotFound +from datasette import __version__ +from datasette.utils import ( + CustomJSONEncoder, + InvalidSql, + path_from_row_pks, + path_with_added_args, + path_with_ext, + sqlite_timelimit, + to_css_class, +) +import re +import threading +import sqlite3 +import asyncio +import time +import json +import pint + +connections = threading.local() +ureg = pint.UnitRegistry() + +HASH_BLOCK_SIZE = 1024 * 1024 +HASH_LENGTH = 7 + + +class DatasetteError(Exception): + def __init__(self, message, title=None, error_dict=None, status=500, template=None): + self.message = message + self.title = title + self.error_dict = error_dict or {} + self.status = status + + +class RenderMixin(HTTPMethodView): + def render(self, templates, **context): + template = self.jinja_env.select_template(templates) + select_templates = ['{}{}'.format( + '*' if template_name == template.name else '', + template_name + ) for template_name in templates] + return response.html( + template.render({ + **context, **{ + 'app_css_hash': self.ds.app_css_hash(), + 'select_templates': select_templates, + 'zip': zip, + } + }) + ) + + +class BaseView(RenderMixin): + re_named_parameter = re.compile(':([a-zA-Z0-9_]+)') + + def __init__(self, datasette): + self.ds = datasette + self.files = datasette.files + self.jinja_env = datasette.jinja_env + self.executor = datasette.executor + self.page_size = datasette.page_size + self.max_returned_rows = datasette.max_returned_rows + + def table_metadata(self, database, table): + "Fetch table-specific metadata." + return self.ds.metadata.get( + 'databases', {} + ).get(database, {}).get('tables', {}).get(table, {}) + + def options(self, request, *args, **kwargs): + r = response.text('ok') + if self.ds.cors: + r.headers['Access-Control-Allow-Origin'] = '*' + return r + + def redirect(self, request, path, forward_querystring=True): + if request.query_string and '?' not in path and forward_querystring: + path = '{}?{}'.format( + path, request.query_string + ) + r = response.redirect(path) + r.headers['Link'] = '<{}>; rel=preload'.format(path) + if self.ds.cors: + r.headers['Access-Control-Allow-Origin'] = '*' + return r + + def resolve_db_name(self, db_name, **kwargs): + databases = self.ds.inspect() + hash = None + name = None + if '-' in db_name: + # Might be name-and-hash, or might just be + # a name with a hyphen in it + name, hash = db_name.rsplit('-', 1) + if name not in databases: + # Try the whole name + name = db_name + hash = None + else: + name = db_name + # Verify the hash + try: + info = databases[name] + except KeyError: + raise NotFound('Database not found: {}'.format(name)) + expected = info['hash'][:HASH_LENGTH] + if expected != hash: + should_redirect = '/{}-{}'.format( + name, expected, + ) + if 'table' in kwargs: + should_redirect += '/' + kwargs['table'] + if 'pk_path' in kwargs: + should_redirect += '/' + kwargs['pk_path'] + if 'as_json' in kwargs: + should_redirect += kwargs['as_json'] + if 'as_db' in kwargs: + should_redirect += kwargs['as_db'] + return name, expected, should_redirect + return name, expected, None + + async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None): + """Executes sql against db_name in a thread""" + page_size = page_size or self.page_size + + def sql_operation_in_thread(): + conn = getattr(connections, db_name, None) + if not conn: + info = self.ds.inspect()[db_name] + conn = sqlite3.connect( + 'file:{}?immutable=1'.format(info['file']), + uri=True, + check_same_thread=False, + ) + self.ds.prepare_connection(conn) + setattr(connections, db_name, conn) + + time_limit_ms = self.ds.sql_time_limit_ms + if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms: + time_limit_ms = custom_time_limit + + with sqlite_timelimit(conn, time_limit_ms): + try: + cursor = conn.cursor() + cursor.execute(sql, params or {}) + max_returned_rows = self.max_returned_rows + if max_returned_rows == page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except Exception as e: + print('ERROR: conn={}, sql = {}, params = {}: {}'.format( + conn, repr(sql), params, e + )) + raise + if truncate: + return rows, truncated, cursor.description + else: + return rows + + return await asyncio.get_event_loop().run_in_executor( + self.executor, sql_operation_in_thread + ) + + def get_templates(self, database, table=None): + assert NotImplemented + + async def get(self, request, db_name, **kwargs): + name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) + if should_redirect: + return self.redirect(request, should_redirect) + return await self.view_get(request, name, hash, **kwargs) + + async def view_get(self, request, name, hash, **kwargs): + try: + as_json = kwargs.pop('as_json') + except KeyError: + as_json = False + extra_template_data = {} + start = time.time() + status_code = 200 + templates = [] + try: + response_or_template_contexts = await self.data( + request, name, hash, **kwargs + ) + if isinstance(response_or_template_contexts, response.HTTPResponse): + return response_or_template_contexts + else: + data, extra_template_data, templates = response_or_template_contexts + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title='Invalid SQL', status=400) + except (sqlite3.OperationalError) as e: + raise DatasetteError(str(e)) + except DatasetteError: + raise + end = time.time() + data['query_ms'] = (end - start) * 1000 + for key in ('source', 'source_url', 'license', 'license_url'): + value = self.ds.metadata.get(key) + if value: + data[key] = value + if as_json: + # Special case for .jsono extension - redirect to _shape=objects + if as_json == '.jsono': + return self.redirect( + request, + path_with_added_args( + request, + {'_shape': 'objects'}, + path=request.path.rsplit('.jsono', 1)[0] + '.json' + ), + forward_querystring=False + ) + # Deal with the _shape option + shape = request.args.get('_shape', 'arrays') + if shape in ('objects', 'object', 'array'): + columns = data.get('columns') + rows = data.get('rows') + if rows and columns: + data['rows'] = [ + dict(zip(columns, row)) + for row in rows + ] + if shape == 'object': + error = None + if 'primary_keys' not in data: + error = '_shape=object is only available on tables' + else: + pks = data['primary_keys'] + if not pks: + error = '_shape=object not available for tables with no primary keys' + else: + object_rows = {} + for row in data['rows']: + pk_string = path_from_row_pks(row, pks, not pks) + object_rows[pk_string] = row + data = object_rows + if error: + data = { + 'ok': False, + 'error': error, + 'database': name, + 'database_hash': hash, + } + elif shape == 'array': + data = data['rows'] + elif shape == 'arrays': + pass + else: + status_code = 400 + data = { + 'ok': False, + 'error': 'Invalid _shape: {}'.format(shape), + 'status': 400, + 'title': None, + } + headers = {} + if self.ds.cors: + headers['Access-Control-Allow-Origin'] = '*' + r = response.HTTPResponse( + json.dumps( + data, cls=CustomJSONEncoder + ), + status=status_code, + content_type='application/json', + headers=headers, + ) + else: + extras = {} + if callable(extra_template_data): + extras = extra_template_data() + if asyncio.iscoroutine(extras): + extras = await extras + else: + extras = extra_template_data + context = { + **data, + **extras, + **{ + 'url_json': path_with_ext(request, '.json'), + 'url_jsono': path_with_ext(request, '.jsono'), + 'extra_css_urls': self.ds.extra_css_urls(), + 'extra_js_urls': self.ds.extra_js_urls(), + 'datasette_version': __version__, + } + } + if 'metadata' not in context: + context['metadata'] = self.ds.metadata + r = self.render( + templates, + **context, + ) + r.status = status_code + # Set far-future cache expiry + if self.ds.cache_headers: + r.headers['Cache-Control'] = 'max-age={}'.format( + 365 * 24 * 60 * 60 + ) + return r + + async def custom_sql(self, request, name, hash, sql, editable=True, canned_query=None): + params = request.raw_args + if 'sql' in params: + params.pop('sql') + if '_shape' in params: + params.pop('_shape') + # Extract any :named parameters + named_parameters = self.re_named_parameter.findall(sql) + named_parameter_values = { + named_parameter: params.get(named_parameter) or '' + for named_parameter in named_parameters + } + + # Set to blank string if missing from params + for named_parameter in named_parameters: + if named_parameter not in params: + params[named_parameter] = '' + + extra_args = {} + if params.get('_timelimit'): + extra_args['custom_time_limit'] = int(params['_timelimit']) + rows, truncated, description = await self.execute( + name, sql, params, truncate=True, **extra_args + ) + columns = [r[0] for r in description] + + templates = ['query-{}.html'.format(to_css_class(name)), 'query.html'] + if canned_query: + templates.insert(0, 'query-{}-{}.html'.format( + to_css_class(name), to_css_class(canned_query) + )) + + return { + 'database': name, + 'rows': rows, + 'truncated': truncated, + 'columns': columns, + 'query': { + 'sql': sql, + 'params': params, + } + }, { + 'database_hash': hash, + 'custom_sql': True, + 'named_parameter_values': named_parameter_values, + 'editable': editable, + 'canned_query': canned_query, + }, templates diff --git a/datasette/views/database.py b/datasette/views/database.py new file mode 100644 index 0000000000..c7c57ff1ce --- /dev/null +++ b/datasette/views/database.py @@ -0,0 +1,45 @@ +from .base import BaseView +from datasette.utils import ( + validate_sql_select, + to_css_class, +) +from sanic import response +import os + + +class DatabaseView(BaseView): + async def data(self, request, name, hash): + if request.args.get('sql'): + sql = request.raw_args.pop('sql') + validate_sql_select(sql) + return await self.custom_sql(request, name, hash, sql) + info = self.ds.inspect()[name] + metadata = self.ds.metadata.get('databases', {}).get(name, {}) + self.ds.update_with_inherited_metadata(metadata) + tables = list(info['tables'].values()) + tables.sort(key=lambda t: (t['hidden'], t['name'])) + return { + 'database': name, + 'tables': tables, + 'hidden_count': len([t for t in tables if t['hidden']]), + 'views': info['views'], + 'queries': [{ + 'name': query_name, + 'sql': query_sql, + } for query_name, query_sql in (metadata.get('queries') or {}).items()], + }, { + 'database_hash': hash, + 'show_hidden': request.args.get('_show_hidden'), + 'editable': True, + 'metadata': metadata, + }, ('database-{}.html'.format(to_css_class(name)), 'database.html') + + +class DatabaseDownload(BaseView): + async def view_get(self, request, name, hash, **kwargs): + filepath = self.ds.inspect()[name]['file'] + return await response.file_stream( + filepath, + filename=os.path.basename(filepath), + mime_type='application/octet-stream', + ) diff --git a/datasette/views/index.py b/datasette/views/index.py new file mode 100644 index 0000000000..dd0191adfa --- /dev/null +++ b/datasette/views/index.py @@ -0,0 +1,59 @@ +from .base import RenderMixin, HASH_LENGTH +from sanic import response +from datasette.utils import ( + CustomJSONEncoder, +) +from datasette.version import __version__ +import json + + +class IndexView(RenderMixin): + def __init__(self, datasette): + self.ds = datasette + self.files = datasette.files + self.jinja_env = datasette.jinja_env + self.executor = datasette.executor + + async def get(self, request, as_json): + databases = [] + for key, info in sorted(self.ds.inspect().items()): + tables = [t for t in info['tables'].values() if not t['hidden']] + hidden_tables = [t for t in info['tables'].values() if t['hidden']] + database = { + 'name': key, + 'hash': info['hash'], + 'path': '{}-{}'.format(key, info['hash'][:HASH_LENGTH]), + 'tables_truncated': sorted( + tables, + key=lambda t: t['count'], + reverse=True + )[:5], + 'tables_count': len(tables), + 'tables_more': len(tables) > 5, + 'table_rows_sum': sum(t['count'] for t in tables), + 'hidden_table_rows_sum': sum(t['count'] for t in hidden_tables), + 'hidden_tables_count': len(hidden_tables), + 'views_count': len(info['views']), + } + databases.append(database) + if as_json: + headers = {} + if self.ds.cors: + headers['Access-Control-Allow-Origin'] = '*' + return response.HTTPResponse( + json.dumps( + {db['name']: db for db in databases}, + cls=CustomJSONEncoder + ), + content_type='application/json', + headers=headers, + ) + else: + return self.render( + ['index.html'], + databases=databases, + metadata=self.ds.metadata, + datasette_version=__version__, + extra_css_urls=self.ds.extra_css_urls(), + extra_js_urls=self.ds.extra_js_urls(), + ) diff --git a/datasette/views/table.py b/datasette/views/table.py new file mode 100644 index 0000000000..8b997ddcbf --- /dev/null +++ b/datasette/views/table.py @@ -0,0 +1,695 @@ +from sanic.request import RequestParameters +from .base import BaseView, DatasetteError, ureg +from sanic.exceptions import NotFound +from datasette.utils import ( + Filters, + compound_keys_after_sql, + escape_sqlite, + filters_should_redirect, + is_url, + path_from_row_pks, + path_with_added_args, + to_css_class, + urlsafe_components, +) +import sqlite3 +import jinja2 +import urllib + + +class RowTableShared(BaseView): + def sortable_columns_for_table(self, name, table, use_rowid): + table_metadata = self.table_metadata(name, table) + if 'sortable_columns' in table_metadata: + sortable_columns = set(table_metadata['sortable_columns']) + else: + table_info = self.ds.inspect()[name]['tables'].get(table) or {} + sortable_columns = set(table_info.get('columns', [])) + if use_rowid: + sortable_columns.add('rowid') + return sortable_columns + + async def display_columns_and_rows(self, database, table, description, rows, link_column=False, expand_foreign_keys=True): + "Returns columns, rows for specified table - including fancy foreign key treatment" + table_metadata = self.table_metadata(database, table) + info = self.ds.inspect()[database] + sortable_columns = self.sortable_columns_for_table(database, table, True) + columns = [{ + 'name': r[0], + 'sortable': r[0] in sortable_columns, + } for r in description] + tables = info['tables'] + table_info = tables.get(table) or {} + pks = table_info.get('primary_keys') or [] + + # Prefetch foreign key resolutions for later expansion: + fks = {} + labeled_fks = {} + if table_info and expand_foreign_keys: + foreign_keys = table_info['foreign_keys']['outgoing'] + for fk in foreign_keys: + label_column = ( + # First look in metadata.json definition for this foreign key table: + self.table_metadata(database, fk['other_table']).get('label_column') + # Fall back to label_column from .inspect() detection: + or tables.get(fk['other_table'], {}).get('label_column') + ) + if not label_column: + # No label for this FK + fks[fk['column']] = fk['other_table'] + continue + ids_to_lookup = set([row[fk['column']] for row in rows]) + sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format( + other_column=fk['other_column'], + label_column=label_column, + other_table=escape_sqlite(fk['other_table']), + placeholders=', '.join(['?'] * len(ids_to_lookup)), + ) + try: + results = await self.execute(database, sql, list(set(ids_to_lookup))) + except sqlite3.OperationalError: + # Probably hit the timelimit + pass + else: + for id, value in results: + labeled_fks[(fk['column'], id)] = (fk['other_table'], value) + + cell_rows = [] + for row in rows: + cells = [] + # Unless we are a view, the first column is a link - either to the rowid + # or to the simple or compound primary key + if link_column: + cells.append({ + 'column': pks[0] if len(pks) == 1 else 'Link', + 'value': jinja2.Markup( + '{flat_pks}'.format( + database=database, + table=urllib.parse.quote_plus(table), + flat_pks=str(jinja2.escape(path_from_row_pks(row, pks, not pks, False))), + flat_pks_quoted=path_from_row_pks(row, pks, not pks) + ) + ), + }) + + for value, column_dict in zip(row, columns): + column = column_dict['name'] + if link_column and len(pks) == 1 and column == pks[0]: + # If there's a simple primary key, don't repeat the value as it's + # already shown in the link column. + continue + if (column, value) in labeled_fks: + other_table, label = labeled_fks[(column, value)] + display_value = jinja2.Markup( + '{label} {id}'.format( + database=database, + table=urllib.parse.quote_plus(other_table), + link_id=urllib.parse.quote_plus(str(value)), + id=str(jinja2.escape(value)), + label=str(jinja2.escape(label)), + ) + ) + elif column in fks: + display_value = jinja2.Markup( + '{id}'.format( + database=database, + table=urllib.parse.quote_plus(fks[column]), + link_id=urllib.parse.quote_plus(str(value)), + id=str(jinja2.escape(value)))) + elif value is None: + display_value = jinja2.Markup(' ') + elif is_url(str(value).strip()): + display_value = jinja2.Markup( + '{url}'.format( + url=jinja2.escape(value.strip()) + ) + ) + elif column in table_metadata.get('units', {}) and value != '': + # Interpret units using pint + value = value * ureg(table_metadata['units'][column]) + # Pint uses floating point which sometimes introduces errors in the compact + # representation, which we have to round off to avoid ugliness. In the vast + # majority of cases this rounding will be inconsequential. I hope. + value = round(value.to_compact(), 6) + display_value = jinja2.Markup('{:~P}'.format(value).replace(' ', ' ')) + else: + display_value = str(value) + + cells.append({ + 'column': column, + 'value': display_value, + }) + cell_rows.append(cells) + + if link_column: + # Add the link column header. + # If it's a simple primary key, we have to remove and re-add that column name at + # the beginning of the header row. + if len(pks) == 1: + columns = [col for col in columns if col['name'] != pks[0]] + + columns = [{ + 'name': pks[0] if len(pks) == 1 else 'Link', + 'sortable': len(pks) == 1, + }] + columns + return columns, cell_rows + + +class TableView(RowTableShared): + async def data(self, request, name, hash, table): + table = urllib.parse.unquote_plus(table) + canned_query = self.ds.get_canned_query(name, table) + if canned_query is not None: + return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table) + is_view = bool(list(await self.execute( + name, + "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", + {'n': table} + ))[0][0]) + view_definition = None + table_definition = None + if is_view: + view_definition = list(await self.execute( + name, + 'select sql from sqlite_master where name = :n and type="view"', + {'n': table} + ))[0][0] + else: + table_definition_rows = list(await self.execute( + name, + 'select sql from sqlite_master where name = :n and type="table"', + {'n': table} + )) + if not table_definition_rows: + raise NotFound('Table not found: {}'.format(table)) + table_definition = table_definition_rows[0][0] + info = self.ds.inspect() + table_info = info[name]['tables'].get(table) or {} + pks = table_info.get('primary_keys') or [] + use_rowid = not pks and not is_view + if use_rowid: + select = 'rowid, *' + order_by = 'rowid' + order_by_pks = 'rowid' + else: + select = '*' + order_by_pks = ', '.join([escape_sqlite(pk) for pk in pks]) + order_by = order_by_pks + + if is_view: + order_by = '' + + # We roll our own query_string decoder because by default Sanic + # drops anything with an empty value e.g. ?name__exact= + args = RequestParameters( + urllib.parse.parse_qs(request.query_string, keep_blank_values=True) + ) + + # Special args start with _ and do not contain a __ + # That's so if there is a column that starts with _ + # it can still be queried using ?_col__exact=blah + special_args = {} + special_args_lists = {} + other_args = {} + for key, value in args.items(): + if key.startswith('_') and '__' not in key: + special_args[key] = value[0] + special_args_lists[key] = value + else: + other_args[key] = value[0] + + # Handle ?_filter_column and redirect, if present + redirect_params = filters_should_redirect(special_args) + if redirect_params: + return self.redirect( + request, + path_with_added_args(request, redirect_params), + forward_querystring=False + ) + + # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) + if '_sort_by_desc' in special_args: + return self.redirect( + request, + path_with_added_args(request, { + '_sort_desc': special_args.get('_sort'), + '_sort_by_desc': None, + '_sort': None, + }), + forward_querystring=False + ) + + table_metadata = self.table_metadata(name, table) + units = table_metadata.get('units', {}) + filters = Filters(sorted(other_args.items()), units, ureg) + where_clauses, params = filters.build_where_clauses() + + # _search support: + fts_table = info[name]['tables'].get(table, {}).get('fts_table') + search_args = dict( + pair for pair in special_args.items() + if pair[0].startswith('_search') + ) + search_descriptions = [] + search = '' + if fts_table and search_args: + if '_search' in search_args: + # Simple ?_search=xxx + search = search_args['_search'] + where_clauses.append( + 'rowid in (select rowid from [{fts_table}] where [{fts_table}] match :search)'.format( + fts_table=fts_table + ) + ) + search_descriptions.append('search matches "{}"'.format(search)) + params['search'] = search + else: + # More complex: search against specific columns + valid_columns = set(info[name]['tables'][fts_table]['columns']) + for i, (key, search_text) in enumerate(search_args.items()): + search_col = key.split('_search_', 1)[1] + if search_col not in valid_columns: + raise DatasetteError( + 'Cannot search by that column', + status=400 + ) + where_clauses.append( + 'rowid in (select rowid from [{fts_table}] where [{search_col}] match :search_{i})'.format( + fts_table=fts_table, + search_col=search_col, + i=i, + ) + ) + search_descriptions.append( + 'search column "{}" matches "{}"'.format(search_col, search_text) + ) + params['search_{}'.format(i)] = search_text + + table_rows_count = None + sortable_columns = set() + if not is_view: + table_rows_count = table_info['count'] + sortable_columns = self.sortable_columns_for_table(name, table, use_rowid) + + # Allow for custom sort order + sort = special_args.get('_sort') + if sort: + if sort not in sortable_columns: + raise DatasetteError('Cannot sort table by {}'.format(sort)) + order_by = escape_sqlite(sort) + sort_desc = special_args.get('_sort_desc') + if sort_desc: + if sort_desc not in sortable_columns: + raise DatasetteError('Cannot sort table by {}'.format(sort_desc)) + if sort: + raise DatasetteError('Cannot use _sort and _sort_desc at the same time') + order_by = '{} desc'.format(escape_sqlite(sort_desc)) + + from_sql = 'from {table_name} {where}'.format( + table_name=escape_sqlite(table), + where=( + 'where {} '.format(' and '.join(where_clauses)) + ) if where_clauses else '', + ) + count_sql = 'select count(*) {}'.format(from_sql) + + _next = special_args.get('_next') + offset = '' + if _next: + if is_view: + # _next is an offset + offset = ' offset {}'.format(int(_next)) + else: + components = urlsafe_components(_next) + # If a sort order is applied, the first of these is the sort value + if sort or sort_desc: + sort_value = components[0] + # Special case for if non-urlencoded first token was $null + if _next.split(',')[0] == '$null': + sort_value = None + components = components[1:] + + # Figure out the SQL for next-based-on-primary-key first + next_by_pk_clauses = [] + if use_rowid: + next_by_pk_clauses.append( + 'rowid > :p{}'.format( + len(params), + ) + ) + params['p{}'.format(len(params))] = components[0] + else: + # Apply the tie-breaker based on primary keys + if len(components) == len(pks): + param_len = len(params) + next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) + for i, pk_value in enumerate(components): + params['p{}'.format(param_len + i)] = pk_value + + # Now add the sort SQL, which may incorporate next_by_pk_clauses + if sort or sort_desc: + if sort_value is None: + if sort_desc: + # Just items where column is null ordered by pk + where_clauses.append( + '({column} is null and {next_clauses})'.format( + column=escape_sqlite(sort_desc), + next_clauses=' and '.join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + '({column} is not null or ({column} is null and {next_clauses}))'.format( + column=escape_sqlite(sort), + next_clauses=' and '.join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + '({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))'.format( + column=escape_sqlite(sort or sort_desc), + op='>' if sort else '<', + p=len(params), + extra_desc_only='' if sort else ' or {column2} is null'.format( + column2=escape_sqlite(sort or sort_desc), + ), + next_clauses=' and '.join(next_by_pk_clauses), + ) + ) + params['p{}'.format(len(params))] = sort_value + order_by = '{}, {}'.format( + order_by, order_by_pks + ) + else: + where_clauses.extend(next_by_pk_clauses) + + where_clause = '' + if where_clauses: + where_clause = 'where {} '.format(' and '.join(where_clauses)) + + if order_by: + order_by = 'order by {} '.format(order_by) + + # _group_count=col1&_group_count=col2 + group_count = special_args_lists.get('_group_count') or [] + if group_count: + sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( + group_cols=', '.join('"{}"'.format(group_count_col) for group_count_col in group_count), + table_name=escape_sqlite(table), + where=where_clause, + ) + return await self.custom_sql(request, name, hash, sql, editable=True) + + extra_args = {} + # Handle ?_page_size=500 + page_size = request.raw_args.get('_size') + if page_size: + if page_size == 'max': + page_size = self.max_returned_rows + try: + page_size = int(page_size) + if page_size < 0: + raise ValueError + except ValueError: + raise DatasetteError( + '_size must be a positive integer', + status=400 + ) + if page_size > self.max_returned_rows: + raise DatasetteError( + '_size must be <= {}'.format(self.max_returned_rows), + status=400 + ) + extra_args['page_size'] = page_size + else: + page_size = self.page_size + + sql = 'select {select} from {table_name} {where}{order_by}limit {limit}{offset}'.format( + select=select, + table_name=escape_sqlite(table), + where=where_clause, + order_by=order_by, + limit=page_size + 1, + offset=offset, + ) + + if request.raw_args.get('_timelimit'): + extra_args['custom_time_limit'] = int(request.raw_args['_timelimit']) + + rows, truncated, description = await self.execute( + name, sql, params, truncate=True, **extra_args + ) + + # facets support + try: + facets = request.args['_facet'] + except KeyError: + facets = table_metadata.get('facets', []) + facet_results = {} + for column in facets: + facet_sql = ''' + select {col} as value, count(*) as count + {from_sql} + group by {col} order by count desc limit 20 + '''.format(col=escape_sqlite(column), from_sql=from_sql) + try: + facet_rows = await self.execute( + name, + facet_sql, + params, + truncate=False, + custom_time_limit=200 + ) + facet_results[column] = [{ + 'value': row['value'], + 'count': row['count'], + 'toggle_url': urllib.parse.urljoin( + request.url, path_with_added_args( + request, {column: row['value']} + ) + ) + } for row in facet_rows] + except sqlite3.OperationalError: + # Hit time limit + pass + + columns = [r[0] for r in description] + rows = list(rows) + + filter_columns = columns[:] + if use_rowid and filter_columns[0] == 'rowid': + filter_columns = filter_columns[1:] + + # Pagination next link + next_value = None + next_url = None + if len(rows) > page_size and page_size > 0: + if is_view: + next_value = int(_next or 0) + page_size + else: + next_value = path_from_row_pks(rows[-2], pks, use_rowid) + # If there's a sort or sort_desc, add that value as a prefix + if (sort or sort_desc) and not is_view: + prefix = rows[-2][sort or sort_desc] + if prefix is None: + prefix = '$null' + else: + prefix = urllib.parse.quote_plus(str(prefix)) + next_value = '{},{}'.format(prefix, next_value) + added_args = { + '_next': next_value, + } + if sort: + added_args['_sort'] = sort + else: + added_args['_sort_desc'] = sort_desc + else: + added_args = { + '_next': next_value, + } + next_url = urllib.parse.urljoin(request.url, path_with_added_args( + request, added_args + )) + rows = rows[:page_size] + + # Number of filtered rows in whole set: + filtered_table_rows_count = None + if count_sql: + try: + count_rows = list(await self.execute(name, count_sql, params)) + filtered_table_rows_count = count_rows[0][0] + except sqlite3.OperationalError: + # Almost certainly hit the timeout + pass + + # human_description_en combines filters AND search, if provided + human_description_en = filters.human_description_en(extra=search_descriptions) + + if sort or sort_desc: + sorted_by = 'sorted by {}{}'.format( + (sort or sort_desc), + ' descending' if sort_desc else '', + ) + human_description_en = ' '.join([ + b for b in [human_description_en, sorted_by] if b + ]) + + async def extra_template(): + display_columns, display_rows = await self.display_columns_and_rows( + name, table, description, rows, link_column=not is_view, expand_foreign_keys=True + ) + metadata = self.ds.metadata.get( + 'databases', {} + ).get(name, {}).get('tables', {}).get(table, {}) + self.ds.update_with_inherited_metadata(metadata) + return { + 'database_hash': hash, + 'supports_search': bool(fts_table), + 'search': search or '', + 'use_rowid': use_rowid, + 'filters': filters, + 'display_columns': display_columns, + 'filter_columns': filter_columns, + 'display_rows': display_rows, + 'is_sortable': any(c['sortable'] for c in display_columns), + 'path_with_added_args': path_with_added_args, + 'request': request, + 'sort': sort, + 'sort_desc': sort_desc, + 'disable_sort': is_view, + 'custom_rows_and_columns_templates': [ + '_rows_and_columns-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + '_rows_and_columns-table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + '_rows_and_columns.html', + ], + 'metadata': metadata, + } + + return { + 'database': name, + 'table': table, + 'is_view': is_view, + 'view_definition': view_definition, + 'table_definition': table_definition, + 'human_description_en': human_description_en, + 'rows': rows[:page_size], + 'truncated': truncated, + 'table_rows_count': table_rows_count, + 'filtered_table_rows_count': filtered_table_rows_count, + 'columns': columns, + 'primary_keys': pks, + 'units': units, + 'query': { + 'sql': sql, + 'params': params, + }, + 'facet_results': facet_results, + 'next': next_value and str(next_value) or None, + 'next_url': next_url, + }, extra_template, ( + 'table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + 'table.html' + ) + + +class RowView(RowTableShared): + async def data(self, request, name, hash, table, pk_path): + table = urllib.parse.unquote_plus(table) + pk_values = urlsafe_components(pk_path) + info = self.ds.inspect()[name] + table_info = info['tables'].get(table) or {} + pks = table_info.get('primary_keys') or [] + use_rowid = not pks + select = '*' + if use_rowid: + select = 'rowid, *' + pks = ['rowid'] + wheres = [ + '"{}"=:p{}'.format(pk, i) + for i, pk in enumerate(pks) + ] + sql = 'select {} from "{}" where {}'.format( + select, table, ' AND '.join(wheres) + ) + params = {} + for i, pk_value in enumerate(pk_values): + params['p{}'.format(i)] = pk_value + # rows, truncated, description = await self.execute(name, sql, params, truncate=True) + rows, truncated, description = await self.execute(name, sql, params, truncate=True) + columns = [r[0] for r in description] + rows = list(rows) + if not rows: + raise NotFound('Record not found: {}'.format(pk_values)) + + async def template_data(): + display_columns, display_rows = await self.display_columns_and_rows( + name, table, description, rows, link_column=False, expand_foreign_keys=True + ) + for column in display_columns: + column['sortable'] = False + return { + 'database_hash': hash, + 'foreign_key_tables': await self.foreign_key_tables(name, table, pk_values), + 'display_columns': display_columns, + 'display_rows': display_rows, + 'custom_rows_and_columns_templates': [ + '_rows_and_columns-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + '_rows_and_columns-row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + '_rows_and_columns.html', + ], + 'metadata': self.ds.metadata.get( + 'databases', {} + ).get(name, {}).get('tables', {}).get(table, {}), + } + + data = { + 'database': name, + 'table': table, + 'rows': rows, + 'columns': columns, + 'primary_keys': pks, + 'primary_key_values': pk_values, + 'units': self.table_metadata(name, table).get('units', {}) + } + + if 'foreign_key_tables' in (request.raw_args.get('_extras') or '').split(','): + data['foreign_key_tables'] = await self.foreign_key_tables(name, table, pk_values) + + return data, template_data, ( + 'row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + 'row.html' + ) + + async def foreign_key_tables(self, name, table, pk_values): + if len(pk_values) != 1: + return [] + table_info = self.ds.inspect()[name]['tables'].get(table) + if not table_info: + return [] + foreign_keys = table_info['foreign_keys']['incoming'] + if len(foreign_keys) == 0: + return [] + + sql = 'select ' + ', '.join([ + '(select count(*) from {table} where "{column}"=:id)'.format( + table=escape_sqlite(fk['other_table']), + column=fk['other_column'], + ) + for fk in foreign_keys + ]) + try: + rows = list(await self.execute(name, sql, {'id': pk_values[0]})) + except sqlite3.OperationalError: + # Almost certainly hit the timeout + return [] + foreign_table_counts = dict( + zip( + [(fk['other_table'], fk['other_column']) for fk in foreign_keys], + list(rows[0]), + ) + ) + foreign_key_tables = [] + for fk in foreign_keys: + count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0 + foreign_key_tables.append({**fk, **{'count': count}}) + return foreign_key_tables