From ca1151f0923f02cc2f280b71648833e8c5e00d0f Mon Sep 17 00:00:00 2001 From: davidusb-geek Date: Sat, 3 Feb 2024 15:46:39 +0100 Subject: [PATCH] Addeed more tests but passing correctly --- tests/test_command_line_utils.py | 1 + tests/test_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/tests/test_command_line_utils.py b/tests/test_command_line_utils.py index d84cfb3e..71c69fdf 100644 --- a/tests/test_command_line_utils.py +++ b/tests/test_command_line_utils.py @@ -290,6 +290,7 @@ def test_forecast_model_fit_predict_tune(self): action, logger, get_data_from_file=True) self.assertTrue(input_data_dict['params']['passed_data']['model_type'] == 'load_forecast') self.assertTrue(input_data_dict['params']['passed_data']['sklearn_model'] == 'KNeighborsRegressor') + self.assertTrue(input_data_dict['params']['passed_data']['perform_backtest'] == False) # Check that the default params are loaded input_data_dict = set_input_data_dict(config_path, base_path, costfun, self.params_json, self.runtimeparams_json, action, logger, get_data_from_file=True) diff --git a/tests/test_utils.py b/tests/test_utils.py index f766d147..facbea7c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -131,6 +131,19 @@ def test_treat_runtimeparams(self): runtimeparams.update({'custom_unit_load_cost_id':'my_custom_unit_load_cost_id'}) runtimeparams.update({'custom_unit_prod_price_id':'my_custom_unit_prod_price_id'}) runtimeparams.update({'custom_deferrable_forecast_id':'my_custom_deferrable_forecast_id'}) + runtimeparams.update({'days_to_retrieve':15}) + runtimeparams.update({'model_type':"my_special_model"}) + runtimeparams.update({'var_model':"sensor.my_special_sensor"}) + runtimeparams.update({'sklearn_model':"LinearRegression"}) + runtimeparams.update({'num_lags':12}) + runtimeparams.update({'split_date_delta':"24h"}) + runtimeparams.update({'perform_backtest':True}) + runtimeparams.update({'model_predict_entity_id':"sensor.my_custom_model_forecast"}) + runtimeparams.update({'model_predict_unit_of_measurement':"kW"}) + runtimeparams.update({'model_predict_friendly_name':"That friendly name"}) + runtimeparams.update({'model_predict_unit_of_measurement':"kW"}) + runtimeparams.update({'model_predict_unit_of_measurement':"kW"}) + runtimeparams.update({'model_predict_unit_of_measurement':"kW"}) runtimeparams_json = json.dumps(runtimeparams) retrieve_hass_conf, optim_conf, plant_conf = utils.get_yaml_parse( @@ -167,6 +180,20 @@ def test_treat_runtimeparams(self): self.assertTrue(params['passed_data']['custom_unit_load_cost_id'] == 'my_custom_unit_load_cost_id') self.assertTrue(params['passed_data']['custom_unit_prod_price_id'] == 'my_custom_unit_prod_price_id') self.assertTrue(params['passed_data']['custom_deferrable_forecast_id'] == 'my_custom_deferrable_forecast_id') + + self.assertTrue(params['passed_data']['days_to_retrieve'] == 15) + self.assertTrue(params['passed_data']['model_type'] == "my_special_model") + self.assertTrue(params['passed_data']['var_model'] == "sensor.my_special_sensor") + self.assertTrue(params['passed_data']['sklearn_model'] == "LinearRegression") + self.assertTrue(params['passed_data']['num_lags'] == 12) + self.assertTrue(params['passed_data']['split_date_delta'] == "24h") + self.assertTrue(params['passed_data']['perform_backtest'] == True) + self.assertTrue(params['passed_data']['model_predict_entity_id'] == "sensor.my_custom_model_forecast") + self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW") + self.assertTrue(params['passed_data']['model_predict_friendly_name'] == "That friendly name") + self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW") + self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW") + self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW") def test_treat_runtimeparams_failed(self): params = TestCommandLineUtils.get_test_params()