Skip to content

Commit

Permalink
Implement the diff command (raw term with colors)
Browse files Browse the repository at this point in the history
  • Loading branch information
NyanKiyoshi committed May 20, 2019
1 parent d231310 commit 8d1821a
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 47 deletions.
39 changes: 39 additions & 0 deletions pytest_django_queries/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
from jinja2 import Template
from jinja2 import exceptions as jinja_exceptions

from pytest_django_queries.diff import DiffGenerator
from pytest_django_queries.entry import flatten_entries
from pytest_django_queries.plugin import DEFAULT_RESULT_FILENAME
from pytest_django_queries.tables import print_entries, print_entries_as_html

HERE = dirname(__file__)
DEFAULT_TEMPLATE_PATH = abspath(pathjoin(HERE, 'templates', 'default_bootstrap.jinja2'))

DIFF_TERM_COLOR = {'-': 'red', '+': 'green'}
DEFAULT_TERM_DIFF_COLOR = None


class JsonFileParamType(click.File):
name = 'integer'
Expand Down Expand Up @@ -66,5 +71,39 @@ def html(input_file, template):
return print_entries_as_html(input_file, template)


@main.command()
@click.argument(
'left_file', type=JsonFileParamType('r'))
@click.argument(
'right_file', type=JsonFileParamType('r'), default=DEFAULT_RESULT_FILENAME)
def diff(left_file, right_file):
"""Render the diff as a console table with colors."""
left = flatten_entries(left_file)
right = flatten_entries(right_file)
first_line = True
for module_name, lines in DiffGenerator(left, right):
if not first_line:
click.echo()
else:
first_line = False

click.echo('# %s' % module_name)
for line in lines:
fg_color = DIFF_TERM_COLOR.get(line[0], DEFAULT_TERM_DIFF_COLOR)
click.secho(line, fg=fg_color)


@main.command()
@click.argument(
'left_file', type=JsonFileParamType('r'))
@click.argument(
'right_file', type=JsonFileParamType('r'), default=DEFAULT_RESULT_FILENAME)
def ediff(left_file, right_file):
"""Render the diff as HTML instead of a diff table."""
left = flatten_entries(left_file)
right = flatten_entries(right_file)
raise NotImplementedError


if __name__ == '__main__':
main()
173 changes: 173 additions & 0 deletions pytest_django_queries/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# coding=utf-8
from collections import namedtuple
from pytest_django_queries.entry import Entry
from pytest_django_queries.filters import format_underscore_name_to_human

_ROW_FIELD = namedtuple('_RowField', ('comp_field', 'align_char', 'length_field'))
_ROW_FIELDS = (
_ROW_FIELD('test_name', '<', 'test_name'),
_ROW_FIELD('left_count', '>', 'query_count'),
_ROW_FIELD('right_count', '>', 'query_count'),
)
_ROW_PREFIX = ' '
_NA_CHAR = '-'


def entry_row(entry_comp, lengths):
cols = []

for field, align, length_key in _ROW_FIELDS:
fmt = '{cmp.%s: %s{lengths[%s]}}' % (field, align, length_key)
cols.append(fmt.format(cmp=entry_comp, lengths=lengths))

return '%(diff_char)s %(results)s' % ({
'diff_char': entry_comp.diff,
'results': '\t'.join(cols)})


def get_header_row(lengths):
sep_row = []
head_row = []

for field, _, length_key in _ROW_FIELDS:
length = lengths[length_key]
sep_row.append('%s' % ('-' * length))
head_row.append('{field: <{length}}'.format(
field=field.replace('_', ' '), length=length))

return '%(prefix)s%(head)s\n%(prefix)s%(sep)s' % ({
'prefix': _ROW_PREFIX,
'head': '\t'.join(head_row),
'sep': '\t'.join(sep_row)})


class DiffChars(object):
NEGATIVE = '-'
NEUTRAL = ' '
POSITIVE = '+'

@classmethod
def convert(cls, diff):
if diff < 0:
return DiffChars.POSITIVE
if diff > 0:
return DiffChars.NEGATIVE
return DiffChars.NEUTRAL


class SingleEntryComparison(object):
__slots__ = ["left", "right", "diff"]

def __init__(self, left=None, right=None):
"""
:param left: Previous version.
:type left: Entry
:param right: Newest version.
:type right: Entry
"""

self.left = left
self.right = right
self.diff = None

def _diff_from_newest(self):
"""
Returns the query count difference from the previous version.
If there is no older version, we assume it's an "improvement" (positive output)
If there is no new version, we assume it's not an improvement (negative output)
"""
if self.left is None:
return DiffChars.POSITIVE
if self.right is None:
return DiffChars.NEGATIVE
return DiffChars.convert(self.right.query_count - self.left.query_count)

@property
def test(self):
return self.left or self.right

@property
def test_name(self):
return format_underscore_name_to_human(self.test.test_name)

@property
def left_count(self):
return str(self.left.query_count) if self.left else _NA_CHAR

@property
def right_count(self):
return str(self.right.query_count) if self.right else _NA_CHAR

def to_string(self, lengths):
if self.diff is None:
self.diff = self._diff_from_newest()

return entry_row(self, lengths=lengths)


