diff --git a/README.md b/README.md index 5a37327..e55d5d1 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ schema = { {"key": "vat", "column-number": 10, "type": "decimal", "optional": False}, {"key": "initial-price", "column-number": 11, "type": "decimal", "post-processors": {"name": "divide", "parameters": {"denominator": 100}}}, {"key": "unit-of-measurement", "column-number": 12, "type": "int", "pre-processors": [{"name": "map", "parameters": {"values": {"K": 0, "A": 1, "L": 2}}}]}, - {"key": "volume", "column-number": 13, "type": "decimal"}, + {"key": "volume", "column-number": 13, "type": "decimal", "post-processors": {"name": "round", "parameters": {"precision": 3}}}, ] } @@ -121,3 +121,4 @@ assert rows == [{"name": "Joe"}, {"name": "William"}, {"name": "Jack"}, {"name": #### Post-processors - divide +- round diff --git a/magicparse/post_processors.py b/magicparse/post_processors.py index 6131c4d..2e901dd 100644 --- a/magicparse/post_processors.py +++ b/magicparse/post_processors.py @@ -42,4 +42,24 @@ def key() -> str: return "divide" -builtins = [Divide] +class Round(PostProcessor): + Number = TypeVar("Number", int, float, Decimal) + + def __init__(self, precision: int) -> None: + if precision < 0: + raise ValueError( + "post-processor 'round': " + "'precision' parameter must be a positive or zero integer" + ) + + self.precision = precision + + def apply(self, value: Number) -> Number: + return round(value, self.precision) + + @staticmethod + def key() -> str: + return "round" + + +builtins = [Divide, Round] diff --git a/tests/test_post_processors.py b/tests/test_post_processors.py index 59bd07b..b6bd4e1 100644 --- a/tests/test_post_processors.py +++ b/tests/test_post_processors.py @@ -49,6 +49,22 @@ def test_divide_decimal(self): assert post_processor.apply(Decimal("1.63")) == Decimal("0.0163") +class TestRound(TestCase): + def test_with_negative_precision(self): + error_message = ( + "post-processor 'round': " + "'precision' parameter must be a positive or zero integer" + ) + with pytest.raises(ValueError, match=error_message): + PostProcessor.build({"name": "round", "parameters": {"precision": -2}}) + + def test_with_valid_precision(self): + post_processor = PostProcessor.build( + {"name": "round", "parameters": {"precision": 2}} + ) + assert post_processor.apply(3.14159265359) == 3.14 + + class TestRegister(TestCase): class NoThanksPostProcessor(PostProcessor): @staticmethod