Skip to content

Commit

Permalink
Add ExternalGenerationNode tutorial (#2281)
Browse files Browse the repository at this point in the history
Summary:

Adds a tutorial that uses ExternalGenerationNode with RandomForest.

Reviewed By: Balandat

Differential Revision: D54970325
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 18, 2024
1 parent 639b731 commit 90a315d
Show file tree
Hide file tree
Showing 2 changed files with 401 additions and 0 deletions.
395 changes: 395 additions & 0 deletions tutorials/external_generation_node.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,395 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "448bd7a0-af5a-43b4-a4fa-6a43577193b5",
"outputsInitialized": false,
"showInput": false
},
"source": [
"# Using external methods for candidate generation in Ax\n",
"\n",
"Out of the box, Ax offers many options for candidate generation, most of which utilize Bayesian optimization algorithms built using [BoTorch](https://botorch.org/). For users that want to leverage Ax for experiment orchestration (via `AxClient` or `Scheduler`) and other features (e.g., early stopping), while relying on other methods for candidate generation, we introduced `ExternalGenerationNode`. \n",
"\n",
"A `GenerationNode` is a building block of a `GenerationStrategy`. They can be combined together utilize different methods for generating candidates at different stages of an experiment. `ExternalGenerationNode` exposes a lightweight interface to allow the users to easily integrate their methods into Ax, and use them as standalone or with other `GenerationNode`s in a `GenerationStrategy`.\n",
"\n",
"In this tutorial, we will implement a simple generation node using `RandomForestRegressor` from sklearn, and combine it with Sobol (for initialization) to optimize the Hartmann6 problem.\n",
"\n",
"NOTE: This is for illustration purposes only. We do not recommend using this strategy as it typically does not perform well compared to Ax's default algorithms due to it's overly greedy behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"executionStartTime": 1710539298590,
"executionStopTime": 1710539307671,
"originalKey": "d07e3074-f374-40e8-af49-a018a00288b5",
"output": {
"id": "314819867912827",
"loadingStatus": "before loading"
},
"outputsInitialized": true,
"requestMsgId": "d07e3074-f374-40e8-af49-a018a00288b5",
"serverExecutionDuration": 4039.838102879
},
"outputs": [],
"source": [
"import time\n",
"from typing import Any, Dict, List, Optional, Tuple\n",
"\n",
"import numpy as np\n",
"from ax.core.base_trial import TrialStatus\n",
"from ax.core.data import Data\n",
"from ax.core.experiment import Experiment\n",
"from ax.core.parameter import RangeParameter\n",
"from ax.core.types import TParameterization\n",
"from ax.modelbridge.external_generation_node import ExternalGenerationNode\n",
"from ax.modelbridge.generation_node import GenerationNode\n",
"from ax.modelbridge.generation_strategy import GenerationStrategy\n",
"from ax.modelbridge.model_spec import ModelSpec\n",
"from ax.modelbridge.registry import Models\n",
"from ax.modelbridge.transition_criterion import MaxTrials\n",
"from ax.plot.trace import plot_objective_value_vs_trial_index\n",
"from ax.service.ax_client import AxClient, ObjectiveProperties\n",
"from ax.service.utils.report_utils import exp_to_df\n",
"from ax.utils.common.typeutils import checked_cast\n",
"from ax.utils.measurement.synthetic_functions import hartmann6\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"\n",
"class RandomForestGenerationNode(ExternalGenerationNode):\n",
" \"\"\"A generation node that uses the RandomForestRegressor\n",
" from sklearn to predict candidate performance and picks the\n",
" next point as the random sample that has the best prediction.\n",
"\n",
" To leverage external methods for candidate generation, the user must\n",
" create a subclass that implements ``update_generator_state`` and\n",
" ``get_next_candidate`` methods. This can then be provided\n",
" as a node into a ``GenerationStrategy``, either as standalone or as\n",
" part of a larger generation strategy with other generation nodes,\n",
" e.g., with a Sobol node for initialization.\n",
" \"\"\"\n",
"\n",
" def __init__(self, num_samples: int, regressor_options: Dict[str, Any]) -> None:\n",
" \"\"\"Initialize the generation node.\n",
"\n",
" Args:\n",
" regressor_options: Options to pass to the random forest regressor.\n",
" num_samples: Number of random samples from the search space\n",
" used during candidate generation. The sample with the best\n",
" prediction is recommended as the next candidate.\n",
" \"\"\"\n",
" t_init_start = time.monotonic()\n",
" super().__init__(node_name=\"RandomForest\")\n",
" self.num_samples: int = num_samples\n",
" self.regressor: RandomForestRegressor = RandomForestRegressor(\n",
" **regressor_options\n",
" )\n",
" # We will set these later when updating the state.\n",
" # Alternatively, we could have required experiment as an input\n",
" # and extracted them here.\n",
" self.parameters: Optional[List[RangeParameter]] = None\n",
" self.minimize: Optional[bool] = None\n",
" # Recording time spent in initializing the generator. This is\n",
" # used to compute the time spent in candidate generation.\n",
" self.fit_time_since_gen: float = time.monotonic() - t_init_start\n",
"\n",
" def update_generator_state(self, experiment: Experiment, data: Data) -> None:\n",
" \"\"\"A method used to update the state of the generator. This includes any\n",
" models, predictors or any other custom state used by the generation node.\n",
" This method will be called with the up-to-date experiment and data before\n",
" ``get_next_candidate`` is called to generate the next trial(s). Note\n",
" that ``get_next_candidate`` may be called multiple times (to generate\n",
" multiple candidates) after a call to ``update_generator_state``.\n",
"\n",
" For this example, we will train the regressor using the latest data from\n",
" the experiment.\n",
"\n",
" Args:\n",
" experiment: The ``Experiment`` object representing the current state of the\n",
" experiment. The key properties includes ``trials``, ``search_space``,\n",
" and ``optimization_config``. The data is provided as a separate arg.\n",
" data: The data / metrics collected on the experiment so far.\n",
" \"\"\"\n",
" search_space = experiment.search_space\n",
" parameter_names = list(search_space.parameters.keys())\n",
" metric_names = list(experiment.optimization_config.metrics.keys())\n",
" if any(\n",
" not isinstance(p, RangeParameter) for p in search_space.parameters.values()\n",
" ):\n",
" raise NotImplementedError(\n",
" \"This example only supports RangeParameters in the search space.\"\n",
" )\n",
" if search_space.parameter_constraints:\n",
" raise NotImplementedError(\n",
" \"This example does not support parameter constraints.\"\n",
" )\n",
" if len(metric_names) != 1:\n",
" raise NotImplementedError(\n",
" \"This example only supports single-objective optimization.\"\n",
" )\n",
" # Get the data for the completed trials.\n",
" num_completed_trials = len(experiment.trials_by_status[TrialStatus.COMPLETED])\n",
" x = np.zeros([num_completed_trials, len(parameter_names)])\n",
" y = np.zeros([num_completed_trials, 1])\n",
" for t_idx, trial in experiment.trials.items():\n",
" if trial.status == \"COMPLETED\":\n",
" trial_parameters = trial.arm.parameters\n",
" x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])\n",
" trial_df = data.df[data.df[\"trial_index\"] == t_idx]\n",
" y[t_idx, 0] = trial_df[trial_df[\"metric_name\"] == metric_names[0]][\n",
" \"mean\"\n",
" ].item()\n",
"\n",
" # Train the regressor.\n",
" self.regressor.fit(x, y)\n",
" # Update the attributes not set in __init__.\n",
" self.parameters = search_space.parameters\n",
" self.minimize = experiment.optimization_config.objective.minimize\n",
"\n",
" def get_next_candidate(\n",
" self, pending_parameters: List[TParameterization]\n",
" ) -> TParameterization:\n",
" \"\"\"Get the parameters for the next candidate configuration to evaluate.\n",
"\n",
" We will draw ``self.num_samples`` random samples from the search space\n",
" and predict the objective value for each sample. We will then return\n",
" the sample with the best predicted value.\n",
"\n",
" Args:\n",
" pending_parameters: A list of parameters of the candidates pending\n",
" evaluation. This is often used to avoid generating duplicate candidates.\n",
" We ignore this here for simplicity.\n",
"\n",
" Returns:\n",
" A dictionary mapping parameter names to parameter values for the next\n",
" candidate suggested by the method.\n",
" \"\"\"\n",
" bounds = np.array([[p.lower, p.upper] for p in self.parameters.values()])\n",
" unit_samples = np.random.random_sample([self.num_samples, len(bounds)])\n",
" samples = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * unit_samples\n",
" # Predict the objective value for each sample.\n",
" y_pred = self.regressor.predict(samples)\n",
" # Find the best sample.\n",
" best_idx = np.argmin(y_pred) if self.minimize else np.argmax(y_pred)\n",
" best_sample = samples[best_idx, :]\n",
" # Convert the sample to a parameterization.\n",
" candidate = {\n",
" p_name: best_sample[i].item()\n",
" for i, p_name in enumerate(self.parameters.keys())\n",
" }\n",
" return candidate"
]
},
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "e1c194ea-53f9-466b-a04a-d1e222751a62",
"outputsInitialized": false,
"showInput": false
},
"source": [
"## Construct the GenerationStrategy\n",
"\n",
"We will use Sobol for the first 5 trials and defer to random forest for the rest."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"customInput": null,
"executionStartTime": 1710539307673,
"executionStopTime": 1710539307752,
"originalKey": "389cb09c-adeb-4724-82b0-903806b6b403",
"outputsInitialized": true,
"requestMsgId": "389cb09c-adeb-4724-82b0-903806b6b403",
"serverExecutionDuration": 5.2677921485156,
"showInput": true
},
"outputs": [],
"source": [
"generation_strategy = GenerationStrategy(\n",
" name=\"Sobol+RandomForest\",\n",
" nodes=[\n",
" GenerationNode(\n",
" node_name=\"Sobol\",\n",
" model_specs=[ModelSpec(Models.SOBOL)],\n",
" transition_criteria=[\n",
" MaxTrials(\n",
" # This specifies the maximum number of trials to generate from this node.\n",
" threshold=5,\n",
" block_transition_if_unmet=True,\n",
" )\n",
" ],\n",
" gen_unlimited_trials=False,\n",
" ),\n",
" RandomForestGenerationNode(num_samples=128, regressor_options={}),\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "7bcf0a8e-39f7-4ceb-a791-c5453024bcfd",
"outputsInitialized": false,
"showInput": false
},
"source": [
"## Run a simple experiment using AxClient\n",
"\n",
"More details on how to use AxClient can be found in the [tutorial](https://ax.dev/tutorials/gpei_hartmann_service.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"customInput": null,
"executionStartTime": 1710539307754,
"executionStopTime": 1710539307854,
"originalKey": "4be26fc1-6183-40c4-a45e-79adb613b950",
"outputsInitialized": true,
"requestMsgId": "4be26fc1-6183-40c4-a45e-79adb613b950",
"serverExecutionDuration": 15.909331152216,
"showInput": true
},
"outputs": [],
"source": [
"ax_client = AxClient(generation_strategy=generation_strategy)\n",
"\n",
"ax_client.create_experiment(\n",
" name=\"hartmann_test_experiment\",\n",
" parameters=[\n",
" {\n",
" \"name\": f\"x{i}\",\n",
" \"type\": \"range\",\n",
" \"bounds\": [0.0, 1.0],\n",
" \"value_type\": \"float\", # Optional, defaults to inference from type of \"bounds\".\n",
" }\n",
" for i in range(1, 7)\n",
" ],\n",
" objectives={\"hartmann6\": ObjectiveProperties(minimize=True)},\n",
")\n",
"\n",
"\n",
"def evaluate(parameterization: TParameterization) -> Dict[str, Tuple[float, float]]:\n",
" x = np.array([parameterization.get(f\"x{i+1}\") for i in range(6)])\n",
" return {\"hartmann6\": (checked_cast(float, hartmann6(x)), 0.0)}"
]
},
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "a470eb3e-40a0-45d2-9d53-13a98a137ec2",
"outputsInitialized": false,
"showInput": false
},
"source": [
"### Run the optimization loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"customInput": null,
"executionStartTime": 1710539307855,
"executionStopTime": 1710539309651,
"originalKey": "f67454e1-2a1a-4e87-ba3b-038c3134b09d",
"outputsInitialized": false,
"requestMsgId": "f67454e1-2a1a-4e87-ba3b-038c3134b09d",
"serverExecutionDuration": 1679.0952710435,
"showInput": true
},
"outputs": [],
"source": [
"for i in range(15):\n",
" parameterization, trial_index = ax_client.get_next_trial()\n",
" ax_client.complete_trial(\n",
" trial_index=trial_index, raw_data=evaluate(parameterization)\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "d0655321-4875-46d7-a4bf-ac2c4e166d94",
"outputsInitialized": false,
"showInput": false
},
"source": [
"### View the trials generated during optimization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"customInput": null,
"executionStartTime": 1710539309652,
"executionStopTime": 1710539309824,
"originalKey": "ba69ed8c-7ee2-49ef-9ccf-0aad2bc5ac61",
"outputsInitialized": true,
"requestMsgId": "ba69ed8c-7ee2-49ef-9ccf-0aad2bc5ac61",
"serverExecutionDuration": 73.840260040015,
"showInput": true
},
"outputs": [],
"source": [
"exp_df = exp_to_df(ax_client.experiment)\n",
"exp_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_objective_value_vs_trial_index(\n",
" exp_df=exp_df,\n",
" metric_colname=\"hartmann6\",\n",
" minimize=True,\n",
" title=\"Hartmann6 Objective Value vs. Trial Index\",\n",
")"
]
}
],
"metadata": {
"fileHeader": "",
"fileUid": "1ab8b45a-525c-4c25-b142-f7ef9fffb1c5",
"isAdHoc": false,
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 90a315d

Please sign in to comment.