From 675ce91e431b7b9090e8658d3408621c37ec3083 Mon Sep 17 00:00:00 2001 From: gAldeia Date: Sun, 7 Jul 2024 20:19:42 -0300 Subject: [PATCH] updated guide to show how to access archive in saving_loading_populations --- docs/guide/saving_loading_populations.ipynb | 281 +++++++++++--------- 1 file changed, 162 insertions(+), 119 deletions(-) diff --git a/docs/guide/saving_loading_populations.ipynb b/docs/guide/saving_loading_populations.ipynb index 3ed2376a..3763c894 100644 --- a/docs/guide/saving_loading_populations.ipynb +++ b/docs/guide/saving_loading_populations.ipynb @@ -48,78 +48,78 @@ "output_type": "stream", "text": [ "Generation 1/10 [////// ]\n", - "Train Loss (Med): 16.41696 (74.37033)\n", - "Val Loss (Med): 16.41696 (74.37033)\n", - "Median Size (Max): 3 (12)\n", - "Median complexity (Max): 9 (156)\n", - "Time (s): 0.07226\n", + "Train Loss (Med): 10.84173 (54.01542)\n", + "Val Loss (Med): 10.84173 (54.01542)\n", + "Median Size (Max): 3 (20)\n", + "Median complexity (Max): 19 (2163)\n", + "Time (s): 0.09850\n", "\n", "Generation 2/10 [/////////// ]\n", - "Train Loss (Med): 12.66635 (49.96683)\n", - "Val Loss (Med): 12.66635 (49.96683)\n", - "Median Size (Max): 3 (12)\n", - "Median complexity (Max): 9 (165)\n", - "Time (s): 0.12100\n", + "Train Loss (Med): 10.84173 (38.19256)\n", + "Val Loss (Med): 10.84173 (38.19256)\n", + "Median Size (Max): 3 (20)\n", + "Median complexity (Max): 9 (2163)\n", + "Time (s): 0.26875\n", "\n", "Generation 3/10 [//////////////// ]\n", - "Train Loss (Med): 12.66635 (16.41696)\n", - "Val Loss (Med): 12.66635 (16.41696)\n", - "Median Size (Max): 5 (14)\n", - "Median complexity (Max): 33 (408)\n", - "Time (s): 0.16357\n", + "Train Loss (Med): 10.43982 (34.84354)\n", + "Val Loss (Med): 10.43982 (34.84354)\n", + "Median Size (Max): 3 (20)\n", + "Median complexity (Max): 9 (2163)\n", + "Time (s): 0.46519\n", "\n", "Generation 4/10 [///////////////////// ]\n", - "Train Loss (Med): 10.97588 (17.85729)\n", - "Val Loss (Med): 10.97588 (17.85729)\n", - "Median Size (Max): 5 (14)\n", - "Median complexity (Max): 20 (360)\n", - "Time (s): 0.21556\n", + "Train Loss (Med): 10.26326 (17.94969)\n", + "Val Loss (Med): 10.26326 (17.94969)\n", + "Median Size (Max): 3 (19)\n", + "Median complexity (Max): 9 (723)\n", + "Time (s): 0.58006\n", "\n", "Generation 5/10 [////////////////////////// ]\n", - "Train Loss (Med): 10.97588 (16.95482)\n", - "Val Loss (Med): 10.97588 (16.95482)\n", - "Median Size (Max): 5 (15)\n", - "Median complexity (Max): 33 (399)\n", - "Time (s): 0.26767\n", + "Train Loss (Med): 10.26326 (17.94969)\n", + "Val Loss (Med): 10.26326 (17.94969)\n", + "Median Size (Max): 3 (19)\n", + "Median complexity (Max): 9 (304)\n", + "Time (s): 0.69868\n", "\n", "Generation 6/10 [/////////////////////////////// ]\n", - "Train Loss (Med): 10.97588 (16.41696)\n", - "Val Loss (Med): 10.97588 (16.41696)\n", - "Median Size (Max): 5 (15)\n", - "Median complexity (Max): 33 (315)\n", - "Time (s): 0.33006\n", + "Train Loss (Med): 10.26326 (17.90349)\n", + "Val Loss (Med): 10.26326 (17.90349)\n", + "Median Size (Max): 4 (19)\n", + "Median complexity (Max): 22 (591)\n", + "Time (s): 0.82559\n", "\n", "Generation 7/10 [//////////////////////////////////// ]\n", - "Train Loss (Med): 10.97588 (16.41696)\n", - "Val Loss (Med): 10.97588 (16.41696)\n", - "Median Size (Max): 5 (15)\n", - "Median complexity (Max): 33 (273)\n", - "Time (s): 0.43463\n", + "Train Loss (Med): 10.26326 (16.75967)\n", + "Val Loss (Med): 10.26326 (16.75967)\n", + "Median Size (Max): 6 (19)\n", + "Median complexity (Max): 33 (591)\n", + "Time (s): 1.00143\n", "\n", "Generation 8/10 [///////////////////////////////////////// ]\n", - "Train Loss (Med): 10.97588 (15.46647)\n", - "Val Loss (Med): 10.97588 (15.46647)\n", - "Median Size (Max): 7 (15)\n", - "Median complexity (Max): 43 (273)\n", - "Time (s): 0.51012\n", + "Train Loss (Med): 10.26326 (17.85729)\n", + "Val Loss (Med): 10.26326 (17.85729)\n", + "Median Size (Max): 5 (19)\n", + "Median complexity (Max): 9 (591)\n", + "Time (s): 1.21972\n", "\n", "Generation 9/10 [////////////////////////////////////////////// ]\n", - "Train Loss (Med): 10.97588 (15.94172)\n", - "Val Loss (Med): 10.97588 (15.94172)\n", - "Median Size (Max): 6 (15)\n", - "Median complexity (Max): 34 (273)\n", - "Time (s): 0.58572\n", + "Train Loss (Med): 10.26326 (17.94969)\n", + "Val Loss (Med): 10.26326 (17.94969)\n", + "Median Size (Max): 4 (19)\n", + "Median complexity (Max): 8 (324)\n", + "Time (s): 1.39759\n", "\n", "Generation 10/10 [//////////////////////////////////////////////////]\n", - "Train Loss (Med): 10.97588 (15.94172)\n", - "Val Loss (Med): 10.97588 (15.94172)\n", - "Median Size (Max): 6 (15)\n", - "Median complexity (Max): 34 (273)\n", - "Time (s): 0.64205\n", + "Train Loss (Med): 10.26326 (89.12132)\n", + "Val Loss (Med): 10.26326 (89.12132)\n", + "Median Size (Max): 3 (19)\n", + "Median complexity (Max): 6 (723)\n", + "Time (s): 1.56908\n", "\n", - "Saved population to file /tmp/tmpfphckt_3/population.json\n", + "Saved population to file /tmp/tmpvo0quw4a/population.json\n", "saving final population as archive...\n", - "score: 0.8785654367436365\n" + "score: 0.8864497096851035\n" ] } ], @@ -161,10 +161,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loaded population from /tmp/tmpfphckt_3/population.json of size = 200\n", + "Loaded population from /tmp/tmpvo0quw4a/population.json of size = 200\n", "Completed 100% [====================]\n", "saving final population as archive...\n", - "score: 0.888055116477749\n" + "score: 0.8900709681238513\n" ] } ], @@ -181,6 +181,49 @@ "print('score:', est.score(X,y))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving just the archive\n", + "\n", + "In case you want to use another expression rather than the final `best_estimator_`, brush provides the archive option.\n", + "\n", + "The archive is just the pareto front from the population. You can use `predict_archive` (and `predict_proba_archive` if using a `BrushClassifier`) to call the prediction methods for the entire archive, instead of the selected best individual.\n", + "\n", + "But first, you need to enable this option with `use_arch=True`. When set to False, it will store the entire final population" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded population from /tmp/tmpvo0quw4a/population.json of size = 200\n", + "Completed 100% [====================]\n", + "{'complexity': 240, 'crowding_dist': 3.4028234663852886e+38, 'dcounter': 0, 'depth': 3, 'dominated': [], 'linear_complexity': 24, 'loss': 9.675766944885254, 'loss_v': 9.675766944885254, 'rank': 1, 'size': 17, 'values': [9.675766944885254, 17.0], 'weights': [-1.0, -1.0], 'wvalues': [-9.675766944885254, -17.0]}\n" + ] + } + ], + "source": [ + "est = BrushRegressor(\n", + " functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],\n", + " load_population=pop_file,\n", + " use_arch=True,\n", + " max_gens=10,\n", + " verbosity=1\n", + ")\n", + "\n", + "est.fit(X,y)\n", + "\n", + "# accessing first expression from the archive. It is serialized as a dict\n", + "print(est.archive_[0]['fitness'])" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -205,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -213,78 +256,78 @@ "output_type": "stream", "text": [ "Generation 1/10 [////// ]\n", - "Train Loss (Med): 0.54851 (0.69315)\n", - "Val Loss (Med): 0.54851 (0.69315)\n", - "Median Size (Max): 5 (12)\n", + "Train Loss (Med): 0.54847 (0.69315)\n", + "Val Loss (Med): 0.54847 (0.69315)\n", + "Median Size (Max): 5 (10)\n", "Median complexity (Max): 6 (270)\n", - "Time (s): 0.03284\n", + "Time (s): 0.04547\n", "\n", "Generation 2/10 [/////////// ]\n", - "Train Loss (Med): 0.54850 (0.69315)\n", - "Val Loss (Med): 0.54850 (0.69315)\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", "Median Size (Max): 5 (10)\n", - "Median complexity (Max): 6 (165)\n", - "Time (s): 0.05459\n", + "Median complexity (Max): 6 (90)\n", + "Time (s): 0.07358\n", "\n", "Generation 3/10 [//////////////// ]\n", - "Train Loss (Med): 0.54851 (0.69315)\n", - "Val Loss (Med): 0.54851 (0.69315)\n", - "Median Size (Max): 3 (10)\n", - "Median complexity (Max): 3 (165)\n", - "Time (s): 0.07147\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 5 (10)\n", + "Median complexity (Max): 6 (279)\n", + "Time (s): 0.10416\n", "\n", "Generation 4/10 [///////////////////// ]\n", - "Train Loss (Med): 0.54851 (0.69315)\n", - "Val Loss (Med): 0.54851 (0.69315)\n", - "Median Size (Max): 1 (10)\n", - "Median complexity (Max): 2 (54)\n", - "Time (s): 0.08754\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 2 (10)\n", + "Median complexity (Max): 2 (279)\n", + "Time (s): 0.13638\n", "\n", "Generation 5/10 [////////////////////////// ]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", - "Median Size (Max): 1 (10)\n", - "Median complexity (Max): 2 (54)\n", - "Time (s): 0.11204\n", + "Median Size (Max): 2 (10)\n", + "Median complexity (Max): 2 (279)\n", + "Time (s): 0.16392\n", "\n", "Generation 6/10 [/////////////////////////////// ]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", "Median Size (Max): 1 (10)\n", - "Median complexity (Max): 1 (90)\n", - "Time (s): 0.13126\n", + "Median complexity (Max): 1 (54)\n", + "Time (s): 0.19176\n", "\n", "Generation 7/10 [//////////////////////////////////// ]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", "Median Size (Max): 1 (10)\n", - "Median complexity (Max): 1 (54)\n", - "Time (s): 0.14825\n", + "Median complexity (Max): 1 (90)\n", + "Time (s): 0.21715\n", "\n", "Generation 8/10 [///////////////////////////////////////// ]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", "Median Size (Max): 1 (10)\n", - "Median complexity (Max): 1 (52)\n", - "Time (s): 0.16841\n", + "Median complexity (Max): 1 (54)\n", + "Time (s): 0.24563\n", "\n", "Generation 9/10 [////////////////////////////////////////////// ]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", "Median Size (Max): 1 (10)\n", "Median complexity (Max): 1 (48)\n", - "Time (s): 0.19657\n", + "Time (s): 0.27388\n", "\n", "Generation 10/10 [//////////////////////////////////////////////////]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", "Median Size (Max): 1 (10)\n", "Median complexity (Max): 1 (48)\n", - "Time (s): 0.22023\n", + "Time (s): 0.29905\n", "\n", - "Saved population to file /tmp/tmpe7n_mbgz/population.json\n", + "Saved population to file /tmp/tmp3lun3trc/population.json\n", "saving final population as archive...\n", - "If(AIDS>15890.50,1.18*Logistic(1.69*MeanLabel),0.39*MeanLabel)\n", + "Logistic(Sum(-0.83599114,If(AIDS>15890.50,15.27,0.39*MeanLabel)))\n", "score: 0.68\n" ] } @@ -317,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -328,75 +371,75 @@ "Train Loss (Med): 0.46115 (0.31675)\n", "Val Loss (Med): 0.46115 (0.31675)\n", "Median Size (Max): 5 (9)\n", - "Median complexity (Max): 6 (180)\n", - "Time (s): 0.03686\n", + "Median complexity (Max): 6 (120)\n", + "Time (s): 0.07469\n", "\n", "Generation 2/10 [/////////// ]\n", - "Train Loss (Med): 0.75212 (0.31675)\n", - "Val Loss (Med): 0.75212 (0.31675)\n", + "Train Loss (Med): 0.55038 (0.31675)\n", + "Val Loss (Med): 0.55038 (0.31675)\n", "Median Size (Max): 5 (9)\n", - "Median complexity (Max): 6 (120)\n", - "Time (s): 0.06046\n", + "Median complexity (Max): 6 (648)\n", + "Time (s): 0.12616\n", "\n", "Generation 3/10 [//////////////// ]\n", - "Train Loss (Med): 0.75212 (0.31675)\n", - "Val Loss (Med): 0.75212 (0.31675)\n", - "Median Size (Max): 4 (9)\n", - "Median complexity (Max): 2 (90)\n", - "Time (s): 0.07728\n", + "Train Loss (Med): 0.45385 (0.00000)\n", + "Val Loss (Med): 0.45385 (0.00000)\n", + "Median Size (Max): 1 (8)\n", + "Median complexity (Max): 1 (648)\n", + "Time (s): 0.16134\n", "\n", "Generation 4/10 [///////////////////// ]\n", "Train Loss (Med): 0.75212 (0.00000)\n", "Val Loss (Med): 0.75212 (0.00000)\n", "Median Size (Max): 1 (9)\n", - "Median complexity (Max): 1 (90)\n", - "Time (s): 0.09751\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.20292\n", "\n", "Generation 5/10 [////////////////////////// ]\n", "Train Loss (Med): 0.75212 (0.00000)\n", "Val Loss (Med): 0.75212 (0.00000)\n", - "Median Size (Max): 1 (6)\n", - "Median complexity (Max): 1 (90)\n", - "Time (s): 0.11214\n", + "Median Size (Max): 1 (9)\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.23893\n", "\n", "Generation 6/10 [/////////////////////////////// ]\n", - "Train Loss (Med): 0.75212 (0.00000)\n", - "Val Loss (Med): 0.75212 (0.00000)\n", - "Median Size (Max): 1 (6)\n", - "Median complexity (Max): 1 (90)\n", - "Time (s): 0.12957\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 1 (9)\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.27777\n", "\n", "Generation 7/10 [//////////////////////////////////// ]\n", - "Train Loss (Med): 0.75258 (0.00000)\n", - "Val Loss (Med): 0.75258 (0.00000)\n", - "Median Size (Max): 1 (8)\n", - "Median complexity (Max): 1 (360)\n", - "Time (s): 0.14327\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 1 (9)\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.33410\n", "\n", "Generation 8/10 [///////////////////////////////////////// ]\n", "Train Loss (Med): 0.75212 (0.31675)\n", "Val Loss (Med): 0.75212 (0.31675)\n", - "Median Size (Max): 1 (6)\n", + "Median Size (Max): 1 (9)\n", "Median complexity (Max): 1 (30)\n", - "Time (s): 0.15835\n", + "Time (s): 0.37918\n", "\n", "Generation 9/10 [////////////////////////////////////////////// ]\n", "Train Loss (Med): 0.75212 (0.31675)\n", "Val Loss (Med): 0.75212 (0.31675)\n", "Median Size (Max): 1 (6)\n", "Median complexity (Max): 1 (30)\n", - "Time (s): 0.17249\n", + "Time (s): 0.41311\n", "\n", "Generation 10/10 [//////////////////////////////////////////////////]\n", "Train Loss (Med): 0.75212 (0.31675)\n", "Val Loss (Med): 0.75212 (0.31675)\n", "Median Size (Max): 1 (6)\n", "Median complexity (Max): 1 (30)\n", - "Time (s): 0.18584\n", + "Time (s): 0.44007\n", "\n", "saving final population as archive...\n", - "Logistic(-0.07*Mul(0.00*AIDS,-0.03*MeanLabel))\n", - "score: 0.52\n" + "Add(0.50*MeanLabel,AIDS)\n", + "score: 0.5\n" ] } ],