diff --git a/.gitignore b/.gitignore index 8be49ffa06d..ef42fd34a68 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,4 @@ __pycache__ target build dist +apps/wide-deep-recommendation/model_training.ipynb diff --git a/python/friesian/example/multi_task/README.md b/python/friesian/example/multi_task/README.md new file mode 100644 index 00000000000..3f65b13d5e4 --- /dev/null +++ b/python/friesian/example/multi_task/README.md @@ -0,0 +1,160 @@ +# Multi-task Recommendation with BigDL +In addition to providing a personalized recommendation, recommendation systems need to output diverse +predictions to meet the needs of real-world applications, such as user click-through rates and browsing (or watching) time predictions for products. +This example demonstrates how to use BigDL Friesian to train [MMoE](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) or [PLE](https://dl.acm.org/doi/pdf/10.1145/3383313.3412236?casa_token=8fchWD8CHc0AAAAA:2cyP8EwkhIUlSFPRpfCGHahTddki0OEjDxfbUFMkXY5fU0FNtkvRzmYloJtLowFmL1en88FRFY4Q) for multi-task recommendation with large-scale data. + +## Prepare the environment +We highly recommend you use [Anaconda](https://www.anaconda.com/distribution/#linux) to prepare the environment, especially if you want to run on a yarn cluster. +``` +conda create -n bigdl python=3.7 #bigdl is conda environment name, you can set another name you like. +conda activate bigdl +pip install bigdl-friesian[train] +pip install tensorflow==2.9.1 +pip install deepctr[cpu] +``` +Refer to [this document](https://bigdl.readthedocs.io/en/latest/doc/UserGuide/python.html#install) for more installation guides. + +## Data Preparation +In this example, a [news dataset](https://github.com/zhongqiangwu960812/AI-RecommenderSystem/tree/master/Dataset) is used to demonstrate the training and testing process. +The original data has more than 1 million users, as well as more than 60 million clicks, and the processed training and test data have 2,977,923 and 962,066 records respectively. +Each row contains several feature values, timestamps and two labels. The timestamp is used to divide the training and testing sets. +The click prediction (classification) and duration time prediction (regression) are two output targets. Original data examples are as follows: +```angular2html ++----------+----------+-------------------+----------+----------+-------------+-----+--------+------+-------+--------+--------+------+------+-------------------+-------+-------------+--------------------+ +| user_id|article_id| expo_time|net_status|flush_nums|exop_position|click|duration|device| os|province| city| age|gender| ctime|img_num| cat_1| cat_2| ++----------+----------+-------------------+----------+----------+-------------+-----+--------+------+-------+--------+--------+------+------+-------------------+-------+-------------+--------------------+ +|1000541010| 464467760|2021-06-30 09:57:14| 2| 0| 13| 1| 28|V2054A|Android|Shanghai|Shanghai|A_0_24|female|2021-06-29 14:46:43| 3|Entertainment| Entertainment/Stars| +|1000541010| 463850913|2021-06-30 09:57:14| 2| 0| 15| 0| 0|V2054A|Android|Shanghai|Shanghai|A_0_24|female|2021-06-27 22:29:13| 11| Fashions|Fashions/Female F...| +|1000541010| 464022440|2021-06-30 09:57:14| 2| 0| 17| 0| 0|V2054A|Android|Shanghai|Shanghai|A_0_24|female|2021-06-28 12:22:54| 7| Rural|Rural/Agriculture...| +|1000541010| 464586545|2021-06-30 09:58:31| 2| 1| 20| 0| 0|V2054A|Android|Shanghai|Shanghai|A_0_24|female|2021-06-29 13:25:06| 5|Entertainment| Entertainment/Stars| +|1000541010| 465352885|2021-07-03 18:13:03| 5| 0| 18| 0| 0|V2054A|Android|Shanghai|Shanghai|A_0_24|female|2021-07-02 10:43:51| 18|Entertainment| Entertainment/Stars| ++----------+----------+-------------------+----------+----------+-------------+-----+--------+------+-------+--------+--------+------+------+-------------------+-------+-------------+--------------------+ +``` + +With the built-in high-level preprocessing operations in Friesian FeatureTable, we can easily perform distributed preprocessing for large-scale data. +The details of preprocessing can be found [here](https://github.com/intel-analytics/BigDL/blob/main/apps/wide-deep-recommendation/feature_engineering.ipynb). Examples of processed data are as follows: + +```angular2html ++-------------------+-----+--------+-------------------+-----------+-----+-------+----------+----------+----------+-------------+------+---+--------+----+---+------+-----+ +| expo_time|click|duration| ctime| img_num|cat_2|user_id|article_id|net_status|flush_nums|exop_position|device| os|province|city|age|gender|cat_1| ++-------------------+-----+--------+-------------------+-----------+-----+-------+----------+----------+----------+-------------+------+---+--------+----+---+------+-----+ +|2021-06-30 09:57:14| 1| 28|2021-06-29 14:46:43|0.016574586| 60| 14089| 87717| 4| 73| 1003| 36| 2| 38| 308| 5| 1| 5| +|2021-06-30 09:57:14| 0| 0|2021-06-27 22:29:13| 0.06077348| 47| 14089| 35684| 4| 73| 43| 36| 2| 38| 308| 5| 1| 32| +|2021-06-30 09:57:14| 0| 0|2021-06-28 12:22:54|0.038674034| 157| 14089| 20413| 4| 73| 363| 36| 2| 38| 308| 5| 1| 20| +|2021-06-30 09:58:31| 0| 0|2021-06-29 13:25:06|0.027624309| 60| 14089| 15410| 4| 312| 848| 36| 2| 38| 308| 5| 1| 5| +|2021-07-03 18:13:03| 0| 0|2021-07-02 10:43:51| 0.09944751| 60| 14089| 81707| 2| 73| 313| 36| 2| 38| 308| 5| 1| 5| ++-------------------+-----+--------+-------------------+-----------+-----+-------+----------+----------+----------+-------------+------+---+--------+----+---+------+-----+ +``` +Data preprocessing command: +```bash +python data_processing.py \ + --input_path /path/to/input/dataset \ + --output_path /path/to/save/processed/dataset \ + --cluster_mode local \ + --executor_cores 8 \ + --executor_memory 12g \ +``` +```bash +python data_processing.py \ + --input_path /path/to/input/dataset \ + --output_path /path/to/save/processed/dataset \ + --cluster_mode yarn \ + --executor_cores 8 \ + --executor_memory 12g \ + --num_executors 4 \ + --driver_cores 2 \ + --driver_memory 8g +``` + +__Options for data_processing:__ +* `input_path`: The path to input dataset. +* `output_path`: The path to save processed dataset. +* `cluster_mode`: The cluster mode, such as local, yarn, standalone or spark-submit. Default to be local. +* `master`: The master url, only used when cluster mode is standalone. Default to be None. +* `executor_cores`: The executor core number. Default to be 8. +* `executor_memory`: The executor memory. Default to be 12g. +* `num_executors`: The number of executors. Default to be 4. +* `driver_cores`: The driver core number. Default to be 2. +* `driver_memory`: The driver memory. Default to be 8g. + +__NOTE:__ +When the *cluster_mode* is yarn, *input_path* and *output_path* should be HDFS paths. + +## Train and test Multi-task models +After data preprocessing, the training command for MMoE or PLE model is as follows: +```bash +python run_multi_task.py \ + --model_type mmoe\ + --train_data_path /path/to/training/dataset \ + --test_data_path /path/to/testing/dataset \ + --model_save_path /path/to/save/the/trained/model \ + --cluster_mode local \ + --executor_cores 8 \ + --executor_memory 12g \ +``` +```bash +python run_multi_task.py \ + --model_type mmoe\ + --train_data_path /path/to/training/dataset \ + --test_data_path /path/to/testing/dataset \ + --model_save_path /path/to/save/the/trained/model \ + --cluster_mode yarn \ + --executor_cores 8 \ + --executor_memory 12g \ + --num_executors 4 \ + --driver_cores 2 \ + --driver_memory 8g +``` +Evaluate Results as follows: +```bash +python run_multi_task.py \ + --model_type mmoe\ + --test_data_path /path/to/testing/dataset \ + --model_save_path /path/to/save/the/trained/model \ + --cluster_mode local \ + --executor_cores 8 \ + --executor_memory 12g \ + --num_executors 4 \ + --driver_cores 2 \ + --driver_memory 8g +``` +Results: +```angular2html +1. For MMoE: +50/50 [==============================] - 85s 2s/step - loss: 5505.2607 - duration_loss: 5504.8799 - click_loss: 0.3727 - duration_mae: 30.[1520/1979] k_auc: 0.6574 - click_precision: 0.0000e+00 - click_recall: 0.0000e+00 - val_loss: 6546.5293 - val_duration_loss: 6546.0991 - val_click_loss: 0.4202 - val_duration_mae: 39.1881 - val_click_auc: 0.6486 - val_click_precision: 0.4036 - val_click_recall: 0.0012 +(Worker pid=22945) Epoch 7: early stopping +Save model to path: ./save_model/mmoe_model.bin +3759/3759 [==============================] - 78s 20ms/step - loss: 6546.6997 - duration_loss: 6546.2734 - click_loss: 0.4202 - duration_mae: 39.1884 - click_auc: 0.6486 - click_precision: 0.4036 - click_recall: 0.0012 +validation_loss 6546.69970703125 +validation_duration_loss 6546.2734375 +validation_click_loss 0.42016342282295227 +validation_duration_mae 39.18841552734375 +validation_click_auc 0.648556113243103 + +2. For PLE: +50/50 [==============================] - 87s 2s/step - loss: 6788.6426 - duration_loss: 6788.2168 - click_loss: 0.4217 - duration_mae: 38.3158 - click_auc: 0.6523 - click_precision: 0.3333 - click_recall: 9.7752e-04 - val_loss: 6610.4990 - val_duration_loss: 6610.0732 - val_click_loss: 0.4236 - val_duration_mae: 42.6656 - val_click_auc: 0.6482 - val_click_precision: 0.6667 - val_click_recall: 9.7058e-05 +(Worker pid=13791) Epoch 4: early stopping +Save model to path: ./save_model/ple_model.bin +3753/3759 [============================>.] - ETA: 0s - loss: 6612.4531 - duration_loss: 6612.0410 - click_loss: 0.4236 - duration_mae: 42.6693 - click_auc: 0.6482 - click_precision: 0.6667 - click_recall: 9.7249e-05 +validation_loss 6610.6552734375 +validation_duration_loss 6610.244140625 +validation_click_loss 0.4236340820789337 +validation_duration_mae 42.66642379760742 +validation_click_auc 0.6481693387031555 +``` + +__Options for training and test:__ +* `model_type`: The multi task model, mmoe or ple. Default to be mmoe. +* `train_data_path`: The path to training dataset. +* `test_data_path`: The path to testing dataset. +* `model_save_path`: The path to save model. +* `cluster_mode`: The cluster mode, such as local, yarn, standalone or spark-submit. Default to be local. +* `master`: The master url, only used when cluster mode is standalone. Default to be None. +* `executor_cores`: The executor core number. Default to be 8. +* `executor_memory`: The executor memory. Default to be 12g. +* `num_executors`: The number of executors. Default to be 4. +* `driver_cores`: The driver core number. Default to be 2. +* `driver_memory`: The driver memory. Default to be 8g. + +__NOTE:__ +When the *cluster_mode* is yarn, *train_data_path*, *test_data_path* ans *model_save_path* should be HDFS paths. diff --git a/python/friesian/example/multi_task/data_processing.py b/python/friesian/example/multi_task/data_processing.py new file mode 100644 index 00000000000..c8f27f5b1e5 --- /dev/null +++ b/python/friesian/example/multi_task/data_processing.py @@ -0,0 +1,159 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import os +from argparse import ArgumentParser + +from bigdl.friesian.feature import FeatureTable +from bigdl.orca import init_orca_context, stop_orca_context +from bigdl.dllib.utils.log4Error import invalidInputError + + +def transform(x): + # dealing with some abnormal data + if x == '上海': + return 0.0 + elif isinstance(x, float): + return float(x) + else: + return float(eval(x)) + + +def transform_cat_2(x): + return '-'.join(sorted(x.split('/'))) + + +def read_and_split(data_input_path, sparse_int_features, sparse_string_features, dense_features): + header_names = ['user_id', 'article_id', 'expo_time', 'net_status', 'flush_nums', + 'exop_position', 'click', 'duration', 'device', 'os', 'province', 'city', + 'age', 'gender', 'ctime', 'img_num', 'cat_1', 'cat_2' + ] + if data_input_path.endswith("csv"): + # data_pd = pd.read_csv(os.path.join(data_input_path, 'train_data.csv'), index_col=0, + # parse_dates=['expo_time'], low_memory=False) + # data_pd.to_csv('../train_data_new.csv', index=False, header=None) + tbl = FeatureTable.read_csv(data_input_path, header=False, names=header_names) + else: + tbl = FeatureTable.read_parquet(data_input_path) + + print('The number of total data: ', tbl.size()) + + tbl = tbl.cast(sparse_int_features, 'string') + tbl = tbl.cast(dense_features, 'string') + + # fill absence data + for feature in (sparse_int_features + sparse_string_features): + tbl = tbl.fillna("", feature) + tbl = tbl.fillna('0.0', 'img_num') + + process_img_num = lambda x: transform(x) + process_cat_2 = lambda x: transform_cat_2(x) + tbl = tbl.apply("img_num", "img_num", process_img_num, "float") + tbl = tbl.apply("cat_2", "cat_2", process_cat_2, "string") + + train_tbl = FeatureTable(tbl.df[tbl.df['expo_time'] < '2021-07-06']) + valid_tbl = FeatureTable(tbl.df[tbl.df['expo_time'] >= '2021-07-06']) + print('The number of train data: ', train_tbl.size()) + print('The number of test data: ', valid_tbl.size()) + return train_tbl, valid_tbl + + +def feature_engineering(train_tbl, valid_tbl, output_path, sparse_int_features, + sparse_string_features, dense_features): + import json + train_tbl, min_max_dict = train_tbl.min_max_scale(dense_features) + valid_tbl = valid_tbl.transform_min_max_scale(dense_features, min_max_dict) + cat_cols = sparse_string_features[-1:] + sparse_int_features + sparse_string_features[:-1] + for feature in cat_cols: + train_tbl, feature_idx = train_tbl.category_encode(feature) + valid_tbl = valid_tbl.encode_string(feature, feature_idx) + valid_tbl = valid_tbl.fillna(0, feature) + print("The class number of feature: {}/{}".format(feature, feature_idx.size())) + feature_idx.write_parquet(os.path.join(output_path, 'feature_maps')) + return train_tbl, valid_tbl + + +def parse_args(): + parser = ArgumentParser(description="Transform dataset for multi task demo") + parser.add_argument('--input_path', type=str, + default='/path/to/input/dataset', + help='The path for input dataset') + parser.add_argument('--output_path', type=str, default='/path/to/save/processed/dataset', + help='The path for output dataset') + parser.add_argument('--cluster_mode', type=str, default="local", + help='The cluster mode, such as local, yarn, standalone or spark-submit.') + parser.add_argument('--master', type=str, default=None, + help='The master url, only used when cluster mode is standalone.') + parser.add_argument('--executor_cores', type=int, default=8, + help='The executor core number.') + parser.add_argument('--executor_memory', type=str, default="12g", + help='The executor memory.') + parser.add_argument('--num_executors', type=int, default=4, + help='The number of executors.') + parser.add_argument('--driver_cores', type=int, default=2, + help='The driver core number.') + parser.add_argument('--driver_memory', type=str, default="8g", + help='The driver memory.') + args_ = parser.parse_args() + return args_ + + +if __name__ == '__main__': + args = parse_args() + if args.cluster_mode == "local": + sc = init_orca_context("local", cores=args.executor_cores, + memory=args.executor_memory) + elif args.cluster_mode == "standalone": + sc = init_orca_context("standalone", master=args.master, + cores=args.executor_cores, num_nodes=args.num_executors, + memory=args.executor_memory, + driver_cores=args.driver_cores, + driver_memory=args.driver_memory) + elif args.cluster_mode == "yarn": + sc = init_orca_context("yarn-client", cores=args.executor_cores, + num_nodes=args.num_executors, memory=args.executor_memory, + driver_cores=args.driver_cores, driver_memory=args.driver_memory) + elif args.cluster_mode == "spark-submit": + sc = init_orca_context("spark-submit") + else: + invalidInputError(False, + "cluster_mode should be one of 'local', 'yarn', 'standalone' and" + " 'spark-submit', but got " + args.cluster_mode) + + sparse_int_features = [ + 'user_id', 'article_id', + 'net_status', 'flush_nums', + 'exop_position', + ] + # put cat_2 at first bug + # put cat_1,cat_2 at first bug + sparse_string_features = [ + 'device', 'os', 'province', + 'city', 'age', + 'gender', 'cat_1', 'cat_2' + ] + dense_features = ['img_num'] + + # read, reformat and split data + df_train, df_test = read_and_split(args.input_path, sparse_int_features, + sparse_string_features, dense_features) + train_tbl, valid_tbl = feature_engineering(df_train, df_test, + args.output_path, + sparse_int_features, + sparse_string_features, dense_features) + train_tbl.write_parquet(os.path.join(args.output_path, 'train_processed')) + valid_tbl.write_parquet(os.path.join(args.output_path, 'test_processed')) + stop_orca_context() diff --git a/python/friesian/example/multi_task/run_multi_task.py b/python/friesian/example/multi_task/run_multi_task.py new file mode 100644 index 00000000000..c0577c7abcd --- /dev/null +++ b/python/friesian/example/multi_task/run_multi_task.py @@ -0,0 +1,233 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +from time import time +from argparse import ArgumentParser, ArgumentError +from keras.callbacks import EarlyStopping + +from bigdl.orca import init_orca_context, stop_orca_context +from bigdl.orca.learn.tf2.estimator import Estimator +from bigdl.friesian.feature import FeatureTable + +from deepctr.feature_column import SparseFeat, DenseFeat +from deepctr.models import MMOE, PLE + +from bigdl.dllib.utils.log4Error import invalidInputError + + +def build_model(model_type, sparse_features, dense_features, feature_max_idx): + sparse_feature_columns = [SparseFeat(feat, feature_max_idx[feat], + embedding_dim='auto') for feat in sparse_features] + dense_feature_columns = [DenseFeat(feat, 1) for feat in dense_features] + dnn_features_columns = sparse_feature_columns + dense_feature_columns + if model_type == 'mmoe': + model = MMOE(dnn_features_columns, tower_dnn_hidden_units=[], + task_types=['regression', 'binary'], + task_names=['duration', 'click']) + elif model_type == 'ple': + model = PLE(dnn_features_columns, shared_expert_num=1, specific_expert_num=1, + task_types=['regression', 'binary'], + num_levels=2, task_names=['duration', 'click']) + else: + invalidInputError(False, 'model_type should be one of "mmoe" and "ple", ' + 'but got ' + model_type) + return model + + +def model_creator(config): + model = build_model(model_type=config['model_type'], + sparse_features=config['column_info']['cat_cols'], + dense_features=config['column_info']['continuous_cols'], + feature_max_idx=config['column_info']['feature_max_idx']) + model.compile(optimizer='adam', + loss=["mean_squared_error", "binary_crossentropy"], + metrics=[['mae'], ["AUC", 'Precision', 'Recall']]) + return model + + +def label_cols(column_info): + return column_info["label"] + + +def feature_cols(column_info): + return column_info["cat_cols"] + column_info["embed_cols"] + column_info["continuous_cols"] + + +def train_multi_task(train_tbl_data, valid_tbl_data, save_path, model, + cat_cols, continuous_cols, feature_max_idx): + column_info = { + "cat_cols": cat_cols, + "continuous_cols": continuous_cols, + "feature_max_idx": feature_max_idx, + "embed_cols": [], + "embed_in_dims": [], + "embed_out_dims": [], + "label": ['duration', 'click']} + + config = { + "column_info": column_info, + "inter_op_parallelism": 4, + "intra_op_parallelism": 8, + "model_type": model # mmoe or ple + } + + batch_size = 256 + estimator = Estimator.from_keras( + model_creator=model_creator, + verbose=False, + config=config) + + train_count = train_tbl_data.size() + print("Total number of train records: {}".format(train_count)) + total_steps = math.ceil(train_count / batch_size) + steps_per_epoch = 50 + # To train the full dataset for an entire epoch + epochs = math.ceil(total_steps / steps_per_epoch) + val_count = valid_tbl_data.size() + print("Total number of val records: {}".format(val_count)) + val_steps = math.ceil(val_count / batch_size) + callbacks = [EarlyStopping(monitor='val_duration_mae', mode='min', verbose=1, patience=3), + EarlyStopping(monitor='val_click_auc', mode='max', verbose=1, patience=3)] + + start = time() + estimator.fit(data=train_tbl_data.df, + epochs=epochs, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + validation_data=valid_tbl_data.df, + validation_steps=val_steps, + callbacks=callbacks, + feature_cols=feature_cols(column_info), + label_cols=label_cols(column_info)) + end = time() + print("Training time is: ", end - start) + estimator.save(save_path) + print('Save model to path: ', save_path) + + +def test_multi_task(valid_tbl_data, save_path, model, cat_cols, continuous_cols, feature_max_idx): + column_info = { + "cat_cols": cat_cols, + "continuous_cols": continuous_cols, + "feature_max_idx": feature_max_idx, + "embed_cols": [], + "embed_in_dims": [], + "embed_out_dims": [], + "label": ['duration', 'click']} + config = { + "column_info": column_info, + "inter_op_parallelism": 4, + "intra_op_parallelism": 8, + "model_type": model # mmoe or ple + } + estimator = Estimator.from_keras( + model_creator=model_creator, + verbose=False, + config=config) + estimator.load(save_path) + + batch_size = 256 + val_steps = math.ceil(valid_tbl_data.size() / batch_size) + eval_results = estimator.evaluate(data=valid_tbl_data.df, + num_steps=val_steps, + batch_size=batch_size, + feature_cols=feature_cols(column_info), + label_cols=label_cols(column_info)) + for k, v in eval_results[0].items(): + print(k, v) + + +def _parse_args(): + parser = ArgumentParser(description="Set parameters for multi task demo") + + parser.add_argument('--model_type', type=str, default="mmoe", + help='The multi task model, mmoe or ple.') + parser.add_argument('--train_data_path', type=str, + default='path/to/training/dataset', + help='The path for training dataset.') + parser.add_argument('--test_data_path', type=str, + default='path/to/testing/dataset', + help='The path for testing dataset.') + parser.add_argument('--model_save_path', type=str, + default='path/to/save/the/trained/model', + help='The path for saving the trained model.') + parser.add_argument('--cluster_mode', type=str, default="local", + help='The cluster mode, such as local, yarn, standalone or spark-submit.') + parser.add_argument('--master', type=str, default=None, + help='The master url, only used when cluster mode is standalone.') + parser.add_argument('--executor_cores', type=int, default=8, + help='The executor core number.') + parser.add_argument('--executor_memory', type=str, default="12g", + help='The executor memory.') + parser.add_argument('--num_executors', type=int, default=4, + help='The number of executors.') + parser.add_argument('--driver_cores', type=int, default=2, + help='The driver core number.') + parser.add_argument('--driver_memory', type=str, default="8g", + help='The driver memory.') + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = _parse_args() + if args.cluster_mode == "local": # For local machine + sc = init_orca_context(cluster_mode="local", + cores=args.executor_cores, memory=args.executor_memory) + elif args.cluster_mode == "standalone": + sc = init_orca_context("standalone", master=args.master, + cores=args.executor_cores, num_nodes=args.num_executors, + memory=args.executor_memory, + driver_cores=args.driver_cores, + driver_memory=args.driver_memory) + elif args.cluster_mode == "yarn": # For Hadoop/YARN cluster + sc = init_orca_context(cluster_mode="yarn", cores=args.executor_cores, + num_nodes=args.num_executors, memory=args.executor_memory, + driver_cores=args.driver_cores, driver_memory=args.driver_memory, + object_store_memory="80g") + elif args.cluster_mode == "spark-submit": + sc = init_orca_context("spark-submit") + else: + invalidInputError(False, + "cluster_mode should be one of 'local', 'yarn', 'standalone' and" + " 'spark-submit', but got " + args.cluster_mode) + cat_cols = [ + 'user_id', + 'article_id', + 'net_status', + 'exop_position', + 'device', + 'city', + 'age', + 'gender', + 'cat_1', + ] + continuous_cols = ['img_num'] + feature_max_idx = {'user_id': 40000, 'article_id': 200000, 'net_status': 1004, + 'exop_position': 2000, 'device': 2000, + 'city': 1379, 'age': 1005, 'gender': 1003, 'cat_1': 1038} + + # do train + train_tbl = FeatureTable.read_parquet(args.train_data_path) + valid_tbl = FeatureTable.read_parquet(args.test_data_path) + train_multi_task(train_tbl, valid_tbl, args.model_save_path, + args.model_type, cat_cols, continuous_cols, + feature_max_idx) + # do test + test_multi_task(valid_tbl, args.model_save_path, args.model_type, + cat_cols, continuous_cols, feature_max_idx) + + stop_orca_context()