diff --git a/.demo/Fee Waiver demo.ipynb b/.demo/Fee Waiver demo.ipynb new file mode 100644 index 0000000000..504f0e8219 --- /dev/null +++ b/.demo/Fee Waiver demo.ipynb @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Connected to flyte.lyft.net\n" + ] + } + ], + "source": [ + "from flytekit.configuration import set_flyte_config_file, platform\n", + "set_flyte_config_file(\"/Users/kumare/.ssh/notebook-production.config\")\n", + "#set_flyte_config_file(\"notebook.config\")\n", + "\n", + "print(\"Connected to {}\".format(platform.URL.get()))\n", + "\n", + "def print_console_url(exc):\n", + " print(\"http://{}/console/projects/{}/domains/{}/executions/{}\".format(platform.URL.get(), exc.id.project, exc.id.domain, exc.id.name))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "query=\"\"\"WITH eme AS (\n", + " SELECT\n", + " ride_id,\n", + " feature_driver_distance_at_arrival_meters,\n", + " feature_driver_distance_at_cancellation_meters,\n", + " feature_dvr_cancellation_rate,\n", + " feature_dvr_no_show_rate,\n", + " feature_dvr_num_voice_calls_to_pax,\n", + " feature_dvr_rides_28d,\n", + " feature_dvr_sum_call_duration,\n", + " feature_dvr_total_rides,\n", + " feature_fixed_fare_amount,\n", + " feature_gh6_total_rides,\n", + " feature_has_waypoint,\n", + " feature_hour_local,\n", + " feature_hour_of_week_local,\n", + " feature_hour_of_week_shifted_local,\n", + " feature_hour_shifted_local,\n", + " feature_is_scheduled_ride,\n", + " feature_num_average_daily_rides_canceled,\n", + " feature_num_rides_taken,\n", + " feature_pax_avg_pickup_time_seconds,\n", + " feature_pax_no_show_rate,\n", + " feature_pax_num_voice_calls_to_driver,\n", + " feature_pax_sms,\n", + " feature_pax_sms_char_len,\n", + " feature_pax_sum_call_duration,\n", + " feature_pax_total_rides,\n", + " feature_pax_unsuccessful_voice,\n", + " feature_request_started_at_to_arrived_at_seconds,\n", + " feature_seconds_since_arrival,\n", + " feature_upfront_fare_amount\n", + " FROM event_model_executed\n", + " WHERE ds >= '{{.inputs.start_date}}'\n", + " AND ds < '{{.inputs.end_date}}'\n", + " AND model = 'dummyfeatureloggingnoshowmodel'\n", + "),\n", + "\n", + "dsi AS (\n", + " SELECT\n", + " ride_id,\n", + " MAX(CAST(is_a1k AS INT)) AS pax_a1k\n", + " FROM dimension_support_issues\n", + " WHERE issue_started_at >= CAST('{{.inputs.start_date}}' AS TIMESTAMP)\n", + " AND issue_started_at < CAST('{{.inputs.end_date}}' AS TIMESTAMP) + INTERVAL '7' DAY\n", + " AND impacted_user = 'passenger'\n", + " GROUP BY ride_id\n", + ")\n", + "\n", + "SELECT\n", + " erc.ride_id,\n", + " feature_driver_distance_at_arrival_meters,\n", + " feature_driver_distance_at_cancellation_meters,\n", + " feature_dvr_cancellation_rate,\n", + " feature_dvr_no_show_rate,\n", + " feature_dvr_num_voice_calls_to_pax,\n", + " feature_dvr_rides_28d, \n", + " feature_dvr_sum_call_duration,\n", + " feature_dvr_total_rides,\n", + " feature_fixed_fare_amount,\n", + " feature_gh6_total_rides,\n", + " feature_has_waypoint,\n", + " feature_hour_local,\n", + " feature_hour_of_week_local,\n", + " feature_hour_of_week_shifted_local,\n", + " feature_hour_shifted_local,\n", + " feature_is_scheduled_ride,\n", + " feature_num_average_daily_rides_canceled,\n", + " feature_num_rides_taken,\n", + " feature_pax_avg_pickup_time_seconds,\n", + " feature_pax_no_show_rate,\n", + " feature_pax_num_voice_calls_to_driver,\n", + " feature_pax_sms,\n", + " feature_pax_sms_char_len,\n", + " feature_pax_sum_call_duration,\n", + " feature_pax_total_rides,\n", + " feature_pax_unsuccessful_voice,\n", + " feature_request_started_at_to_arrived_at_seconds,\n", + " feature_seconds_since_arrival,\n", + " feature_upfront_fare_amount,\n", + " CASE WHEN dsi.pax_a1k = 1 THEN TRUE ELSE FALSE END AS should_waive_fee\n", + "\n", + "FROM event_cancels_process_canceled_ride erc\n", + "JOIN experimentation.latest_exposure le\n", + " ON erc.passenger_lyft_id = le.user_lyft_id\n", + " AND erc.ds >= '{{.inputs.start_date}}'\n", + " AND erc.ds < '{{.inputs.end_date}}'\n", + " AND erc.after_arrived = TRUE\n", + " AND (erc.due_to_no_show = TRUE OR erc.canceling_party = 'passenger')\n", + " AND erc.cancel_penalty > 0\n", + " AND le.experiment = 'CP_SXP_PAC_NS_JointHoldout_2019Q4'\n", + " AND erc.occurred_at > le.first_exposed_at\n", + " AND le.variant = 'holdout'\n", + "JOIN eme \n", + " ON erc.ride_id = eme.ride_id\n", + "LEFT JOIN dsi\n", + " ON erc.ride_id = dsi.ride_id\n", + "WHERE erc.ds >= '{{.inputs.start_date}}'\n", + " AND erc.ds < '{{.inputs.end_date}}'\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from flytekit.sdk.tasks import inputs\n", + "from flytekit.sdk.types import Types\n", + "from flytekit.common.tasks.presto_task import SdkPrestoTask\n", + "\n", + "schema = Types.Schema([\n", + "('feature_driver_distance_at_arrival_meters', Types.Integer),\n", + "('feature_driver_distance_at_cancellation_meters', Types.Integer),\n", + "('feature_dvr_cancellation_rate', Types.Integer),\n", + "('feature_dvr_no_show_rate', Types.Integer),\n", + "('feature_dvr_num_voice_calls_to_pax', Types.Integer),\n", + "('feature_dvr_rides_28d', Types.Integer),\n", + "('feature_dvr_sum_call_duration', Types.Integer),\n", + "('feature_dvr_total_rides', Types.Integer),\n", + "('feature_fixed_fare_amount', Types.Integer),\n", + "('feature_gh6_total_rides', Types.Integer),\n", + "('feature_has_waypoint', Types.Integer),\n", + "('feature_hour_local', Types.Integer),\n", + "('feature_hour_of_week_local', Types.Integer),\n", + "('feature_hour_of_week_shifted_local', Types.Integer),\n", + "('feature_hour_shifted_local', Types.Integer),\n", + "('feature_is_scheduled_ride', Types.Integer),\n", + "('feature_num_average_daily_rides_canceled', Types.Integer),\n", + "('feature_num_rides_taken', Types.Integer),\n", + "('feature_pax_avg_pickup_time_seconds', Types.Integer),\n", + "('feature_pax_no_show_rate', Types.Integer),\n", + "('feature_pax_num_voice_calls_to_driver', Types.Integer),\n", + "('feature_pax_sms', Types.Integer),\n", + "('feature_pax_sms_char_len', Types.Integer),\n", + "('feature_pax_sum_call_duration', Types.Integer),\n", + "('feature_pax_total_rides', Types.Integer),\n", + "('feature_pax_unsuccessful_voice', Types.Integer),\n", + "('feature_request_started_at_to_arrived_at_seconds', Types.Integer),\n", + "('feature_seconds_since_arrival', Types.Integer),\n", + "('feature_upfront_fare_amount', Types.Integer),\n", + "])\n", + "\n", + "schema = Types.Schema()\n", + "\n", + "presto = SdkPrestoTask(\n", + " task_inputs=inputs(start_date=Types.String, end_date=Types.String),\n", + " statement=query,\n", + " output_schema=schema,\n", + " catalog=\"hive\",\n", + " schema=\"default\",\n", + " discoverable=True,\n", + " discovery_version=\"1\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "http://flyte.lyft.net/console/projects/flyteexamples/domains/development/executions/d42y9db6qz\n" + ] + } + ], + "source": [ + "exc = presto.register_and_launch(\"flyteexamples\", \"development\", inputs={\"start_date\":\"2020-04-07\", \"end_date\":\"2020-04-01\"})\n", + "print_console_url(exc)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "exc.wait_for_completion()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "exc.sync()\n", + "results = exc.outputs[\"results\"]\n", + "results.download(\"/tmp/data\", overwrite=True)\n", + "dfs = []\n", + "with results as reader:\n", + " for df in reader.iter_chunks():\n", + " dfs.append(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dfs" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "IndentationError", + "evalue": "unexpected indent (, line 3)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m3\u001b[0m\n\u001b[0;31m train_features = Index(['feature_driver_distance_at_arrival_meters',\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unexpected indent\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "train_dataset, test_dataset = train_test_split(df, test_size=0.33, random_state=42)\n", + " train_features = Index(['feature_driver_distance_at_arrival_meters',\n", + " 'feature_driver_distance_at_cancellation_meters',\n", + " 'feature_dvr_cancellation_rate',\n", + " 'feature_dvr_no_show_rate',\n", + " 'feature_dvr_num_voice_calls_to_pax',\n", + " 'feature_dvr_rides_28d', \n", + " 'feature_dvr_sum_call_duration',\n", + " 'feature_dvr_total_rides',\n", + " 'feature_fixed_fare_amount',\n", + " 'feature_gh6_total_rides',\n", + " 'feature_has_waypoint',\n", + " 'feature_hour_local',\n", + " 'feature_hour_of_week_local',\n", + " 'feature_hour_of_week_shifted_local',\n", + " 'feature_hour_shifted_local',\n", + " 'feature_is_scheduled_ride',\n", + " 'feature_num_average_daily_rides_canceled',\n", + " 'feature_num_rides_taken',\n", + " 'feature_pax_avg_pickup_time_seconds',\n", + " 'feature_pax_no_show_rate',\n", + " 'feature_pax_num_voice_calls_to_driver',\n", + " 'feature_pax_sms',\n", + " 'feature_pax_sms_char_len',\n", + " 'feature_pax_sum_call_duration',\n", + " 'feature_pax_total_rides',\n", + " 'feature_pax_unsuccessful_voice',\n", + " 'feature_request_started_at_to_arrived_at_seconds',\n", + " 'feature_seconds_since_arrival',\n", + " 'feature_upfront_fare_amount'], dtype='object')\n", + " \n", + "labels = Index(['should_waive_fee'])\n", + "\n", + "x_train = train_dataset[train_features]\n", + "y_train = train_dataset[labels]\n", + "\n", + "x_test = test_dataset[train_features]\n", + "y_test = test_dataset[labels]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flytekit.sdk.workflow import workflow_class, Input, Output" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.4 64-bit ('flytekit': virtualenv)", + "language": "python", + "name": "python37464bitflytekitvirtualenv72cbb5e9968e4a299c6026c09cce8d4c" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/.demo/waive-fee.ipynb b/.demo/waive-fee.ipynb new file mode 100644 index 0000000000..2b0d6ef834 --- /dev/null +++ b/.demo/waive-fee.ipynb @@ -0,0 +1,295 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-25T18:29:22.036093Z", + "start_time": "2020-03-25T18:29:22.025464Z" + } + }, + "outputs": [], + "source": [ + "from pandas import Index\n", + "import pandas as pd\n", + "import numpy as np\n", + "from modelexeclib.wrappers.lgbm import LGBMRegressor\n", + "\n", + "class Model(object):\n", + " HYPERPARAMETERS = [\n", + " {'name':'num_leaves','type':'int', 'default_value': 2},\n", + " ]\n", + "\n", + " def __init__(self, hyperparameters=None):\n", + " hyperparameters = hyperparameters or {}\n", + " # Read and convert hyperparameters\n", + "\n", + " def train(self):\n", + " # Get training data\n", + " from lyft_analysis.db import presto\n", + " df = presto.DatabaseTool().query(\"\"\"WITH eme AS (\n", + " SELECT\n", + " ride_id,\n", + " feature_driver_distance_at_arrival_meters,\n", + " feature_driver_distance_at_cancellation_meters,\n", + " feature_dvr_cancellation_rate,\n", + " feature_dvr_no_show_rate,\n", + " feature_dvr_num_voice_calls_to_pax,\n", + " feature_dvr_rides_28d,\n", + " feature_dvr_sum_call_duration,\n", + " feature_dvr_total_rides,\n", + " feature_fixed_fare_amount,\n", + " feature_gh6_total_rides,\n", + " feature_has_waypoint,\n", + " feature_hour_local,\n", + " feature_hour_of_week_local,\n", + " feature_hour_of_week_shifted_local,\n", + " feature_hour_shifted_local,\n", + " feature_is_scheduled_ride,\n", + " feature_num_average_daily_rides_canceled,\n", + " feature_num_rides_taken,\n", + " feature_pax_avg_pickup_time_seconds,\n", + " feature_pax_no_show_rate,\n", + " feature_pax_num_voice_calls_to_driver,\n", + " feature_pax_sms,\n", + " feature_pax_sms_char_len,\n", + " feature_pax_sum_call_duration,\n", + " feature_pax_total_rides,\n", + " feature_pax_unsuccessful_voice,\n", + " feature_request_started_at_to_arrived_at_seconds,\n", + " feature_seconds_since_arrival,\n", + " feature_upfront_fare_amount\n", + " FROM hive.default.event_model_executed\n", + " WHERE ds >= '2020-02-04'\n", + " AND ds < '2020-03-06'\n", + " AND model = 'dummyfeatureloggingnoshowmodel'\n", + "),\n", + "\n", + "dsi AS (\n", + " SELECT\n", + " ride_id,\n", + " MAX(CAST(is_a1k AS INT)) AS pax_a1k\n", + " FROM default.dimension_support_issues\n", + " WHERE issue_started_at >= CAST('2020-02-04' AS TIMESTAMP)\n", + " AND issue_started_at < CAST('2020-03-06' AS TIMESTAMP) + INTERVAL '7' DAY\n", + " AND impacted_user = 'passenger'\n", + " GROUP BY ride_id\n", + ")\n", + "\n", + "SELECT\n", + " erc.ride_id,\n", + " feature_driver_distance_at_arrival_meters,\n", + " feature_driver_distance_at_cancellation_meters,\n", + " feature_dvr_cancellation_rate,\n", + " feature_dvr_no_show_rate,\n", + " feature_dvr_num_voice_calls_to_pax,\n", + " feature_dvr_rides_28d, \n", + " feature_dvr_sum_call_duration,\n", + " feature_dvr_total_rides,\n", + " feature_fixed_fare_amount,\n", + " feature_gh6_total_rides,\n", + " feature_has_waypoint,\n", + " feature_hour_local,\n", + " feature_hour_of_week_local,\n", + " feature_hour_of_week_shifted_local,\n", + " feature_hour_shifted_local,\n", + " feature_is_scheduled_ride,\n", + " feature_num_average_daily_rides_canceled,\n", + " feature_num_rides_taken,\n", + " feature_pax_avg_pickup_time_seconds,\n", + " feature_pax_no_show_rate,\n", + " feature_pax_num_voice_calls_to_driver,\n", + " feature_pax_sms,\n", + " feature_pax_sms_char_len,\n", + " feature_pax_sum_call_duration,\n", + " feature_pax_total_rides,\n", + " feature_pax_unsuccessful_voice,\n", + " feature_request_started_at_to_arrived_at_seconds,\n", + " feature_seconds_since_arrival,\n", + " feature_upfront_fare_amount,\n", + " CASE WHEN dsi.pax_a1k = 1 THEN TRUE ELSE FALSE END AS should_waive_fee\n", + "\n", + "FROM default.event_cancels_process_canceled_ride erc\n", + "JOIN experimentation.latest_exposure le\n", + " ON erc.passenger_lyft_id = le.user_lyft_id\n", + " AND erc.ds >= '2020-02-04'\n", + " AND erc.ds < '2020-03-06'\n", + " AND erc.after_arrived = TRUE\n", + " AND (erc.due_to_no_show = TRUE OR erc.canceling_party = 'passenger')\n", + " AND erc.cancel_penalty > 0\n", + " AND le.experiment = 'CP_SXP_PAC_NS_JointHoldout_2019Q4'\n", + " AND erc.occurred_at > le.first_exposed_at\n", + " AND le.variant = 'holdout'\n", + "JOIN eme \n", + " ON erc.ride_id = eme.ride_id\n", + "LEFT JOIN dsi\n", + " ON erc.ride_id = dsi.ride_id\n", + "WHERE erc.ds >= '2020-02-04'\n", + " AND erc.ds < '2020-03-06'\"\"\")\n", + " print(\"retrieved data\")\n", + "\n", + " from sklearn.model_selection import train_test_split\n", + " train_dataset, test_dataset = train_test_split(df, test_size=0.33, random_state=42)\n", + " train_features = Index(['feature_driver_distance_at_arrival_meters',\n", + " 'feature_driver_distance_at_cancellation_meters',\n", + " 'feature_dvr_cancellation_rate',\n", + " 'feature_dvr_no_show_rate',\n", + " 'feature_dvr_num_voice_calls_to_pax',\n", + " 'feature_dvr_rides_28d', \n", + " 'feature_dvr_sum_call_duration',\n", + " 'feature_dvr_total_rides',\n", + " 'feature_fixed_fare_amount',\n", + " 'feature_gh6_total_rides',\n", + " 'feature_has_waypoint',\n", + " 'feature_hour_local',\n", + " 'feature_hour_of_week_local',\n", + " 'feature_hour_of_week_shifted_local',\n", + " 'feature_hour_shifted_local',\n", + " 'feature_is_scheduled_ride',\n", + " 'feature_num_average_daily_rides_canceled',\n", + " 'feature_num_rides_taken',\n", + " 'feature_pax_avg_pickup_time_seconds',\n", + " 'feature_pax_no_show_rate',\n", + " 'feature_pax_num_voice_calls_to_driver',\n", + " 'feature_pax_sms',\n", + " 'feature_pax_sms_char_len',\n", + " 'feature_pax_sum_call_duration',\n", + " 'feature_pax_total_rides',\n", + " 'feature_pax_unsuccessful_voice',\n", + " 'feature_request_started_at_to_arrived_at_seconds',\n", + " 'feature_seconds_since_arrival',\n", + " 'feature_upfront_fare_amount'], dtype='object')\n", + " labels = Index(['should_waive_fee'])\n", + " \n", + " x_train = train_dataset[train_features]\n", + " y_train = train_dataset[labels]\n", + " \n", + " x_test = test_dataset[train_features]\n", + " y_test = test_dataset[labels]\n", + " print(\"split data set\")\n", + "\n", + " # Construct model using modelexeclib wrapper\n", + " lgbm = LGBMRegressor(n_estimators=2)\n", + "\n", + " # Fit model\n", + " lgbm.fit(x_train, y_train)\n", + " print(\"model fit done\")\n", + " \n", + " y_predict = lgbm.predict(x_test)\n", + " print(y_predict)\n", + " \n", + " from sklearn.metrics import f1_score\n", + " score = f1_score(y_test, y_predict.round(), average='weighted')\n", + " print(\"f1 score computed {}\".format(score))\n", + "\n", + " from lyftlearnclient.metrics import Metrics\n", + " metrics = Metrics()\n", + " metrics.emit('f1-score', score)\n", + "\n", + " # Return fitted model\n", + " # return lgbm\n", + "\n", + " def init_predict(self):\n", + " # type: (None) -> None\n", + " # This will be called before batch_predict() calls, and called once before serving predict() calls, so any slow\n", + " # operations to set up the model, e.g download weights from S3 or load model checkpoints should be done here.\n", + " pass\n", + "\n", + " def predict(self, request_data):\n", + " # type: (dict) -> (object):\n", + " # Online prediction on a single sample.\n", + " # The input dict will be parsed from the a REST POST request's json body\n", + " # The output object must be json serializable (e.g. a python dictionary)\n", + " pass\n", + "\n", + " def batch_predict(self):\n", + " # type: (None) -> None\n", + " # Fetch data to predict, run prediction, save results.\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-25T18:17:09.987599Z", + "start_time": "2020-03-25T18:16:50.292311Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "retrieved data\n", + "split data set\n", + "model fit done\n", + "[0.29592739 0.2219032 0.25455321 ... 0.23480338 0.2219032 0.30006669]\n", + "f1 score computed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/code/venvs/venv/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.\n", + " 'precision', 'predicted', average, warn_for)\n", + "WARNING:lyftlearnclient.metrics:There was an error retrieving model uuid.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "f1-score=0.6351484574799411\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Model({x['name']: x['default_value'] for x in Model.HYPERPARAMETERS})\n", + "model.train()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 4ebe7115b2..1679c05c5a 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -2,4 +2,4 @@ import flytekit.plugins -__version__ = '0.10.1' +__version__ = '0.10.2' diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py index 2eec22ed44..cbf1368945 100644 --- a/flytekit/common/workflow.py +++ b/flytekit/common/workflow.py @@ -455,17 +455,16 @@ def _discover_workflow_components(workflow_class): return inputs, outputs, nodes -def build_sdk_workflow_from_metaclass(metaclass, queuing_budget=None, on_failure=None, cls=None): +def build_sdk_workflow_from_metaclass(metaclass, on_failure=None, cls=None): """ :param T metaclass: :param cls: This is the class that will be instantiated from the inputs, outputs, and nodes. This will be used by users extending the base Flyte programming model. If set, it must be a subclass of SdkWorkflow. - :param queuing_budget datetime.timedelta: [Optional] Budget that specifies the amount of time a workflow can be queued up for execution. :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy when the workflow detects a failure. :rtype: SdkWorkflow """ inputs, outputs, nodes = _discover_workflow_components(metaclass) - metadata = _workflow_models.WorkflowMetadata(queuing_budget=queuing_budget if queuing_budget else None, on_failure=on_failure if on_failure else None) + metadata = _workflow_models.WorkflowMetadata(on_failure=on_failure if on_failure else None) return (cls or SdkWorkflow)( inputs=[i for i in sorted(inputs, key=lambda x: x.name)], outputs=[o for o in sorted(outputs, key=lambda x: x.name)], diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 6e70016108..fdf12c694f 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -465,23 +465,14 @@ class OnFailurePolicy(object): FAIL_IMMEDIATELY = _core_workflow.WorkflowMetadata.FAIL_IMMEDIATELY FAIL_AFTER_EXECUTABLE_NODES_COMPLETE = _core_workflow.WorkflowMetadata.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE - def __init__(self, queuing_budget=None, on_failure=None): + def __init__(self, on_failure=None): """ Metadata for the workflow. - :param queuing_budget datetime.timedelta: [Optional] Budget that specifies the amount of time a workflow can be queued up for execution. :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy when the workflow detects a failure. """ - self._queuing_budget = queuing_budget self._on_failure = on_failure - @property - def queuing_budget(self): - """ - :rtype: datetime.timedelta - """ - return self._queuing_budget - @property def on_failure(self): """ @@ -494,8 +485,6 @@ def to_flyte_idl(self): :rtype: flyteidl.core.workflow_pb2.WorkflowMetadata """ workflow_metadata = _core_workflow.WorkflowMetadata() - if self._queuing_budget: - workflow_metadata.queuing_budget.FromTimedelta(self.queuing_budget) if self.on_failure: workflow_metadata.on_failure = self.on_failure return workflow_metadata @@ -507,10 +496,10 @@ def from_flyte_idl(cls, pb2_object): :rtype: WorkflowMetadata """ return cls( - queuing_budget=pb2_object.queuing_budget.ToTimedelta() if pb2_object.queuing_budget else None, on_failure=pb2_object.on_failure if pb2_object.on_failure else WorkflowMetadata.OnFailurePolicy.FAIL_IMMEDIATELY ) + class WorkflowMetadataDefaults(_common.FlyteIdlEntity): def __init__(self, interruptible=None): diff --git a/flytekit/sdk/workflow.py b/flytekit/sdk/workflow.py index 57c9c2d71b..a6abe2bb26 100644 --- a/flytekit/sdk/workflow.py +++ b/flytekit/sdk/workflow.py @@ -1,7 +1,9 @@ from __future__ import absolute_import + +import six as _six + from flytekit.common import workflow as _common_workflow, promise as _promise from flytekit.common.types import helpers as _type_helpers -import six as _six class Input(_promise.Input): @@ -42,7 +44,7 @@ def __init__(self, value, sdk_type=None, help=None): ) -def workflow_class(_workflow_metaclass=None, cls=None, queuing_budget=None, on_failure=None): +def workflow_class(_workflow_metaclass=None, cls=None, on_failure=None): """ This is a decorator for wrapping class definitions into workflows. @@ -62,13 +64,12 @@ class MyWorkflow(object): :param cls: This is the class that will be instantiated from the inputs, outputs, and nodes. This will be used by users extending the base Flyte programming model. If set, it must be a subclass of :py:class:`flytekit.common.workflow.SdkWorkflow`. - :param queuing_budget datetime.timedelta: [Optional] Budget that specifies the amount of time a workflow can be queued up for execution. :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy when the workflow detects a failure. :rtype: flytekit.common.workflow.SdkWorkflow """ def wrapper(metaclass): - wf = _common_workflow.build_sdk_workflow_from_metaclass(metaclass, cls=cls, queuing_budget=queuing_budget, on_failure=on_failure) + wf = _common_workflow.build_sdk_workflow_from_metaclass(metaclass, cls=cls, on_failure=on_failure) return wf if _workflow_metaclass is not None: @@ -76,7 +77,7 @@ def wrapper(metaclass): return wrapper -def workflow(nodes, inputs=None, outputs=None, cls=None, queuing_budget=None, on_failure=None): +def workflow(nodes, inputs=None, outputs=None, cls=None, on_failure=None): """ This function provides a user-friendly interface for authoring workflows. @@ -109,14 +110,12 @@ def workflow(nodes, inputs=None, outputs=None, cls=None, queuing_budget=None, on :param T cls: This is the class that will be instantiated from the inputs, outputs, and nodes. This will be used by users extending the base Flyte programming model. If set, it must be a subclass of :py:class:`flytekit.common.workflow.SdkWorkflow`. - :param queuing_budget datetime.timedelta: [Optional] Budget that specifies the amount of time a workflow can be queued up for execution. - :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy when the workflow detects a failure. + :param flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy on_failure: [Optional] The execution policy when the workflow detects a failure. :rtype: flytekit.common.workflow.SdkWorkflow """ wf = (cls or _common_workflow.SdkWorkflow)( inputs=[v.rename_and_return_reference(k) for k, v in sorted(_six.iteritems(inputs or {}))], outputs=[v.rename_and_return_reference(k) for k, v in sorted(_six.iteritems(outputs or {}))], nodes=[v.assign_id_and_return(k) for k, v in sorted(_six.iteritems(nodes))], - metadata=_common_workflow._workflow_models.WorkflowMetadata(queuing_budget=queuing_budget) if queuing_budget else None - ) + metadata=_common_workflow._workflow_models.WorkflowMetadata(on_failure=on_failure)) return wf diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index c5a34ed094..8a3cb26cae 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -27,6 +27,7 @@ def test_alias(): assert obj2.alias == 'myalias' assert obj2.var == 'myvar' + def test_workflow_template(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() @@ -50,7 +51,7 @@ def test_workflow_template(): ) obj = _workflow.WorkflowTemplate( id=_generic_id, - metadata=wf_metadata, + metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], @@ -58,54 +59,22 @@ def test_workflow_template(): obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj -def test_workflow_template_with_queuing_budget(): - task = _workflow.TaskNode(reference_id=_generic_id) - nm = _get_sample_node_metadata() - int_type = _types.LiteralType(_types.SimpleType.INTEGER) - wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta(seconds=10)) - wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() - typed_interface = _interface.TypedInterface( - {'a': _interface.Variable(int_type, "description1")}, - { - 'b': _interface.Variable(int_type, "description2"), - 'c': _interface.Variable(int_type, "description3") - } - ) - wf_node = _workflow.Node( - id='some:node:id', - metadata=nm, - inputs=[], - upstream_node_ids=[], - output_aliases=[], - task_node=task - ) - obj = _workflow.WorkflowTemplate( - id=_generic_id, - metadata=wf_metadata, - metadata_defaults=wf_metadata_defaults, - interface=typed_interface, - nodes=[wf_node], - outputs=[]) - obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) - assert obj2 == obj - -def test_workflow_metadata_queuing_budget(): - obj = _workflow.WorkflowMetadata(queuing_budget=timedelta(seconds=10)) - obj2 = _workflow.WorkflowMetadata.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2 def test_workflow_metadata_failure_policy(): - obj = _workflow.WorkflowMetadata(on_failure=_workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) + obj = _workflow.WorkflowMetadata( + on_failure=_workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) obj2 = _workflow.WorkflowMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj.on_failure == _workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE assert obj2.on_failure == _workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE + def test_workflow_metadata(): obj = _workflow.WorkflowMetadata() obj2 = _workflow.WorkflowMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + def test_task_node(): obj = _workflow.TaskNode(reference_id=_generic_id) assert obj.reference_id == _generic_id diff --git a/tests/flytekit/unit/sdk/test_workflow.py b/tests/flytekit/unit/sdk/test_workflow.py index 75daae0303..eee4efda91 100644 --- a/tests/flytekit/unit/sdk/test_workflow.py +++ b/tests/flytekit/unit/sdk/test_workflow.py @@ -156,16 +156,3 @@ class sup(object): assert _get_node_by_id(sup, 'b').inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert _get_node_by_id(sup, 'b').inputs[0].binding.promise.var == 'input_2' assert _get_node_by_id(sup, 'c').inputs[0].binding.scalar.primitive.integer == 100 - -def test_workflow_queuing_budget(): - @inputs(a=Types.Integer) - @outputs(b=Types.Integer) - @python_task - def my_task(wf_params, a, b): - b.set(a + 1) - - @workflow_class(queuing_budget=datetime.timedelta(seconds=10)) - class my_workflow(object): - b = my_task(a=100) - - assert my_workflow.metadata.queuing_budget == datetime.timedelta(seconds=10)