Skip to content

Commit

Permalink
WIP: writting tests for categorical data in design sheet
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 10, 2024
1 parent af2fb0f commit e5ff167
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/ert/ui_tests/cli/analysis/test_design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_run_poly_example_with_design_matrix():
{
"REAL": list(range(num_realizations)),
"a": a_values,
"category": 5 * ["cat1"] + 5 * ["cat2"],
}
)
default_sheet_df = pd.DataFrame([["b", 1], ["c", 2]])
Expand Down Expand Up @@ -59,6 +60,7 @@ def _load_coeffs(filename):
return json.load(f)["DESIGN_MATRIX"]
def _evaluate(coeffs, x):
assert coeffs["category"] in ["cat1", "cat2"]
return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"]
if __name__ == "__main__":
Expand Down Expand Up @@ -88,8 +90,9 @@ def _evaluate(coeffs, x):
"DESIGN_MATRIX"
)["values"]
np.testing.assert_array_equal(params[:, 0], a_values)
np.testing.assert_array_equal(params[:, 1], 10 * [1])
np.testing.assert_array_equal(params[:, 2], 10 * [2])
np.testing.assert_array_equal(params[:, 0], 5 * ["cat1"] + 5 * ["cat2"])
np.testing.assert_array_equal(params[:, 2], 10 * [1])
np.testing.assert_array_equal(params[:, 3], 10 * [2])


@pytest.mark.usefixtures("copy_poly_case")
Expand Down

0 comments on commit e5ff167

Please sign in to comment.