From e37776820f40d37cbea73a28118bfddee2e4b1aa Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Mon, 13 Apr 2020 19:09:49 +0800 Subject: [PATCH] Add doc for MXNetTrainer and some code refactor (#2198) * add doc and some code refactor * minor * update * minor --- .../orca/src/bigdl/orca/ray/mxnet/__init__.py | 3 +- .../src/bigdl/orca/ray/mxnet/mxnet_runner.py | 128 ++------------- .../src/bigdl/orca/ray/mxnet/mxnet_trainer.py | 146 ++++++++++++++++++ python/orca/src/bigdl/orca/ray/mxnet/utils.py | 43 ++++++ .../bigdl/orca/ray/mxnet/test_mxnet_gluon.py | 16 +- .../bigdl/orca/ray/mxnet/test_mxnet_symbol.py | 12 +- 6 files changed, 213 insertions(+), 135 deletions(-) create mode 100644 python/orca/src/bigdl/orca/ray/mxnet/mxnet_trainer.py create mode 100644 python/orca/src/bigdl/orca/ray/mxnet/utils.py diff --git a/python/orca/src/bigdl/orca/ray/mxnet/__init__.py b/python/orca/src/bigdl/orca/ray/mxnet/__init__.py index 895a220599b..740484906bc 100644 --- a/python/orca/src/bigdl/orca/ray/mxnet/__init__.py +++ b/python/orca/src/bigdl/orca/ray/mxnet/__init__.py @@ -14,4 +14,5 @@ # limitations under the License. # -from .mxnet_runner import MXNetTrainer +from .mxnet_trainer import MXNetTrainer +from .utils import create_trainer_config diff --git a/python/orca/src/bigdl/orca/ray/mxnet/mxnet_runner.py b/python/orca/src/bigdl/orca/ray/mxnet/mxnet_runner.py index 370cf7b58b5..1300bdd1ba2 100644 --- a/python/orca/src/bigdl/orca/ray/mxnet/mxnet_runner.py +++ b/python/orca/src/bigdl/orca/ray/mxnet/mxnet_runner.py @@ -18,21 +18,23 @@ import time import logging import subprocess -import socket import ray.services import mxnet as mx from mxnet import gluon -from contextlib import closing -from dmlc_tracker.tracker import get_host_ip +from zoo.ray.mxnet.utils import find_free_port class MXNetRunner(object): """Manages a MXNet model for training.""" + def setup_distributed(self, env, config, data_creator, model_creator, loss_creator=None, metrics_creator=None): logging.basicConfig(level=logging.INFO) # This can print log messages to console. self.logger = logging.getLogger() - self.config = config # TODO: add check for config keys + assert isinstance(config, dict), "config must be a dict" + for param in ["batch_size", "optimizer", "optimizer_params", "log_interval"]: + assert param in config, param + " must be specified in config" + self.config = config self.data_creator = data_creator self.model_creator = model_creator self.loss_creator = loss_creator @@ -117,9 +119,8 @@ def train(self, nb_epoch=1): self.trainer.step(batch.data[0].shape[0]) if self.metrics: self.metrics.update(label, outputs) - if "log_interval" in self.config and \ - not (i + 1) % self.config["log_interval"]: - # This would print on driver for each pid. + if not (i + 1) % self.config["log_interval"]: + # This would be logged on driver for each worker process. print_output = "" print_output \ += 'Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f' \ @@ -135,6 +136,7 @@ def train(self, nb_epoch=1): print_output += ' %s=%f' % (name, acc) self.logger.info(print_output) batch_start_time = time.time() + # TODO: save checkpoints if self.metrics: names, accs = self.metrics.get() if not isinstance(names, list): @@ -144,19 +146,21 @@ def train(self, nb_epoch=1): stats[name] = acc else: # Symbolic API # TODO: seems no history (i.e. validation accuracy) returned by fit? + if "init" not in self.config: + from mxnet.initializer import Uniform + self.config["init"] = Uniform(0.01) # This is the default value for MXNet self.model.fit(train_data=self.train_data, num_epoch=nb_epoch, initializer=self.config["init"], kvstore=self.kv, optimizer=self.config["optimizer"], optimizer_params=self.config["optimizer_params"], + eval_data=self.val_data, # TODO: eval and validation metrics could be different eval_metric=self.metrics, validation_metric=self.metrics, - eval_data=self.val_data, - batch_end_callback=None if "log_interval" not in self.config - else mx.callback.Speedometer(self.config["batch_size"], - self.config["log_interval"]), + batch_end_callback=mx.callback.Speedometer( + self.config["batch_size"], self.config["log_interval"]), epoch_end_callback=None if "model" not in self.config else mx.callback.do_checkpoint(self.config["model"])) epoch_time = time.time() - start_time @@ -182,105 +186,3 @@ def get_node_ip(self): def find_free_port(self): """Finds a free port on the current node.""" return find_free_port() - - -def find_free_port(): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -class MXNetTrainer(object): - # TODO: Add documentation. - def __init__(self, - config, # Pass in some config, including initializer, batch_size, etc. - data_creator, - # Return a MXNET model defined with either symbolic or gluon API. - model_creator, - # No need for symbolic API. Loss is already defined as model output. - loss_creator=None, - metrics_creator=None, - # Specify cpu resources for actors if necessary. - runner_cores=None): - self.config = config - self.data_creator = data_creator - self.model_creator = model_creator - self.loss_creator = loss_creator - self.metrics_creator = metrics_creator - self.num_workers = config["num_workers"] - self.num_servers = config["num_servers"] if "num_servers" in self.config \ - else self.num_workers - - # Generate actor class - # Add a dummy custom resource to diff worker from server if runner_cores is specified - # so that we can place one worker and one server on a node for better performance. - Worker = ray.remote(num_cpus=runner_cores, resources={"_mxnet_worker": 1})(MXNetRunner) \ - if runner_cores else ray.remote(MXNetRunner) - Server = ray.remote(num_cpus=runner_cores, resources={"_mxnet_server": 1})(MXNetRunner) \ - if runner_cores else ray.remote(MXNetRunner) - - # Start runners: workers followed by servers - self.runners = [ - Worker.remote() - for i in range(self.num_workers) - ] - self.runners += [ - Server.remote() - for i in range(self.num_servers) - ] - - # Compute URL for initializing distributed setup - ips = ray.get( - [runner.get_node_ip.remote() for runner in self.runners]) - ports = ray.get( - [runner.find_free_port.remote() for runner in self.runners]) - logger = logging.getLogger() - logger.info(ips) - logger.info(ports) - - env = { - "DMLC_PS_ROOT_URI": str(get_host_ip()), - "DMLC_PS_ROOT_PORT": str(find_free_port()), - "DMLC_NUM_SERVER": str(self.num_servers), - "DMLC_NUM_WORKER": str(self.num_workers), - } - envs = [] - for i in range(self.num_workers): - current_env = env.copy() - current_env['DMLC_ROLE'] = 'worker' - envs.append(current_env) - for i in range(self.num_servers): - current_env = env.copy() - current_env['DMLC_ROLE'] = 'server' - envs.append(current_env) - - env['DMLC_ROLE'] = 'scheduler' - modified_env = os.environ.copy() - modified_env.update(env) - # Need to contain system env to run bash - # TODO: Need to kill this process manually? - subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env) - - ray.get([ - runner.setup_distributed.remote(envs[i], self.config, - self.data_creator, - self.model_creator, - self.loss_creator, - self.metrics_creator) - for i, runner in enumerate(self.runners) - ]) - - def train(self, nb_epoch=1): - """Trains an MXNet model for several epochs.""" - stats = ray.get([w.train.remote(nb_epoch) for w in self.runners]) - return stats - - def shutdown(self): - """Shuts down runners and releases resources.""" - for runner in self.runners: - runner.shutdown.remote() - runner.__ray_terminate__.remote() - -# TODO: add model save and restore -# TODO: add predict, evaluate diff --git a/python/orca/src/bigdl/orca/ray/mxnet/mxnet_trainer.py b/python/orca/src/bigdl/orca/ray/mxnet/mxnet_trainer.py new file mode 100644 index 00000000000..5f35b88834a --- /dev/null +++ b/python/orca/src/bigdl/orca/ray/mxnet/mxnet_trainer.py @@ -0,0 +1,146 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import logging +import subprocess +import ray.services +from dmlc_tracker.tracker import get_host_ip +from zoo.ray.mxnet.mxnet_runner import MXNetRunner +from zoo.ray.mxnet.utils import find_free_port + + +class MXNetTrainer(object): + """ + MXNetTrainer provides an automatic setup for synchronous distributed MXNet training. + + :param config: A dictionary for training configurations. Keys must include the following: + batch_size, optimizer, optimizer_params, log_interval. + optimizer should be an MXNet optimizer or its string representation. + optimizer_params should be a dict in companion with the optimizer. It can contain learning_rate + and other optimization configurations. + log_interval should be an integer, specifying the interval for logging throughput and metrics + if any during the training process. + You can call create_trainer_config to create the config easily. + You can specify "seed" in config to set random seed. + You can specify "init" in seed to set model initializer. + + :param data_creator: A function that takes config and kv as arguments and returns an MXNet + DataIter/DataLoader for training or a tuple of training and validation datasets. + You can specify data related configurations for this function in the config argument above. + kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank + can be used in this function to split data for different workers if necessary. + + :param model_creator: A function that takes config as argument and returns an MXNet model. + The model can be defined either using MXNet symbolic API or imperative(gluon) API. + + :param loss_creator: A function that takes config as argument and returns an MXNet loss. + This is not needed for symbolic API where loss is already defined as model output. + + :param metrics_creator: A function that takes config as argument and returns one or a list of + MXNet metrics or corresponding string representations of metrics, for example, 'accuracy'. + This is not needed if you don't have validation data throughout the training. + + :param num_workers: The number of workers for distributed training. Default is 1. + :param num_servers: The number of servers for distributed training. Default is None and in this + case it would be equal to the number of workers. + :param runner_cores: The number of CPU cores allocated for each MXNet worker and server. + Default is None. You may need to specify this for better performance. + """ + def __init__(self, config, data_creator, model_creator, + loss_creator=None, metrics_creator=None, + num_workers=1, num_servers=None, runner_cores=None): + self.config = config + self.data_creator = data_creator + self.model_creator = model_creator + self.loss_creator = loss_creator + self.metrics_creator = metrics_creator + self.num_workers = num_workers + self.num_servers = num_servers if num_servers else self.num_workers + + # Generate actor class + # Add a dummy custom resource: _mxnet_worker and _mxnet_server to diff worker from server + # if runner_cores is specified so that we can place one worker and one server on a node + # for better performance. + Worker = ray.remote(num_cpus=runner_cores, resources={"_mxnet_worker": 1})(MXNetRunner) \ + if runner_cores else ray.remote(MXNetRunner) + Server = ray.remote(num_cpus=runner_cores, resources={"_mxnet_server": 1})(MXNetRunner) \ + if runner_cores else ray.remote(MXNetRunner) + + # Start runners: workers followed by servers + self.runners = [ + Worker.remote() + for i in range(self.num_workers) + ] + self.runners += [ + Server.remote() + for i in range(self.num_servers) + ] + + # Compute URL for initializing distributed setup + ips = ray.get( + [runner.get_node_ip.remote() for runner in self.runners]) + ports = ray.get( + [runner.find_free_port.remote() for runner in self.runners]) + logger = logging.getLogger() + logger.info(ips) + logger.info(ports) + + env = { + "DMLC_PS_ROOT_URI": str(get_host_ip()), + "DMLC_PS_ROOT_PORT": str(find_free_port()), + "DMLC_NUM_SERVER": str(self.num_servers), + "DMLC_NUM_WORKER": str(self.num_workers), + } + envs = [] + for i in range(self.num_workers): + current_env = env.copy() + current_env['DMLC_ROLE'] = 'worker' + envs.append(current_env) + for i in range(self.num_servers): + current_env = env.copy() + current_env['DMLC_ROLE'] = 'server' + envs.append(current_env) + + env['DMLC_ROLE'] = 'scheduler' + modified_env = os.environ.copy() + modified_env.update(env) + # Need to contain system env to run bash + # TODO: Need to kill this process manually? + subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env) + + ray.get([ + runner.setup_distributed.remote(envs[i], self.config, + self.data_creator, + self.model_creator, + self.loss_creator, + self.metrics_creator) + for i, runner in enumerate(self.runners) + ]) + + def train(self, nb_epoch=1): + """Trains an MXNet model for several epochs.""" + stats = ray.get([w.train.remote(nb_epoch) for w in self.runners]) + return stats + + def shutdown(self): + """Shuts down runners and releases resources.""" + for runner in self.runners: + runner.shutdown.remote() + runner.__ray_terminate__.remote() + +# TODO: add model save and restore +# TODO: add predict, evaluate diff --git a/python/orca/src/bigdl/orca/ray/mxnet/utils.py b/python/orca/src/bigdl/orca/ray/mxnet/utils.py new file mode 100644 index 00000000000..595f8ae65db --- /dev/null +++ b/python/orca/src/bigdl/orca/ray/mxnet/utils.py @@ -0,0 +1,43 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import socket +from contextlib import closing + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def create_trainer_config(batch_size, optimizer="sgd", optimizer_params=None, + log_interval=10, seed=None, extra_config=None): + if not optimizer_params: + optimizer_params = {'learning_rate': 0.01} + config = { + "batch_size": batch_size, + "optimizer": optimizer, + "optimizer_params": optimizer_params, + "log_interval": log_interval, + } + if seed: + config["seed"] = seed + if extra_config: + assert isinstance(extra_config, dict), "extra_config must be a dict" + config.update(extra_config) + return config diff --git a/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_gluon.py b/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_gluon.py index afd8d0fe978..b82c9e943de 100644 --- a/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_gluon.py +++ b/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_gluon.py @@ -22,7 +22,7 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from zoo.ray.mxnet import MXNetTrainer +from zoo.ray.mxnet import MXNetTrainer, create_trainer_config np.random.seed(1337) # for reproducibility @@ -66,16 +66,10 @@ def get_metrics(config): class TestMXNetGluon(TestCase): def test_gluon(self): - config = { - "num_workers": 2, - "num_servers": 2, - "batch_size": 32, - "optimizer": "sgd", - "optimizer_params": {'learning_rate': 0.01}, - "log_interval": 2, - "seed": 42 - } - trainer = MXNetTrainer(config, get_data_iters, get_model, get_loss, get_metrics) + config = create_trainer_config(batch_size=32, log_interval=2, optimizer="adam", + optimizer_params={'learning_rate': 0.02}) + trainer = MXNetTrainer(config, get_data_iters, get_model, get_loss, get_metrics, + num_workers=2) trainer.train(nb_epoch=2) diff --git a/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_symbol.py b/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_symbol.py index 737b778650e..52de8d952e2 100644 --- a/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_symbol.py +++ b/python/orca/test/bigdl/orca/ray/mxnet/test_mxnet_symbol.py @@ -19,7 +19,7 @@ import pytest import mxnet as mx -from zoo.ray.mxnet import MXNetTrainer +from zoo.ray.mxnet import MXNetTrainer, create_trainer_config np.random.seed(1337) # for reproducibility @@ -55,15 +55,7 @@ def get_metrics(config): class TestMXNetSymbol(TestCase): def test_symbol(self): - config = { - "num_workers": 1, - "batch_size": 32, - "optimizer": "sgd", - "init": mx.init.Xavier(), - "optimizer_params": {'learning_rate': 0.01}, - "log_interval": 2, - "seed": 42 - } + config = create_trainer_config(batch_size=32, log_interval=2, seed=42) trainer = MXNetTrainer(config, get_data_iters, get_model, metrics_creator=get_metrics) trainer.train(nb_epoch=2)