diff --git a/gnucash_reports/reports/savings_goals.py b/gnucash_reports/reports/savings_goals.py index 40232ca..c245a57 100644 --- a/gnucash_reports/reports/savings_goals.py +++ b/gnucash_reports/reports/savings_goals.py @@ -8,13 +8,11 @@ from gnucash_reports.configuration.current_date import get_today from gnucash_reports.configuration.currency import get_currency from gnucash_reports.periods import PeriodStart -from gnucash_reports.wrapper import get_account, get_balance_on_date, account_walker +from gnucash_reports.wrapper import get_account, get_balance_on_date, account_walker, parse_walker_parameters def savings_goal(definition): - accounts = definition['account'] - if isinstance(accounts, basestring): - accounts = [accounts] + walker_params = parse_walker_parameters(definition['savings']) goal_amount = Decimal(definition.get('goal', Decimal(0.0))) @@ -27,17 +25,9 @@ def savings_goal(definition): total_balance = Decimal('0.0') currency = get_currency() - for account_description in accounts: - multiplier = Decimal('1.0') - if isinstance(account_description, basestring): - account = account_description - else: - account = account_description[0] - multiplier = Decimal(account_description[1]) - - for account_name in account_walker([account]): - balance = get_balance_on_date(account_name, as_of.date, currency) - total_balance += (balance * multiplier) + for account in account_walker(**walker_params): + balance = get_balance_on_date(account, as_of.date, currency) + total_balance += balance for contribution in contributions: total_balance += Decimal(contribution) diff --git a/gnucash_reports/wrapper.py b/gnucash_reports/wrapper.py index 9c03a5b..513af2c 100644 --- a/gnucash_reports/wrapper.py +++ b/gnucash_reports/wrapper.py @@ -98,6 +98,10 @@ def account_walker(accounts, ignores=None, place_holders=False, recursive=True, if not ignores: ignores = [] + # Allow for a none account list to be provided + if accounts is None: + accounts = [] + _account_list = [a for a in accounts] while _account_list: @@ -125,6 +129,10 @@ def parse_walker_parameters(definition): 'recursive': True } + # Allow for a none definition to be provided and overwrite to an empty list + if definition is None: + definition = [] + if isinstance(definition, dict): return_value.update(definition) elif isinstance(definition, list) or isinstance(definition, set):