Skip to content

Commit

Permalink
Added 4 new graphs to Pareto Optimizer + bug fixes (#1159)
Browse files Browse the repository at this point in the history
* added formatting with black-formatter

* Added 4 new graphs to Pareto Optimizer + bug fixes

* corrected notebooks

* need to set cluster = true if you want bootstrap graph

* Added Pareto utils call to src notebook

---------

Co-authored-by: Dhaval Patel <[email protected]>
  • Loading branch information
dhavalpatel624624 and Dhaval Patel authored Nov 22, 2024
1 parent f0e6322 commit 209513b
Show file tree
Hide file tree
Showing 17 changed files with 533 additions and 204 deletions.
4 changes: 2 additions & 2 deletions python/src/robyn/data/entities/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class Hyperparameters:
hyperparameters (Dict[str, Hyperparameter]): A dictionary of hyperparameters where the key is the channel name and the value is a Hyperparameter object.
"""

hyperparameters: Dict[str, ChannelHyperparameters] = (None,)
adstock: AdstockType = (None,) # Mandatory. User provides this.
hyperparameters: Dict[str, ChannelHyperparameters] = field(default_factory=dict)
adstock: AdstockType = AdstockType.GEOMETRIC # Mandatory. User provides this.
lambda_: float = 0.0 # User does not provide this. Model run calculates it.
train_size: List[float] = field(default_factory=lambda: [0.5, 0.8])
hyper_bound_list_updated: Dict[str, List[float]] = field(default_factory=dict)
Expand Down
20 changes: 20 additions & 0 deletions python/src/robyn/data/entities/mmmdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,23 @@ def calculate_rolling_window_indices(self) -> None:
- self.mmmdata_spec.rolling_window_start_which
+ 1
)

def set_default_factor_vars(self) -> None:
"""
Set the default factor variables.
"""
factor_variables = self.mmmdata_spec.factor_vars
selected_columns = self.data[self.mmmdata_spec.context_vars]
non_numeric_columns = ~selected_columns.applymap(
lambda x: isinstance(x, (int, float))
).all()
if non_numeric_columns.any():
non_factor_columns = non_numeric_columns[
~non_numeric_columns.index.isin(factor_variables or [])
]
non_factor_columns = non_factor_columns[non_factor_columns]
if len(non_factor_columns) > 0:
factor_variables = (
factor_variables or []
) + non_factor_columns.index.tolist()
self.mmmdata_spec.factor_vars = factor_variables
2 changes: 1 addition & 1 deletion python/src/robyn/modeling/clustering/cluster_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _calculate_confidence_intervals(
cluster_collect = []

self.logger.debug(f"Processing {config.k_clusters} clusters")
for j in range(1, config.k_clusters + 1):
for j in range(0, config.k_clusters):
df_outcome = df_clusters_outcome[df_clusters_outcome["cluster"] == j]
if len(df_outcome["sol_id"].unique()) < 3:
self.logger.warning(
Expand Down
2 changes: 2 additions & 0 deletions python/src/robyn/modeling/feature_engineering.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def _prophet_decomposition(self, dt_mod: pd.DataFrame) -> pd.DataFrame:
dt_regressors["ds"] = pd.to_datetime(dt_regressors["ds"])

# Handle factor variables
if self.mmm_data.mmmdata_spec.factor_vars is None:
self.mmm_data.set_default_factor_vars()
factor_vars = self.mmm_data.mmmdata_spec.factor_vars
if factor_vars:
# Create dummy variables but keep original
Expand Down
2 changes: 1 addition & 1 deletion python/src/robyn/modeling/pareto/pareto_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def optimize(
self.logger.info("Pareto optimization completed successfully")
return ParetoResult(
pareto_solutions=plotting_data["pareto_solutions"],
pareto_fronts=pareto_fronts,
pareto_fronts=max(pareto_data.pareto_fronts),
result_hyp_param=aggregated_data["result_hyp_param"],
result_calibration=aggregated_data["result_calibration"],
x_decomp_agg=pareto_data.x_decomp_agg,
Expand Down
13 changes: 8 additions & 5 deletions python/src/robyn/reporting/onepager_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from robyn.modeling.entities.pareto_result import ParetoResult
from robyn.modeling.entities.clustering_results import ClusteredResult
from robyn.data.entities.hyperparameters import AdstockType
from robyn.data.entities.hyperparameters import Hyperparameters
from robyn.data.entities.mmmdata import MMMData
from robyn.data.entities.enums import PlotType

Expand All @@ -27,13 +27,13 @@ def __init__(
self,
pareto_result: ParetoResult,
clustered_result: Optional[ClusteredResult] = None,
adstock: Optional[AdstockType] = None,
hyperparameter: Optional[Hyperparameters] = None,
mmm_data: Optional[MMMData] = None,
holidays_data: Optional[HolidaysData] = None,
):
self.pareto_result = pareto_result
self.clustered_result = clustered_result
self.adstock = adstock
self.hyperparameter = hyperparameter
self.mmm_data = mmm_data
self.holidays_data = holidays_data

Expand Down Expand Up @@ -270,9 +270,12 @@ def _generate_solution_plots(
# Initialize visualizers
pareto_viz = (
ParetoVisualizer(
self.pareto_result, self.adstock, self.mmm_data, self.holidays_data
self.pareto_result,
self.mmm_data,
self.holidays_data,
self.hyperparameter,
)
if self.adstock and self.holidays_data
if self.hyperparameter and self.holidays_data
else None
)
cluster_viz = (
Expand Down
19 changes: 10 additions & 9 deletions python/src/robyn/robyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict, Optional, List
import numpy as np
from robyn.modeling.entities.clustering_results import ClusteredResult
from robyn.data.entities.enums import AdstockType, PlotType
import copy
from robyn.data.entities.mmmdata import MMMData
from robyn.data.entities.holidays_data import HolidaysData
from robyn.data.entities.hyperparameters import Hyperparameters
Expand Down Expand Up @@ -196,7 +194,6 @@ def train_models(
try:
logger.info("Training models")
trials_config = trials_config or TrialsConfig(trials=5, iterations=2000)

model_executor = ModelExecutor(
mmmdata=self.mmm_data,
holidays_data=self.holidays_data,
Expand Down Expand Up @@ -261,6 +258,7 @@ def evaluate_models(
holidays_data=self.holidays_data,
)
self.pareto_result = pareto_optimizer.optimize(**pareto_config)
unfiltered_pareto_result = copy.deepcopy(self.pareto_result)

# Optional clustering
is_clustered = False
Expand All @@ -274,10 +272,13 @@ def evaluate_models(
)
if display_plots or export_plots:
pareto_visualizer = ParetoVisualizer(
self.pareto_result,
self.hyperparameters.adstock,
self.mmm_data,
self.holidays_data,
pareto_result=self.pareto_result,
mmm_data=self.mmm_data,
holiday_data=self.holidays_data,
hyperparameter=self.hyperparameters,
featurized_mmm_data=self.featurized_mmm_data,
unfiltered_pareto_result=unfiltered_pareto_result,
model_outputs=self.model_outputs,
)
pareto_visualizer.plot_all(display_plots, self.working_dir)
if self.cluster_result:
Expand Down Expand Up @@ -412,7 +413,7 @@ def generate_one_pager(self, solution_id: Optional[str] = None) -> None:
onepager = OnePager(
pareto_result=self.pareto_result,
clustered_result=self.cluster_result,
adstock=self.hyperparameters.adstock,
hyperparameter=self.hyperparameters,
mmm_data=self.mmm_data,
holidays_data=self.holidays_data,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"outputs": [],
"source": [
"hyperparameters = Hyperparameters(\n",
" {\n",
" hyperparameters={\n",
" \"facebook_S\": ChannelHyperparameters(\n",
" alphas=[0.5, 3],\n",
" gammas=[0.3, 1],\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,32 @@
"source": [
"from robyn.data.entities.enums import AdstockType\n",
"from robyn.reporting.onepager_reporting import OnePager\n",
"from robyn.visualization.pareto_visualizer import ParetoVisualizer\n",
"\n",
"visualizer = OnePager(\n",
" pareto_result=filtered_pareto_results,\n",
" clustered_result=cluster_results,\n",
" adstock=AdstockType.GEOMETRIC,\n",
" mmm_data=mmm_data,\n",
" holidays_data=holidays_data,\n",
")\n",
"\n",
"visualizer = OnePager(pareto_result=filtered_pareto_results, clustered_result=cluster_results, hyperparameter=hyperparameters, mmm_data=mmm_data, holidays_data=holidays_data)\n",
"visualizer.generate_one_pager(top_pareto=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"visualizer = ParetoVisualizer(\n",
" pareto_result=filtered_pareto_results, \n",
" hyperparameter=hyperparameters, \n",
" mmm_data=mmm_data, \n",
" holiday_data=holidays_data,\n",
" featurized_mmm_data=featurized_mmm_data,\n",
" unfiltered_pareto_result=pareto_result,\n",
" model_outputs=output_models)\n",
"\n",
"visualizer.plot_all(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
8 changes: 3 additions & 5 deletions python/src/robyn/tutorials/tutorial1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
},
{
"cell_type": "markdown",
"id": "0a1cce14",
"id": "eefbc5da",
"metadata": {},
"source": [
"## 2.2. Initialize Robyn\n",
Expand Down Expand Up @@ -126,7 +126,7 @@
"\n",
"# Create Hyperparameters\n",
"hyperparameters = Hyperparameters(\n",
" {\n",
" hyperparameters={\n",
" \"facebook_S\": ChannelHyperparameters(\n",
" alphas=[0.5, 3],\n",
" gammas=[0.3, 1],\n",
Expand Down Expand Up @@ -258,10 +258,8 @@
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from robyn.modeling.clustering.clustering_config import ClusterBy, ClusteringConfig\n",
"\n",
"\n",
"configs = ClusteringConfig(\n",
" dep_var_type= DependentVarType(mmm_data.mmmdata_spec.dep_var_type),\n",
" cluster_by = ClusterBy.HYPERPARAMETERS,\n",
Expand Down Expand Up @@ -388,7 +386,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
25 changes: 22 additions & 3 deletions python/src/robyn/tutorials/tutorial1_src.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -180,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -333,6 +333,25 @@
"print(cluster_results)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reestablish Pareto Results"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from robyn.modeling.pareto.pareto_utils import ParetoUtils\n",
"\n",
"utils = ParetoUtils()\n",
"pareto_result = utils.process_pareto_clustered_results(pareto_result, clustered_result=cluster_results, ran_cluster=True, ran_calibration= False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -495,7 +514,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@
],
"source": [
"hyperparameters = Hyperparameters(\n",
" {\n",
" hyperparameters={\n",
" \"facebook_S\": ChannelHyperparameters(\n",
" alphas=[0.5, 3],\n",
" gammas=[0.3, 1],\n",
Expand Down
2 changes: 1 addition & 1 deletion python/src/robyn/tutorials/tutorial3_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@
],
"source": [
"hyperparameters = Hyperparameters(\n",
" {\n",
" hyperparameters={\n",
" \"facebook_S\": ChannelHyperparameters(\n",
" alphas=[0.5, 3],\n",
" gammas=[0.3, 1],\n",
Expand Down
2 changes: 1 addition & 1 deletion python/src/robyn/tutorials/tutorial5_calibration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@
],
"source": [
"hyperparameters = Hyperparameters(\n",
" {\n",
" hyperparameters={\n",
" \"facebook_S\": ChannelHyperparameters(\n",
" alphas=[0.5, 3],\n",
" gammas=[0.3, 1],\n",
Expand Down
5 changes: 3 additions & 2 deletions python/src/robyn/tutorials/tutorial7_clustering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
"from robyn.modeling.pareto.pareto_utils import ParetoUtils\n",
"\n",
"utils = ParetoUtils()\n",
"new_pareto_results = utils.process_pareto_clustered_results(pareto_result, clustered_result=cluster_results, ran_cluster=False, ran_calibration= False)"
"new_pareto_results = utils.process_pareto_clustered_results(pareto_result, clustered_result=cluster_results, ran_cluster=True, ran_calibration= False)"
]
},
{
Expand All @@ -147,9 +147,10 @@
"outputs": [],
"source": [
"from robyn.data.entities.enums import AdstockType \n",
"from robyn.data.entities.hyperparameters import Hyperparameters\n",
"from robyn.reporting.onepager_reporting import OnePager\n",
"\n",
"visualizer = OnePager(pareto_result=new_pareto_results, clustered_result=cluster_results, adstock=AdstockType.GEOMETRIC, mmm_data=mmm_data, holidays_data=holidays_data)\n",
"visualizer = OnePager(pareto_result=new_pareto_results, clustered_result=cluster_results, hyperparameter=Hyperparameters(adstock=AdstockType.GEOMETRIC), mmm_data=mmm_data, holidays_data=holidays_data)\n",
"visualizer.generate_one_pager(top_pareto=True)"
]
}
Expand Down
Loading

0 comments on commit 209513b

Please sign in to comment.