Skip to content

Commit

Permalink
Merge pull request #735 from dannysepler/sum-should-work-with-time-de…
Browse files Browse the repository at this point in the history
…ltas

Sum should work with time deltas, also isinstance cleanups
  • Loading branch information
jpmckinney authored Mar 31, 2020
2 parents bc72327 + 29ccc9d commit e2e259f
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 25 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ agate is made by a community. The following individuals have contributed code, d
* `Kartik Agaram <https://github.com/akkartik>`_
* `Loïc Corbasson <https://github.com/lcorbasson>`_
* `Robert Schütz <https://github.com/dotlambda>`_
* `Danny Sepler <https://github.com/dannysepler>`_
10 changes: 3 additions & 7 deletions agate/aggregations/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@ def __init__(self, column_name):
def get_aggregate_data_type(self, table):
column = table.columns[self._column_name]

if (isinstance(column.data_type, Number) or
isinstance(column.data_type, Date) or
isinstance(column.data_type, DateTime)):
if isinstance(column.data_type, (Number, Date, DateTime)):
return column.data_type

def validate(self, table):
column = table.columns[self._column_name]

if not (isinstance(column.data_type, Number) or
isinstance(column.data_type, Date) or
isinstance(column.data_type, DateTime)):
raise DataTypeError('Min can only be applied to columns containing DateTime orNumber data.')
if not isinstance(column.data_type, (Number, Date, DateTime)):
raise DataTypeError('Min can only be applied to columns containing DateTime, Date or Number data.')

def run(self, table):
column = table.columns[self._column_name]
Expand Down
10 changes: 3 additions & 7 deletions agate/aggregations/min.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@ def __init__(self, column_name):
def get_aggregate_data_type(self, table):
column = table.columns[self._column_name]

if (isinstance(column.data_type, Number) or
isinstance(column.data_type, Date) or
isinstance(column.data_type, DateTime)):
if isinstance(column.data_type, (Number, Date, DateTime)):
return column.data_type

def validate(self, table):
column = table.columns[self._column_name]

if not (isinstance(column.data_type, Number) or
isinstance(column.data_type, Date) or
isinstance(column.data_type, DateTime)):
raise DataTypeError('Min can only be applied to columns containing DateTime orNumber data.')
if not isinstance(column.data_type, (Number, Date, DateTime)):
raise DataTypeError('Min can only be applied to columns containing DateTime, Date or Number data.')

def run(self, table):
column = table.columns[self._column_name]
Expand Down
18 changes: 13 additions & 5 deletions agate/aggregations/sum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python
import datetime

from agate.aggregations.base import Aggregation
from agate.data_types import Number
from agate.data_types import Number, TimeDelta
from agate.exceptions import DataTypeError


Expand All @@ -16,15 +17,22 @@ def __init__(self, column_name):
self._column_name = column_name

def get_aggregate_data_type(self, table):
return Number()
column = table.columns[self._column_name]

if isinstance(column.data_type, (Number, TimeDelta)):
return column.data_type

def validate(self, table):
column = table.columns[self._column_name]

if not isinstance(column.data_type, Number):
raise DataTypeError('Sum can only be applied to columns containing Number data.')
if not isinstance(column.data_type, (Number, TimeDelta)):
raise DataTypeError('Sum can only be applied to columns containing Number or TimeDelta data.')

def run(self, table):
column = table.columns[self._column_name]

return sum(column.values_without_nulls())
start = 0
if isinstance(column.data_type, TimeDelta):
start = datetime.timedelta()

return sum(column.values_without_nulls(), start)
6 changes: 1 addition & 5 deletions agate/computations/change.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ def __init__(self, before_column_name, after_column_name):
def get_computed_data_type(self, table):
before_column = table.columns[self._before_column_name]

if isinstance(before_column.data_type, Date):
return TimeDelta()
elif isinstance(before_column.data_type, DateTime):
return TimeDelta()
elif isinstance(before_column.data_type, TimeDelta):
if isinstance(before_column.data_type, (Date, DateTime, TimeDelta)):
return TimeDelta()
elif isinstance(before_column.data_type, Number):
return Number()
Expand Down
2 changes: 1 addition & 1 deletion agate/table/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def apply_computation(table):
if pivot is not None:
groups = groups.group_by(pivot)

column_type = aggregation.get_aggregate_data_type(groups)
column_type = aggregation.get_aggregate_data_type(self)

table = groups.aggregate([
(aggregation_name, aggregation)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ def test_max(self):
Max('test').validate(table)
self.assertEqual(Max('test').run(table), datetime.datetime(1994, 3, 3, 6, 31))

def test_sum(self):
rows = [
[datetime.timedelta(seconds=10)],
[datetime.timedelta(seconds=20)],
]

table = Table(rows, ['test'], [TimeDelta()])

self.assertIsInstance(Sum('test').get_aggregate_data_type(table), TimeDelta)
Sum('test').validate(table)
self.assertEqual(Sum('test').run(table), datetime.timedelta(seconds=30))


class TestNumberAggregation(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit e2e259f

Please sign in to comment.