diff --git a/datasette/app.py b/datasette/app.py index f6b3e51443..4b09b7d17e 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,6 +1,8 @@ import asyncio import collections import hashlib +import itertools +import json import os import re import sys @@ -12,7 +14,8 @@ import click from markupsafe import Markup -from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader +from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader, escape +from jinja2.environment import Template import uvicorn from .views.base import DatasetteError, ureg, AsgiRouter @@ -27,6 +30,7 @@ QueryInterrupted, escape_css_string, escape_sqlite, + format_bytes, get_plugins, module_from_path, sqlite3, @@ -35,6 +39,7 @@ from .utils.asgi import ( AsgiLifespan, NotFound, + Response, asgi_static, asgi_send, asgi_send_html, @@ -526,6 +531,96 @@ def register_renderers(self): for renderer in hook_renderers: self.renderers[renderer["extension"]] = renderer["callback"] + async def render_template( + self, templates, context=None, request=None, view_name=None + ): + context = context or {} + if isinstance(templates, Template): + template = templates + select_templates = [] + else: + if isinstance(templates, str): + templates = [templates] + template = self.jinja_env.select_template(templates) + select_templates = [ + "{}{}".format( + "*" if template_name == template.name else "", template_name + ) + for template_name in templates + ] + body_scripts = [] + # pylint: disable=no-member + for script in pm.hook.extra_body_script( + template=template.name, + database=context.get("database"), + table=context.get("table"), + view_name=view_name, + datasette=self, + ): + body_scripts.append(Markup(script)) + + extra_template_vars = {} + # pylint: disable=no-member + for extra_vars in pm.hook.extra_template_vars( + template=template.name, + database=context.get("database"), + table=context.get("table"), + view_name=view_name, + request=request, + datasette=self, + ): + if callable(extra_vars): + extra_vars = extra_vars() + if asyncio.iscoroutine(extra_vars): + extra_vars = await extra_vars + assert isinstance(extra_vars, dict), "extra_vars is of type {}".format( + type(extra_vars) + ) + extra_template_vars.update(extra_vars) + + template_context = { + **context, + **{ + "app_css_hash": self.app_css_hash(), + "select_templates": select_templates, + "zip": zip, + "body_scripts": body_scripts, + "format_bytes": format_bytes, + "extra_css_urls": self._asset_urls("extra_css_urls", template, context), + "extra_js_urls": self._asset_urls("extra_js_urls", template, context), + }, + **extra_template_vars, + } + return await template.render_async(template_context) + + def _asset_urls(self, key, template, context): + # Flatten list-of-lists from plugins: + seen_urls = set() + for url_or_dict in itertools.chain( + itertools.chain.from_iterable( + getattr(pm.hook, key)( + template=template.name, + database=context.get("database"), + table=context.get("table"), + datasette=self, + ) + ), + (self.metadata(key) or []), + ): + if isinstance(url_or_dict, dict): + url = url_or_dict["url"] + sri = url_or_dict.get("sri") + else: + url = url_or_dict + sri = None + if url in seen_urls: + continue + seen_urls.add(url) + if sri: + yield {"url": url, "sri": sri} + else: + yield {"url": url} + def app(self): "Returns an ASGI app function that serves the whole of Datasette" default_templates = str(app_root / "datasette" / "templates") diff --git a/datasette/views/base.py b/datasette/views/base.py index 61561a329f..673e6d7e3f 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -9,15 +9,12 @@ import jinja2 import pint -from html import escape - from datasette import __version__ from datasette.plugins import pm from datasette.utils import ( QueryInterrupted, InvalidSql, LimitedWriter, - format_bytes, is_url, path_with_added_args, path_with_removed_args, @@ -65,34 +62,6 @@ async def head(self, *args, **kwargs): response.body = b"" return response - def _asset_urls(self, key, template, context): - # Flatten list-of-lists from plugins: - seen_urls = set() - for url_or_dict in itertools.chain( - itertools.chain.from_iterable( - getattr(pm.hook, key)( - template=template.name, - database=context.get("database"), - table=context.get("table"), - datasette=self.ds, - ) - ), - (self.ds.metadata(key) or []), - ): - if isinstance(url_or_dict, dict): - url = url_or_dict["url"] - sri = url_or_dict.get("sri") - else: - url = url_or_dict - sri = None - if url in seen_urls: - continue - seen_urls.add(url) - if sri: - yield {"url": url, "sri": sri} - else: - yield {"url": url} - def database_url(self, database): db = self.ds.databases[database] if self.ds.config("hash_urls") and db.hash: @@ -105,62 +74,22 @@ def database_color(self, database): async def render(self, templates, request, context): template = self.ds.jinja_env.select_template(templates) - select_templates = [ - "{}{}".format("*" if template_name == template.name else "", template_name) - for template_name in templates - ] - body_scripts = [] - # pylint: disable=no-member - for script in pm.hook.extra_body_script( - template=template.name, - database=context.get("database"), - table=context.get("table"), - view_name=self.name, - datasette=self.ds, - ): - body_scripts.append(jinja2.Markup(script)) - - extra_template_vars = {} - # pylint: disable=no-member - for extra_vars in pm.hook.extra_template_vars( - template=template.name, - database=context.get("database"), - table=context.get("table"), - view_name=self.name, - request=request, - datasette=self.ds, - ): - if callable(extra_vars): - extra_vars = extra_vars() - if asyncio.iscoroutine(extra_vars): - extra_vars = await extra_vars - assert isinstance(extra_vars, dict), "extra_vars is of type {}".format( - type(extra_vars) - ) - extra_template_vars.update(extra_vars) - template_context = { - **context, - **{ - "app_css_hash": self.ds.app_css_hash(), - "select_templates": select_templates, - "zip": zip, - "body_scripts": body_scripts, - "extra_css_urls": self._asset_urls("extra_css_urls", template, context), - "extra_js_urls": self._asset_urls("extra_js_urls", template, context), - "format_bytes": format_bytes, - "database_url": self.database_url, - "database_color": self.database_color, - }, - **extra_template_vars, - } - if request.args.get("_context") and self.ds.config("template_debug"): + **context, + **{ + "database_url": self.database_url, + "database_color": self.database_color, + }, + } + if request and request.args.get("_context") and self.ds.config("template_debug"): return Response.html( "
{}
".format( - escape(json.dumps(template_context, default=repr, indent=4)) + jinja2.escape(json.dumps(template_context, default=repr, indent=4)) ) ) - return Response.html(await template.render_async(template_context)) + return Response.html(await self.ds.render_template( + template, template_context, request=request + )) class DataView(BaseView):