From 2ff4bab0a55d75378882dd9acfa218b3e9a2f3ba Mon Sep 17 00:00:00 2001 From: miguelgfierro Date: Wed, 13 Nov 2019 15:38:08 +0000 Subject: [PATCH 1/2] small changes in xlearn notebook --- notebooks/02_model/fm_deep_dive.ipynb | 223 +++++++++----------------- 1 file changed, 77 insertions(+), 146 deletions(-) diff --git a/notebooks/02_model/fm_deep_dive.ipynb b/notebooks/02_model/fm_deep_dive.ipynb index 1bff37847f..6df427ac9b 100644 --- a/notebooks/02_model/fm_deep_dive.ipynb +++ b/notebooks/02_model/fm_deep_dive.ipynb @@ -181,7 +181,7 @@ "|-----------------|------------------|------------------|---------------------|\n", "|[libfm](https://github.com/srendle/libfm)|C++|Implementation of FM algorithm|-|\n", "|[libffm](https://github.com/ycjuan/libffm)|C++|Original implemenation of FFM algorithm. It is handy in model building, but does not support Python interface|-|\n", - "|[xlearn](https://github.com/aksnzhy/xlearn)|C++ with Python interface|More computationally efficient compared to libffm without loss of modeling effectiveness|Appear soon|\n", + "|[xlearn](https://github.com/aksnzhy/xlearn)|C++ with Python interface|More computationally efficient compared to libffm without loss of modeling effectiveness|[notebook](https://github.com/microsoft/recommenders/blob/master/notebooks/02_model/fm_deep_dive.ipynb)|\n", "|[Vowpal Wabbit FM](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Matrix-factorization-example)|Online library with estimator API|Easy to use by calling API, but flexibility and configurability are limited|[notebook](https://github.com/microsoft/recommenders/blob/master/notebooks/02_model/vowpal_wabbit_deep_dive.ipynb) / [utilities](https://github.com/microsoft/recommenders/tree/master/reco_utils/recommender/vowpal_wabbit)\n", "|[microsoft/recommenders xDeepFM](https://github.com/microsoft/recommenders/blob/master/reco_utils/recommender/deeprec/models/xDeepFM.py)|Python|Support flexible interface with different configurations of FM and FM extensions, i.e., LR, FM, and/or CIN|[notebook](https://github.com/microsoft/recommenders/blob/master/notebooks/00_quick_start/xdeepfm_criteo.ipynb) / [utilities](https://github.com/microsoft/recommenders/blob/master/reco_utils/recommender/deeprec/models/xDeepFM.py)|" ] @@ -224,24 +224,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "%matplotlib notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": "System version: 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) \n[GCC 7.3.0]\nTensorflow version: 1.12.0\n" - } - ], "source": [ "import time\n", "import sys\n", @@ -250,11 +235,11 @@ "import papermill as pm\n", "from tempfile import TemporaryDirectory\n", "import xlearn as xl\n", - "import tensorflow as tf\n", "from sklearn.metrics import roc_auc_score\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", + "%matplotlib notebook\n", "from matplotlib import pyplot as plt\n", "\n", "from reco_utils.common.constants import SEED\n", @@ -268,7 +253,7 @@ "from reco_utils.dataset.pandas_df_utils import LibffmConverter\n", "\n", "print(\"System version: {}\".format(sys.version))\n", - "print(\"Tensorflow version: {}\".format(tf.__version__))" + "print(\"Xlearn version: {}\".format(xl.__version__))" ] }, { @@ -289,19 +274,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ratingfield1field2field3field4
011:1:12:4:33:5:1.04:6:1
101:2:12:4:43:5:2.04:7:1
201:3:12:4:53:5:3.04:8:1
311:3:12:4:63:5:4.04:9:1
411:3:12:4:73:5:5.04:10:1
\n
", - "text/plain": " rating field1 field2 field3 field4\n0 1 1:1:1 2:4:3 3:5:1.0 4:6:1\n1 0 1:2:1 2:4:4 3:5:2.0 4:7:1\n2 0 1:3:1 2:4:5 3:5:3.0 4:8:1\n3 1 1:3:1 2:4:6 3:5:4.0 4:9:1\n4 1 1:3:1 2:4:7 3:5:5.0 4:10:1" - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "df_feature_original = pd.DataFrame({\n", " 'rating': [1, 0, 0, 1, 1],\n", @@ -318,15 +293,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": "There are in total 4 fields and 10 features.\n" - } - ], + "outputs": [], "source": [ "print('There are in total {0} fields and {1} features.'.format(converter.field_count, converter.feature_count))" ] @@ -340,33 +309,36 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, + "execution_count": null, + "metadata": { + "tags": [ + "parameters" + ] + }, "outputs": [], "source": [ - "YAML_FILE_NAME = 'xDeepFM.yaml'\n", - "TRAIN_FILE_NAME = 'cretio_tiny_train'\n", - "VALID_FILE_NAME = 'cretio_tiny_valid'\n", - "TEST_FILE_NAME = 'cretio_tiny_test'\n", - "MODEL_FILE_NAME = 'model.out'\n", - "OUTPUT_FILE_NAME = 'output.txt'\n", + "# Parameters\n", + "YAML_FILE_NAME = \"xDeepFM.yaml\"\n", + "TRAIN_FILE_NAME = \"cretio_tiny_train\"\n", + "VALID_FILE_NAME = \"cretio_tiny_valid\"\n", + "TEST_FILE_NAME = \"cretio_tiny_test\"\n", + "MODEL_FILE_NAME = \"model.out\"\n", + "OUTPUT_FILE_NAME = \"output.txt\"\n", "\n", "LEARNING_RATE = 0.2\n", "LAMBDA = 0.002\n", - "METRIC = 'auc'" + "# The metrics for binary classification options are \"acc\", \"prec\", \"f1\" and \"auc\"\n", + "# for regression, options are \"rmse\", \"mae\", \"mape\"\n", + "METRIC = \"auc\" \n", + "EPOCH = 10\n", + "OPT_METHOD = \"sgd\" # options are \"sgd\", \"adagrad\" and \"ftrl\"" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": "100%|██████████| 10.3k/10.3k [00:05<00:00, 2.06kKB/s]\n" - } - ], + "outputs": [], "source": [ "tmpdir = TemporaryDirectory()\n", "\n", @@ -398,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -411,12 +383,19 @@ "# 0. task: binary classification\n", "# 1. learning rate: 0.2\n", "# 2. regular lambda: 0.002\n", - "# 3. evaluation metric: accuracy\n", - "param = {'task':'binary', 'lr':LEARNING_RATE, 'lambda':LAMBDA, 'metric':METRIC}\n", + "# 3. evaluation metric: auc\n", + "# 4. number of epochs: 10\n", + "# 5. optimization method: sgd\n", + "param = {\"task\":\"binary\", \n", + " \"lr\": LEARNING_RATE, \n", + " \"lambda\": LAMBDA, \n", + " \"metric\": METRIC,\n", + " \"epoch\": EPOCH,\n", + " \"opt\": OPT_METHOD\n", + " }\n", "\n", "# Start to train\n", "# The trained model will be stored in model.out\n", - "\n", "with Timer() as time_train:\n", " ffm_model.fit(param, model_file)\n", "\n", @@ -439,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -457,52 +436,27 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": "0.7526414844951351" - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "auc_score" ] }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "application/papermill.record+json": { - "auc_score": 0.7526414844951351 - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "pm.record('auc_score', auc_score)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": "Training takes 26.74s and predicting takes 4.89s.\n" - } - ], + "outputs": [], "source": [ "print('Training takes {0:.2f}s and predicting takes {1:.2f}s.'.format(time_train.interval, time_predict.interval))" ] @@ -545,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -559,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -585,8 +539,24 @@ " truths = np.array([float(truth.split(' ')[0]) for truth in truths])\n", " predictions = np.array([float(prediction.strip('')) for prediction in predictions])\n", "\n", - " auc_scores.append(roc_auc_score(truths, predictions))\n", - "\n", + " auc_scores.append(roc_auc_score(truths, predictions))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('Tuning by grid search takes {0:.2} min'.format(time_tune.interval / 60))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "auc_scores = [float('%.4f' % x) for x in auc_scores]\n", "auc_scores_array = np.reshape(auc_scores, (len(param_dict[\"lr\"]), len(param_dict[\"lambda\"]))) \n", "\n", @@ -594,54 +564,15 @@ " data=auc_scores_array, \n", " index=pd.Index(param_dict[\"lr\"], name=\"LR\"), \n", " columns=pd.Index(param_dict[\"lambda\"], name=\"Lambda\")\n", - ")" + ")\n", + "auc_df" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": "Tuning by grid search takes 7.6 min\n" - } - ], - "source": [ - "print('Tuning by grid search takes {0:.2} min'.format(time_tune.interval / 60))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": "/* Put everything inside the global mpl namespace */\nwindow.mpl = {};\n\n\nmpl.get_websocket_type = function() {\n if (typeof(WebSocket) !== 'undefined') {\n return WebSocket;\n } else if (typeof(MozWebSocket) !== 'undefined') {\n return MozWebSocket;\n } else {\n alert('Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.');\n };\n}\n\nmpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = (this.ws.binaryType != undefined);\n\n if (!this.supports_binary) {\n var warnings = document.getElementById(\"mpl-warnings\");\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent = (\n \"This browser does not support binary websocket messages. \" +\n \"Performance may be slow.\");\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = $('
');\n this._root_extra_style(this.root)\n this.root.attr('style', 'display: inline-block');\n\n $(parent_element).append(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n fig.send_message(\"send_image_mode\", {});\n if (mpl.ratio != 1) {\n fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n }\n fig.send_message(\"refresh\", {});\n }\n\n this.imageObj.onload = function() {\n if (fig.image_mode == 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function() {\n fig.ws.close();\n }\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n}\n\nmpl.figure.prototype._init_header = function() {\n var titlebar = $(\n '
');\n var titletext = $(\n '
');\n titlebar.append(titletext)\n this.root.append(titlebar);\n this.header = titletext[0];\n}\n\n\n\nmpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n\n}\n\n\nmpl.figure.prototype._root_extra_style = function(canvas_div) {\n\n}\n\nmpl.figure.prototype._init_canvas = function() {\n var fig = this;\n\n var canvas_div = $('
');\n\n canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n\n function canvas_keyboard_event(event) {\n return fig.key_event(event, event['data']);\n }\n\n canvas_div.keydown('key_press', canvas_keyboard_event);\n canvas_div.keyup('key_release', canvas_keyboard_event);\n this.canvas_div = canvas_div\n this._canvas_extra_style(canvas_div)\n this.root.append(canvas_div);\n\n var canvas = $('');\n canvas.addClass('mpl-canvas');\n canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n\n this.canvas = canvas[0];\n this.context = canvas[0].getContext(\"2d\");\n\n var backingStore = this.context.backingStorePixelRatio ||\n\tthis.context.webkitBackingStorePixelRatio ||\n\tthis.context.mozBackingStorePixelRatio ||\n\tthis.context.msBackingStorePixelRatio ||\n\tthis.context.oBackingStorePixelRatio ||\n\tthis.context.backingStorePixelRatio || 1;\n\n mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband = $('');\n rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n\n var pass_mouse_events = true;\n\n canvas_div.resizable({\n start: function(event, ui) {\n pass_mouse_events = false;\n },\n resize: function(event, ui) {\n fig.request_resize(ui.size.width, ui.size.height);\n },\n stop: function(event, ui) {\n pass_mouse_events = true;\n fig.request_resize(ui.size.width, ui.size.height);\n },\n });\n\n function mouse_event_fn(event) {\n if (pass_mouse_events)\n return fig.mouse_event(event, event['data']);\n }\n\n rubberband.mousedown('button_press', mouse_event_fn);\n rubberband.mouseup('button_release', mouse_event_fn);\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband.mousemove('motion_notify', mouse_event_fn);\n\n rubberband.mouseenter('figure_enter', mouse_event_fn);\n rubberband.mouseleave('figure_leave', mouse_event_fn);\n\n canvas_div.on(\"wheel\", function (event) {\n event = event.originalEvent;\n event['data'] = 'scroll'\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n mouse_event_fn(event);\n });\n\n canvas_div.append(canvas);\n canvas_div.append(rubberband);\n\n this.rubberband = rubberband;\n this.rubberband_canvas = rubberband[0];\n this.rubberband_context = rubberband[0].getContext(\"2d\");\n this.rubberband_context.strokeStyle = \"#000000\";\n\n this._resize_canvas = function(width, height) {\n // Keep the size of the canvas, canvas container, and rubber band\n // canvas in synch.\n canvas_div.css('width', width)\n canvas_div.css('height', height)\n\n canvas.attr('width', width * mpl.ratio);\n canvas.attr('height', height * mpl.ratio);\n canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n\n rubberband.attr('width', width);\n rubberband.attr('height', height);\n }\n\n // Set the figure to an initial 600x600px, this will subsequently be updated\n // upon first draw.\n this._resize_canvas(600, 600);\n\n // Disable right mouse context menu.\n $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n return false;\n });\n\n function set_focus () {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n}\n\nmpl.figure.prototype._init_toolbar = function() {\n var fig = this;\n\n var nav_element = $('
');\n nav_element.attr('style', 'width: 100%');\n this.root.append(nav_element);\n\n // Define a callback function for later on.\n function toolbar_event(event) {\n return fig.toolbar_button_onclick(event['data']);\n }\n function toolbar_mouse_event(event) {\n return fig.toolbar_button_onmouseover(event['data']);\n }\n\n for(var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n // put a spacer in here.\n continue;\n }\n var button = $('