Skip to content

Commit

Permalink
Add doc for MXNetTrainer and some code refactor (intel-analytics#2198)
Browse files Browse the repository at this point in the history
* add doc and some code refactor

* minor

* update

* minor
  • Loading branch information
hkvision committed Apr 13, 2020
1 parent 632ea5c commit e377768
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 135 deletions.
3 changes: 2 additions & 1 deletion python/orca/src/bigdl/orca/ray/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
# limitations under the License.
#

from .mxnet_runner import MXNetTrainer
from .mxnet_trainer import MXNetTrainer
from .utils import create_trainer_config
128 changes: 15 additions & 113 deletions python/orca/src/bigdl/orca/ray/mxnet/mxnet_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@
import time
import logging
import subprocess
import socket
import ray.services
import mxnet as mx
from mxnet import gluon
from contextlib import closing
from dmlc_tracker.tracker import get_host_ip
from zoo.ray.mxnet.utils import find_free_port


class MXNetRunner(object):
"""Manages a MXNet model for training."""

def setup_distributed(self, env, config, data_creator, model_creator,
loss_creator=None, metrics_creator=None):
logging.basicConfig(level=logging.INFO) # This can print log messages to console.
self.logger = logging.getLogger()
self.config = config # TODO: add check for config keys
assert isinstance(config, dict), "config must be a dict"
for param in ["batch_size", "optimizer", "optimizer_params", "log_interval"]:
assert param in config, param + " must be specified in config"
self.config = config
self.data_creator = data_creator
self.model_creator = model_creator
self.loss_creator = loss_creator
Expand Down Expand Up @@ -117,9 +119,8 @@ def train(self, nb_epoch=1):
self.trainer.step(batch.data[0].shape[0])
if self.metrics:
self.metrics.update(label, outputs)
if "log_interval" in self.config and \
not (i + 1) % self.config["log_interval"]:
# This would print on driver for each pid.
if not (i + 1) % self.config["log_interval"]:
# This would be logged on driver for each worker process.
print_output = ""
print_output \
+= 'Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f' \
Expand All @@ -135,6 +136,7 @@ def train(self, nb_epoch=1):
print_output += ' %s=%f' % (name, acc)
self.logger.info(print_output)
batch_start_time = time.time()
# TODO: save checkpoints
if self.metrics:
names, accs = self.metrics.get()
if not isinstance(names, list):
Expand All @@ -144,19 +146,21 @@ def train(self, nb_epoch=1):
stats[name] = acc
else: # Symbolic API
# TODO: seems no history (i.e. validation accuracy) returned by fit?
if "init" not in self.config:
from mxnet.initializer import Uniform
self.config["init"] = Uniform(0.01) # This is the default value for MXNet
self.model.fit(train_data=self.train_data,
num_epoch=nb_epoch,
initializer=self.config["init"],
kvstore=self.kv,
optimizer=self.config["optimizer"],
optimizer_params=self.config["optimizer_params"],
eval_data=self.val_data,
# TODO: eval and validation metrics could be different
eval_metric=self.metrics,
validation_metric=self.metrics,
eval_data=self.val_data,
batch_end_callback=None if "log_interval" not in self.config
else mx.callback.Speedometer(self.config["batch_size"],
self.config["log_interval"]),
batch_end_callback=mx.callback.Speedometer(
self.config["batch_size"], self.config["log_interval"]),
epoch_end_callback=None if "model" not in self.config
else mx.callback.do_checkpoint(self.config["model"]))
epoch_time = time.time() - start_time
Expand All @@ -182,105 +186,3 @@ def get_node_ip(self):
def find_free_port(self):
"""Finds a free port on the current node."""
return find_free_port()


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


class MXNetTrainer(object):
# TODO: Add documentation.
def __init__(self,
config, # Pass in some config, including initializer, batch_size, etc.
data_creator,
# Return a MXNET model defined with either symbolic or gluon API.
model_creator,
# No need for symbolic API. Loss is already defined as model output.
loss_creator=None,
metrics_creator=None,
# Specify cpu resources for actors if necessary.
runner_cores=None):
self.config = config
self.data_creator = data_creator
self.model_creator = model_creator
self.loss_creator = loss_creator
self.metrics_creator = metrics_creator
self.num_workers = config["num_workers"]
self.num_servers = config["num_servers"] if "num_servers" in self.config \
else self.num_workers

# Generate actor class
# Add a dummy custom resource to diff worker from server if runner_cores is specified
# so that we can place one worker and one server on a node for better performance.
Worker = ray.remote(num_cpus=runner_cores, resources={"_mxnet_worker": 1})(MXNetRunner) \
if runner_cores else ray.remote(MXNetRunner)
Server = ray.remote(num_cpus=runner_cores, resources={"_mxnet_server": 1})(MXNetRunner) \
if runner_cores else ray.remote(MXNetRunner)

# Start runners: workers followed by servers
self.runners = [
Worker.remote()
for i in range(self.num_workers)
]
self.runners += [
Server.remote()
for i in range(self.num_servers)
]

# Compute URL for initializing distributed setup
ips = ray.get(
[runner.get_node_ip.remote() for runner in self.runners])
ports = ray.get(
[runner.find_free_port.remote() for runner in self.runners])
logger = logging.getLogger()
logger.info(ips)
logger.info(ports)

env = {
"DMLC_PS_ROOT_URI": str(get_host_ip()),
"DMLC_PS_ROOT_PORT": str(find_free_port()),
"DMLC_NUM_SERVER": str(self.num_servers),
"DMLC_NUM_WORKER": str(self.num_workers),
}
envs = []
for i in range(self.num_workers):
current_env = env.copy()
current_env['DMLC_ROLE'] = 'worker'
envs.append(current_env)
for i in range(self.num_servers):
current_env = env.copy()
current_env['DMLC_ROLE'] = 'server'
envs.append(current_env)

env['DMLC_ROLE'] = 'scheduler'
modified_env = os.environ.copy()
modified_env.update(env)
# Need to contain system env to run bash
# TODO: Need to kill this process manually?
subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env)

