Skip to content

Commit

Permalink
Update test input data, add some test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Dec 11, 2024
1 parent 5ce6c2e commit 97d8c93
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions src/_gettsim_tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def minimal_input_data():
return out


@pytest.fixture
def input_data_aggregate_by_p_id():
@pytest.fixture(scope="module")
def minimal_input_data_as_df():
return pd.DataFrame(
{
"groupings__p_id": pd.Series([1, 2, 3], name="p_id"),
Expand Down Expand Up @@ -186,17 +186,23 @@ def c(b):
compute_taxes_and_transfers(minimal_input_data, environment, targets="c")


def test_data_as_dict():
def test_data_as_dict(minimal_input_data):
def c(b):
return b

data = {
"groupings": {
"p_id": pd.Series([1, 2, 3]),
"hh_id": pd.Series([1, 1, 2]),
},
"b": pd.Series([100, 200, 300]),
}
data = minimal_input_data.copy()
data["b"] = pd.Series([1, 2, 3], name="b")

environment = PolicyEnvironment({"c": c})
compute_taxes_and_transfers(data, environment, targets="c")


def test_data_as_df(minimal_input_data_as_df):
def c(b):
return b

data = minimal_input_data_as_df.copy()
data["b"] = pd.Series([1, 2, 3], name="b")

environment = PolicyEnvironment({"c": c})
compute_taxes_and_transfers(data, environment, targets="c")
Expand Down Expand Up @@ -439,22 +445,35 @@ def test_user_provided_aggregate_by_group_specs():
)


def test_user_provided_aggregate_by_group_specs_function():
@pytest.mark.parametrize(
"aggregate_by_group_specs",
[
{
"module_name": {
"betrag_double_m_hh": {
"source_col": "betrag_m_double",
"aggr": "max",
},
},
},
{
"module_name": {
"betrag_double_m_hh": {
"source_col": "module_name__betrag_m_double",
"aggr": "max",
},
},
},
],
)
def test_user_provided_aggregate_by_group_specs_function(aggregate_by_group_specs):
data = pd.DataFrame(
{
"groupings__p_id": [1, 2, 3],
"groupings__hh_id": [1, 1, 2],
"module_name__betrag_m": [200, 100, 100],
}
)
aggregate_by_group_specs = {
"module_name": {
"betrag_double_m_hh": {
"source_col": "betrag_m_double",
"aggr": "max",
}
},
}
expected_res = pd.Series([400, 400, 200])

def betrag_m_double(betrag_m):
Expand Down Expand Up @@ -542,7 +561,7 @@ def test_aggregate_by_group_specs_agg_not_impl():
("df, aggregate_by_p_id_spec, target, expected"),
[
(
"input_data_aggregate_by_p_id",
"minimal_input_data_as_df",
{
"module": {
"target_func": {
Expand All @@ -557,7 +576,7 @@ def test_aggregate_by_group_specs_agg_not_impl():
pd.Series([200, 100, 0]),
),
(
"input_data_aggregate_by_p_id",
"minimal_input_data_as_df",
{
"module": {
"target_func_m": {
Expand All @@ -572,7 +591,7 @@ def test_aggregate_by_group_specs_agg_not_impl():
pd.Series([2400, 1200, 0]),
),
(
"input_data_aggregate_by_p_id",
"minimal_input_data_as_df",
{
"module": {
"target_func_m": {
Expand Down

0 comments on commit 97d8c93

Please sign in to comment.