Skip to content

Commit

Permalink
Extract and refactor filters into filters.py
Browse files Browse the repository at this point in the history
This will help in implementing __in as a filter, refs #433
  • Loading branch information
simonw committed Apr 15, 2019
1 parent 9dc7a18 commit 6da567d
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 200 deletions.
156 changes: 156 additions & 0 deletions datasette/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import numbers
from .utils import detect_json1


class Filter:
key = None
display = None
no_argument = False

def where_clause(self, table, column, value, param_counter):
raise NotImplementedError

def human_clause(self, column, value):
raise NotImplementedError


class TemplatedFilter(Filter):
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
self.key = key
self.display = display
self.sql_template = sql_template
self.human_template = human_template
self.format = format
self.numeric = numeric
self.no_argument = no_argument

def where_clause(self, table, column, value, param_counter):
converted = self.format.format(value)
if self.numeric and converted.isdigit():
converted = int(converted)
if self.no_argument:
kwargs = {
'c': column,
}
converted = None
else:
kwargs = {
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
return self.sql_template.format(**kwargs), converted

def human_clause(self, column, value):
if callable(self.human_template):
template = self.human_template(column, value)
else:
template = self.human_template
if self.no_argument:
return template.format(c=column)
else:
return template.format(c=column, v=value)


class Filters:
_filters = [
# key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p}
)""", '{c} contains "{v}"')
] if detect_json1() else []) + [
TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
]
_filters_by_key = {
f.key: f for f in _filters
}

def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs
self.units = units
self.ureg = ureg

def lookups(self):
"Yields (lookup, display, no_argument) pairs"
for filter in self._filters:
yield filter.key, filter.display, filter.no_argument

def human_description_en(self, extra=None):
bits = []
if extra:
bits.extend(extra)
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
bits.append(filter.human_clause(column, value))
# Comma separated, with an ' and ' at the end
and_bits = []
commas, tail = bits[:-1], bits[-1:]
if commas:
and_bits.append(', '.join(commas))
if tail:
and_bits.append(tail[0])
s = ' and '.join(and_bits)
if not s:
return ''
return 'where {}'.format(s)

def selections(self):
"Yields (column, lookup, value) tuples"
for key, value in self.pairs:
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
yield column, lookup, value

def has_selections(self):
return bool(self.pairs)

def convert_unit(self, column, value):
"If the user has provided a unit in the query, convert it into the column unit, if present."
if column not in self.units:
return value

# Try to interpret the value as a unit
value = self.ureg(value)
if isinstance(value, numbers.Number):
# It's just a bare number, assume it's the column unit
return value

column_unit = self.ureg(self.units[column])
return value.to(column_unit).magnitude

def build_where_clauses(self, table):
sql_bits = []
params = {}
i = 0
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
sql_bits.append(sql_bit)
if param is not None:
if not isinstance(param, list):
param = [param]
for individual_param in param:
param_id = 'p{}'.format(i)
params[param_id] = individual_param
i += 1
return sql_bits, params
137 changes: 0 additions & 137 deletions datasette/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,143 +584,6 @@ def table_columns(conn, table):
]


class Filter:
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
self.key = key
self.display = display
self.sql_template = sql_template
self.human_template = human_template
self.format = format
self.numeric = numeric
self.no_argument = no_argument

def where_clause(self, table, column, value, param_counter):
converted = self.format.format(value)
if self.numeric and converted.isdigit():
converted = int(converted)
if self.no_argument:
kwargs = {
'c': column,
}
converted = None
else:
kwargs = {
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
return self.sql_template.format(**kwargs), converted

def human_clause(self, column, value):
if callable(self.human_template):
template = self.human_template(column, value)
else:
template = self.human_template
if self.no_argument:
return template.format(c=column)
else:
return template.format(c=column, v=value)


class Filters:
_filters = [
# key, display, sql_template, human_template, format=, numeric=, no_argument=
Filter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
Filter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
Filter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
Filter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
Filter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
Filter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
Filter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
Filter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
Filter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
Filter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
Filter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
] + ([Filter('arraycontains', 'array contains', """rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p}
)""", '{c} contains "{v}"')
] if detect_json1() else []) + [
Filter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
Filter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
Filter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
Filter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
]
_filters_by_key = {
f.key: f for f in _filters
}

def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs
self.units = units
self.ureg = ureg

def lookups(self):
"Yields (lookup, display, no_argument) pairs"
for filter in self._filters:
yield filter.key, filter.display, filter.no_argument

def human_description_en(self, extra=None):
bits = []
if extra:
bits.extend(extra)
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
bits.append(filter.human_clause(column, value))
# Comma separated, with an ' and ' at the end
and_bits = []
commas, tail = bits[:-1], bits[-1:]
if commas:
and_bits.append(', '.join(commas))
if tail:
and_bits.append(tail[0])
s = ' and '.join(and_bits)
if not s:
return ''
return 'where {}'.format(s)

def selections(self):
"Yields (column, lookup, value) tuples"
for key, value in self.pairs:
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
yield column, lookup, value

def has_selections(self):
return bool(self.pairs)

def convert_unit(self, column, value):
"If the user has provided a unit in the query, convert it into the column unit, if present."
if column not in self.units:
return value

# Try to interpret the value as a unit
value = self.ureg(value)
if isinstance(value, numbers.Number):
# It's just a bare number, assume it's the column unit
return value

column_unit = self.ureg(self.units[column])
return value.to(column_unit).magnitude

def build_where_clauses(self, table):
sql_bits = []
params = {}
for i, (column, lookup, value) in enumerate(self.selections()):
filter = self._filters_by_key.get(lookup, None)
if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
sql_bits.append(sql_bit)
if param is not None:
param_id = 'p{}'.format(i)
params[param_id] = param
return sql_bits, params


filter_column_re = re.compile(r'^_filter_column_\d+$')


Expand Down
2 changes: 1 addition & 1 deletion datasette/views/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datasette.plugins import pm
from datasette.utils import (
CustomRow,
Filters,
InterruptedError,
append_querystring,
compound_keys_after_sql,
Expand All @@ -27,6 +26,7 @@
urlsafe_components,
value_as_boolean,
)
from datasette.filters import Filters
from .base import BaseView, DatasetteError, ureg

LINK_WITH_LABEL = '<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
Expand Down
64 changes: 64 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from datasette.filters import Filters
import pytest


@pytest.mark.parametrize('args,expected_where,expected_params', [
(
{
'name_english__contains': 'foo',
},
['"name_english" like :p0'],
['%foo%']
),
(
{
'foo': 'bar',
'bar__contains': 'baz',
},
['"bar" like :p0', '"foo" = :p1'],
['%baz%', 'bar']
),
(
{
'foo__startswith': 'bar',
'bar__endswith': 'baz',
},
['"bar" like :p0', '"foo" like :p1'],
['%baz', 'bar%']
),
(
{
'foo__lt': '1',
'bar__gt': '2',
'baz__gte': '3',
'bax__lte': '4',
},
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1]
),
(
{
'foo__like': '2%2',
'zax__glob': '3*',
},
['"foo" like :p0', '"zax" glob :p1'],
['2%2', '3*']
),
(
{
'foo__isnull': '1',
'baz__isnull': '1',
'bar__gt': '10'
},
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10]
),
])
def test_build_where(args, expected_where, expected_params):
f = Filters(sorted(args.items()))
sql_bits, actual_params = f.build_where_clauses("table")
assert expected_where == sql_bits
assert {
'p{}'.format(i): param
for i, param in enumerate(expected_params)
} == actual_params
Loading

0 comments on commit 6da567d

Please sign in to comment.