class DiffGenerator(object):
def __init__(self, entries_left, entries_right):
"""
Generates the diffs from two files.
:param entries_left:
:type entries_left: List[Entry]
:param entries_right:
:type entries_right: List[Entry]
"""

self.entries_left = entries_left
self.entries_right = entries_right

self._mapping = {}
self._generate_mapping()
self.longest_props = self._get_longest_per_prop({'query_count', 'test_name'})
self.header_rows = get_header_row(lengths=self.longest_props)

def _get_longest_per_prop(self, props):
"""
:param props:
:type props: set
:return:
"""

longest = {prop: 0 for prop in props}
entries = (
self.entries_left + self.entries_right + [field for field, _, _ in _ROW_FIELDS]
)

for entry in entries:
for prop in props:
if isinstance(entry, Entry):
current_length = len(str(getattr(entry, prop, None)))
else:
current_length = len(entry)
if current_length > longest[prop]:
longest[prop] = current_length

return longest

def _map_side(self, entries, side_name):
for entry in entries:
module_map = self._mapping.setdefault(entry.module_name, {})

if entry.test_name not in module_map:
module_map[entry.test_name] = SingleEntryComparison()

setattr(module_map[entry.test_name], side_name, entry)

def _generate_mapping(self):
self._map_side(self.entries_left, 'left')
self._map_side(self.entries_right, 'right')

def iter_module(self, module_entries):
yield self.header_rows
for test_comparison in module_entries.values(): # type: SingleEntryComparison
yield test_comparison.to_string(lengths=self.longest_props)

def __iter__(self):
for module_name, module_entries in self._mapping.items():
yield (
format_underscore_name_to_human(module_name),
self.iter_module(module_entries))
54 changes: 54 additions & 0 deletions pytest_django_queries/entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pytest_django_queries.utils import assert_type, raise_error


class Entry(object):
BASE_FIELDS = [
('test_name', 'Test Name')
]
REQUIRED_FIELDS = [
('query-count', 'Queries'),
]
FIELDS = BASE_FIELDS + REQUIRED_FIELDS

def __init__(self, test_name, module_name, data):
"""
:param data: The test entry's data.
:type data: dict
"""

assert_type(data, dict)

self._raw_data = data
self.test_name = test_name
self.module_name = module_name

for field, _ in self.REQUIRED_FIELDS:
setattr(self, field, self._get_required_key(field))

def __getitem__(self, item):
return getattr(self, item)

@property
def query_count(self):
return self['query-count']

def _get_required_key(self, key):
if key in self._raw_data:
return self._raw_data.get(key)
raise_error('Got invalid data. It is missing a required key: %s' % key)


def iter_entries(entries):
for module_name, module_data in sorted(entries.items()):
assert_type(module_data, dict)

yield module_name, (
Entry(test_name, module_name, test_data)
for test_name, test_data in sorted(module_data.items()))


def flatten_entries(file_content):
entries = []
for _, data in iter_entries(file_content):
entries += list(data)
return entries
4 changes: 4 additions & 0 deletions pytest_django_queries/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def format_underscore_name_to_human(name):
if name.startswith('test'):
_, name = name.split('test', 1)
return name.replace('_', ' ').strip()
51 changes: 4 additions & 47 deletions pytest_django_queries/tables.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,19 @@
import click
from beautifultable import BeautifulTable

from pytest_django_queries.utils import assert_type, raise_error


def format_underscore_name_to_human(name):
if name.startswith('test'):
_, name = name.split('test', 1)
return name.replace('_', ' ')


class TestEntryData(object):
BASE_FIELDS = [
('test_name', 'Test Name')
]
REQUIRED_FIELDS = [
('query-count', 'Queries'),
]
FIELDS = BASE_FIELDS + REQUIRED_FIELDS

def __init__(self, test_name, data):
"""
:param data: The test entry's data.
:type data: dict
"""

assert_type(data, dict)

self._raw_data = data
self.test_name = test_name

for field, _ in self.REQUIRED_FIELDS:
setattr(self, field, self._get_required_key(field))

def _get_required_key(self, key):
if key in self._raw_data:
return self._raw_data.get(key)
raise_error('Got invalid data. It is missing a required key: %s' % key)


def iter_entries(entries):
for module_name, module_data in sorted(entries.items()):
assert_type(module_data, dict)

yield module_name, (
TestEntryData(test_name, test_data)
for test_name, test_data in sorted(module_data.items()))
from pytest_django_queries.entry import Entry, iter_entries
from pytest_django_queries.filters import format_underscore_name_to_human


def print_entries(data):
table = BeautifulTable()
table.column_headers = ['Module', 'Tests']
for module_name, module_entries in iter_entries(data):
subtable = BeautifulTable()
subtable.column_headers = [field for _, field in TestEntryData.FIELDS]
subtable.column_headers = [field for _, field in Entry.FIELDS]
for entry in module_entries:
subtable.append_row([
getattr(entry, field) for field, _ in TestEntryData.FIELDS])
getattr(entry, field) for field, _ in Entry.FIELDS])
table.append_row([module_name, subtable])
click.echo(table)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ django ; python_version >= '3.0'
Click
beautifultable==0.7.0
jinja2
colorama

0 comments on commit 8d1821a

Please sign in to comment.