Skip to content

Commit

Permalink
Implemented thompson sampling correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jul 23, 2024
1 parent a2c9f4f commit 5d2f057
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 34 deletions.
27 changes: 13 additions & 14 deletions docs/guide/bandits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -191,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -210,17 +210,16 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Completed 100% [====================]\n",
"Saved population to file /tmp/tmpshx7uzc1/population.json\n",
"saving final population as archive...\n",
"score: 0.8900709673753779\n"
"score: 0.8972961714305525\n"
]
}
],
Expand All @@ -239,7 +238,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -287,7 +286,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -297,7 +296,7 @@
"Search Space\n",
"===\n",
"terminal_map: {\"ArrayB\": [\"1.00\"], \"ArrayI\": [\"x5\", \"x7\", \"1.00\"], \"ArrayF\": [\"x0\", \"x1\", \"x2\", \"x3\", \"x4\", \"x6\", \"1.00\", \"1.00*MeanLabel\"]}\n",
"terminal_weights: {\"ArrayB\": [-nan], \"ArrayI\": [0.5, 0.5, 0.5], \"ArrayF\": [0.5302013, 0.5300546, 0.55172414, 0.55825245, 0.5, 0.55445546, 0.407767, 0.3604651]}\n",
"terminal_weights: {\"ArrayB\": [-nan], \"ArrayI\": [0.5, 0.5, 0.5], \"ArrayF\": [0.31623933, 0.37096775, 0.28431374, 0.3359375, 0.36923078, 0.3006993, 0.5485714, 0.5503876]}\n",
"node_map[ArrayI][[\"ArrayI\", \"ArrayI\"]][SplitBest] = SplitBest, weight = 1\n",
"node_map[MatrixF][[\"ArrayF\", \"ArrayF\", \"ArrayF\", \"ArrayF\"]][Logabs] = Logabs, weight = 1\n",
"node_map[MatrixF][[\"ArrayF\", \"ArrayF\", \"ArrayF\", \"ArrayF\"]][Exp] = Exp, weight = 1\n",
Expand Down Expand Up @@ -335,7 +334,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -354,15 +353,15 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.19159255921840668\n",
"{'delete': 0.3327217102050781, 'insert': 0.014492753893136978, 'point': 0.47653430700302124, 'subtree': 0.687393069267273, 'toggle_weight_off': 0.40665435791015625, 'toggle_weight_on': 0.019417475908994675}\n"
"0.1598760038614273\n",
"{'delete': 0.272018700838089, 'insert': 0.019999999552965164, 'point': 0.38034459948539734, 'subtree': 0.6300863027572632, 'toggle_weight_off': 0.5979797840118408, 'toggle_weight_on': 0.036269430071115494}\n"
]
}
],
Expand All @@ -388,7 +387,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.undefined"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
181 changes: 169 additions & 12 deletions docs/guide/search_space.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,133 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"id": "b667948a",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>race</th>\n",
" <th>target</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>993.00000</td>\n",
" <td>993.000000</td>\n",
" <td>993.000000</td>\n",
" <td>993.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>496.00000</td>\n",
" <td>0.487412</td>\n",
" <td>2.625378</td>\n",
" <td>8.219092</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>286.79871</td>\n",
" <td>0.500093</td>\n",
" <td>1.725240</td>\n",
" <td>1.101319</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.00000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1.337280</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>248.00000</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000</td>\n",
" <td>7.836757</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>496.00000</td>\n",
" <td>0.000000</td>\n",
" <td>3.000000</td>\n",
" <td>8.404038</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>744.00000</td>\n",
" <td>1.000000</td>\n",
" <td>4.000000</td>\n",
" <td>8.810710</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>992.00000</td>\n",
" <td>1.000000</td>\n",
" <td>5.000000</td>\n",
" <td>11.410597</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id sex race target\n",
"count 993.00000 993.000000 993.000000 993.000000\n",
"mean 496.00000 0.487412 2.625378 8.219092\n",
"std 286.79871 0.500093 1.725240 1.101319\n",
"min 0.00000 0.000000 0.000000 1.337280\n",
"25% 248.00000 0.000000 1.000000 7.836757\n",
"50% 496.00000 0.000000 3.000000 8.404038\n",
"75% 744.00000 1.000000 4.000000 8.810710\n",
"max 992.00000 1.000000 5.000000 11.410597"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from pybrush import Dataset, SearchSpace\n",
"\n",
"df = pd.read_csv('../examples/datasets/d_enc.csv')\n",
"X = df.drop(columns='label')\n",
"y = df['label']\n",
"df = pd.read_csv('../examples/datasets/d_example_patients.csv')\n",
"X = df.drop(columns='target')\n",
"y = df['target']\n",
"\n",
"df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6d563aae",
"metadata": {},
"outputs": [],
"source": [
"data = Dataset(X,y)\n",
"\n",
"search_space = SearchSpace(data)"
Expand All @@ -59,7 +174,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 14,
"id": "23d6f552",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -93,25 +208,25 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 15,
"id": "a2953719",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Search Space\n",
"===\n",
"terminal_map: {\"ArrayB\": [\"1.00\"], \"ArrayI\": [\"x_5\", \"x_7\", \"1.00\"], \"ArrayF\": [\"x_0\", \"x_1\", \"x_2\", \"x_3\", \"x_4\", \"x_6\", \"1.00\", \"1.00*MeanLabel\"]}\n",
"terminal_weights: {\"ArrayB\": [-nan], \"ArrayI\": [0.011619061, 0.03579926, 0.023709161], \"ArrayF\": [0.6343385, 0.67299956, 0.42711574, 0.8625447, 0.8957853, 0.20750472, 0.6167148, 0.6167148]}\n",
"=== Search space ===\n",
"terminal_map: {\"ArrayI\": [\"x_2\", \"1.00\"], \"ArrayB\": [\"x_1\", \"1.00\"], \"ArrayF\": [\"x_0\", \"1.00\", \"1.00*MeanLabel\"]}\n",
"terminal_weights: {\"ArrayI\": [0.01214596, 0.01214596], \"ArrayB\": [0.026419641, 0.026419641], \"ArrayF\": [0.056145623, 0.056145623, 0.056145623]}\n",
"node_map[ArrayB][[\"ArrayB\", \"ArrayB\"]][SplitBest] = SplitBest, weight = 0.2\n",
"node_map[ArrayI][[\"ArrayI\", \"ArrayI\"]][SplitBest] = SplitBest, weight = 0.2\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][SplitBest] = SplitBest, weight = 0.2\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Div] = Div, weight = 0.1\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Mul] = Mul, weight = 1\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Sub] = Sub, weight = 0.5\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Add] = Add, weight = 0.5\n",
"===\n"
"\n"
]
}
],
Expand All @@ -129,6 +244,48 @@
"Note also that the default behavior is to give both of these nodes the same weight as specified by the user. "
]
},
{
"cell_type": "markdown",
"id": "ca903d90",
"metadata": {},
"source": [
"## Loading datatypes\n",
"\n",
"If you pass a numpy array, Brush will try to infer datatypes based on its values.\n",
"If instead of passing the data directly you rather pass a pandas dataframe, then it will use the data types retrieved from the powerful pandas sniffer to use as its own data type."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "1c8c72c1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Search space ===\n",
"terminal_map: {\"ArrayI\": [\"x_2\", \"1.00\"], \"ArrayB\": [\"x_1\", \"1.00\"], \"ArrayF\": [\"x_0\", \"1.00\", \"1.00*MeanLabel\"]}\n",
"terminal_weights: {\"ArrayI\": [0.01214596, 0.01214596], \"ArrayB\": [0.026419641, 0.026419641], \"ArrayF\": [0.056145623, 0.056145623, 0.056145623]}\n",
"node_map[ArrayB][[\"ArrayB\", \"ArrayB\"]][SplitBest] = SplitBest, weight = 0.2\n",
"node_map[ArrayI][[\"ArrayI\", \"ArrayI\"]][SplitBest] = SplitBest, weight = 0.2\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][SplitBest] = SplitBest, weight = 0.2\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Div] = Div, weight = 0.1\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Mul] = Mul, weight = 1\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Sub] = Sub, weight = 0.5\n",
"node_map[ArrayF][[\"ArrayF\", \"ArrayF\"]][Add] = Add, weight = 0.5\n",
"\n"
]
}
],
"source": [
"data = Dataset(X.values, y.values)\n",
"\n",
"search_space = SearchSpace(data, user_ops)\n",
"search_space.print()"
]
},
{
"cell_type": "markdown",
"id": "d662c5a7",
Expand Down
30 changes: 29 additions & 1 deletion src/bandit/thompson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,39 @@ ThompsonSamplingBandit<T>::ThompsonSamplingBandit(map<T, float> arms_probs)

template <typename T>
std::map<T, float> ThompsonSamplingBandit<T>::sample_probs(bool update) {

// from https://stackoverflow.com/questions/4181403/generate-random-number-based-on-beta-distribution-using-boost
// You'll first want to draw a random number uniformly from the
// range (0,1). Given any distribution, you can then plug that number
// into the distribution's "quantile function," and the result is as
// if a random value was drawn from the distribution.

// from https://stackoverflow.com/questions/10358064/random-numbers-from-beta-distribution-c
// The beta distribution is related to the gamma distribution. Let X be a
// random number drawn from Gamma(α,1) and Y from Gamma(β,1), where the
// first argument to the gamma distribution is the shape parameter.
// Then Z=X/(X+Y) has distribution Beta(α,β).

if (update) {
// 1. use a beta distribution based on alphas and betas to sample probabilities
// 2. normalize probabilities so the sum is 1?

float alpha, beta, X, Y, prob;
for (const auto& pair : this->probabilities) {
T arm = pair.first;
float prob = static_cast<float>(alphas[arm] - 1) / static_cast<float>(alphas[arm] + betas[arm] - 2);

alpha = alphas[arm];
beta = betas[arm];

// TODO: stop using boost and use std::gamma_distribution (first, search to see if it is faster)
boost::math::gamma_distribution<> gammaX(alpha);
boost::math::gamma_distribution<> gammaY(beta);

X = boost::math::quantile(gammaX, Brush::Util::r.rnd_flt());
Y = boost::math::quantile(gammaY, Brush::Util::r.rnd_flt());

prob = X/(X+Y);

this->probabilities[arm] = prob;
}
}
Expand Down
Loading

0 comments on commit 5d2f057

Please sign in to comment.