Skip to content

Commit

Permalink
Addeed more tests but passing correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
davidusb-geek authored and GeoDerp committed Feb 4, 2024
1 parent 70bea29 commit ca1151f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/test_command_line_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def test_forecast_model_fit_predict_tune(self):
action, logger, get_data_from_file=True)
self.assertTrue(input_data_dict['params']['passed_data']['model_type'] == 'load_forecast')
self.assertTrue(input_data_dict['params']['passed_data']['sklearn_model'] == 'KNeighborsRegressor')
self.assertTrue(input_data_dict['params']['passed_data']['perform_backtest'] == False)
# Check that the default params are loaded
input_data_dict = set_input_data_dict(config_path, base_path, costfun, self.params_json, self.runtimeparams_json,
action, logger, get_data_from_file=True)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,19 @@ def test_treat_runtimeparams(self):
runtimeparams.update({'custom_unit_load_cost_id':'my_custom_unit_load_cost_id'})
runtimeparams.update({'custom_unit_prod_price_id':'my_custom_unit_prod_price_id'})
runtimeparams.update({'custom_deferrable_forecast_id':'my_custom_deferrable_forecast_id'})
runtimeparams.update({'days_to_retrieve':15})
runtimeparams.update({'model_type':"my_special_model"})
runtimeparams.update({'var_model':"sensor.my_special_sensor"})
runtimeparams.update({'sklearn_model':"LinearRegression"})
runtimeparams.update({'num_lags':12})
runtimeparams.update({'split_date_delta':"24h"})
runtimeparams.update({'perform_backtest':True})
runtimeparams.update({'model_predict_entity_id':"sensor.my_custom_model_forecast"})
runtimeparams.update({'model_predict_unit_of_measurement':"kW"})
runtimeparams.update({'model_predict_friendly_name':"That friendly name"})
runtimeparams.update({'model_predict_unit_of_measurement':"kW"})
runtimeparams.update({'model_predict_unit_of_measurement':"kW"})
runtimeparams.update({'model_predict_unit_of_measurement':"kW"})

runtimeparams_json = json.dumps(runtimeparams)
retrieve_hass_conf, optim_conf, plant_conf = utils.get_yaml_parse(
Expand Down Expand Up @@ -167,6 +180,20 @@ def test_treat_runtimeparams(self):
self.assertTrue(params['passed_data']['custom_unit_load_cost_id'] == 'my_custom_unit_load_cost_id')
self.assertTrue(params['passed_data']['custom_unit_prod_price_id'] == 'my_custom_unit_prod_price_id')
self.assertTrue(params['passed_data']['custom_deferrable_forecast_id'] == 'my_custom_deferrable_forecast_id')

self.assertTrue(params['passed_data']['days_to_retrieve'] == 15)
self.assertTrue(params['passed_data']['model_type'] == "my_special_model")
self.assertTrue(params['passed_data']['var_model'] == "sensor.my_special_sensor")
self.assertTrue(params['passed_data']['sklearn_model'] == "LinearRegression")
self.assertTrue(params['passed_data']['num_lags'] == 12)
self.assertTrue(params['passed_data']['split_date_delta'] == "24h")
self.assertTrue(params['passed_data']['perform_backtest'] == True)
self.assertTrue(params['passed_data']['model_predict_entity_id'] == "sensor.my_custom_model_forecast")
self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW")
self.assertTrue(params['passed_data']['model_predict_friendly_name'] == "That friendly name")
self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW")
self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW")
self.assertTrue(params['passed_data']['model_predict_unit_of_measurement'] == "kW")

def test_treat_runtimeparams_failed(self):
params = TestCommandLineUtils.get_test_params()
Expand Down

0 comments on commit ca1151f

Please sign in to comment.