ray.get([
runner.setup_distributed.remote(envs[i], self.config,
self.data_creator,
self.model_creator,
self.loss_creator,
self.metrics_creator)
for i, runner in enumerate(self.runners)
])

def train(self, nb_epoch=1):
"""Trains an MXNet model for several epochs."""
stats = ray.get([w.train.remote(nb_epoch) for w in self.runners])
return stats

def shutdown(self):
"""Shuts down runners and releases resources."""
for runner in self.runners:
runner.shutdown.remote()
runner.__ray_terminate__.remote()

# TODO: add model save and restore
# TODO: add predict, evaluate
146 changes: 146 additions & 0 deletions python/orca/src/bigdl/orca/ray/mxnet/mxnet_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import logging
import subprocess
import ray.services
from dmlc_tracker.tracker import get_host_ip
from zoo.ray.mxnet.mxnet_runner import MXNetRunner
from zoo.ray.mxnet.utils import find_free_port


class MXNetTrainer(object):
"""
MXNetTrainer provides an automatic setup for synchronous distributed MXNet training.
:param config: A dictionary for training configurations. Keys must include the following:
batch_size, optimizer, optimizer_params, log_interval.
optimizer should be an MXNet optimizer or its string representation.
optimizer_params should be a dict in companion with the optimizer. It can contain learning_rate
and other optimization configurations.
log_interval should be an integer, specifying the interval for logging throughput and metrics
if any during the training process.
You can call create_trainer_config to create the config easily.
You can specify "seed" in config to set random seed.
You can specify "init" in seed to set model initializer.
:param data_creator: A function that takes config and kv as arguments and returns an MXNet
DataIter/DataLoader for training or a tuple of training and validation datasets.
You can specify data related configurations for this function in the config argument above.
kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank
can be used in this function to split data for different workers if necessary.
:param model_creator: A function that takes config as argument and returns an MXNet model.
The model can be defined either using MXNet symbolic API or imperative(gluon) API.
:param loss_creator: A function that takes config as argument and returns an MXNet loss.
This is not needed for symbolic API where loss is already defined as model output.
:param metrics_creator: A function that takes config as argument and returns one or a list of
MXNet metrics or corresponding string representations of metrics, for example, 'accuracy'.
This is not needed if you don't have validation data throughout the training.
:param num_workers: The number of workers for distributed training. Default is 1.
:param num_servers: The number of servers for distributed training. Default is None and in this
case it would be equal to the number of workers.
:param runner_cores: The number of CPU cores allocated for each MXNet worker and server.
Default is None. You may need to specify this for better performance.
"""
def __init__(self, config, data_creator, model_creator,
loss_creator=None, metrics_creator=None,
num_workers=1, num_servers=None, runner_cores=None):
self.config = config
self.data_creator = data_creator
self.model_creator = model_creator
self.loss_creator = loss_creator
self.metrics_creator = metrics_creator
self.num_workers = num_workers
self.num_servers = num_servers if num_servers else self.num_workers

# Generate actor class
# Add a dummy custom resource: _mxnet_worker and _mxnet_server to diff worker from server
# if runner_cores is specified so that we can place one worker and one server on a node
# for better performance.
Worker = ray.remote(num_cpus=runner_cores, resources={"_mxnet_worker": 1})(MXNetRunner) \
if runner_cores else ray.remote(MXNetRunner)
Server = ray.remote(num_cpus=runner_cores, resources={"_mxnet_server": 1})(MXNetRunner) \
if runner_cores else ray.remote(MXNetRunner)

# Start runners: workers followed by servers
self.runners = [
Worker.remote()
for i in range(self.num_workers)
]
self.runners += [
Server.remote()
for i in range(self.num_servers)
]

# Compute URL for initializing distributed setup
ips = ray.get(
[runner.get_node_ip.remote() for runner in self.runners])
ports = ray.get(
[runner.find_free_port.remote() for runner in self.runners])
logger = logging.getLogger()
logger.info(ips)
logger.info(ports)

env = {
"DMLC_PS_ROOT_URI": str(get_host_ip()),
"DMLC_PS_ROOT_PORT": str(find_free_port()),
"DMLC_NUM_SERVER": str(self.num_servers),
"DMLC_NUM_WORKER": str(self.num_workers),
}
envs = []
for i in range(self.num_workers):
current_env = env.copy()
current_env['DMLC_ROLE'] = 'worker'
envs.append(current_env)
for i in range(self.num_servers):
current_env = env.copy()
current_env['DMLC_ROLE'] = 'server'
envs.append(current_env)

env['DMLC_ROLE'] = 'scheduler'
modified_env = os.environ.copy()
modified_env.update(env)
# Need to contain system env to run bash
# TODO: Need to kill this process manually?
subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env)

ray.get([
runner.setup_distributed.remote(envs[i], self.config,
self.data_creator,
self.model_creator,
self.loss_creator,
self.metrics_creator)
for i, runner in enumerate(self.runners)
])

def train(self, nb_epoch=1):
"""Trains an MXNet model for several epochs."""
stats = ray.get([w.train.remote(nb_epoch) for w in self.runners])
return stats

def shutdown(self):
"""Shuts down runners and releases resources."""
for runner in self.runners:
runner.shutdown.remote()
runner.__ray_terminate__.remote()

# TODO: add model save and restore
# TODO: add predict, evaluate
43 changes: 43 additions & 0 deletions python/orca/src/bigdl/orca/ray/mxnet/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import socket
from contextlib import closing


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def create_trainer_config(batch_size, optimizer="sgd", optimizer_params=None,
log_interval=10, seed=None, extra_config=None):
if not optimizer_params:
optimizer_params = {'learning_rate': 0.01}
config = {
"batch_size": batch_size,
"optimizer": optimizer,
"optimizer_params": optimizer_params,
"log_interval": log_interval,
}
if seed:
config["seed"] = seed
if extra_config:
assert isinstance(extra_config, dict), "extra_config must be a dict"
config.update(extra_config)
return config
Loading

0 comments on commit e377768

Please sign in to comment.