Skip to content

Commit

Permalink
Merge pull request apache#5 from precedenceguo/test
Browse files Browse the repository at this point in the history
Add detection support
  • Loading branch information
winstywang authored Jun 19, 2016
2 parents 17c6819 + 329f58b commit 2a2809c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 12 deletions.
63 changes: 53 additions & 10 deletions python/mxnet/executor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ def _load_general(data, targets):
d_src.copyto(d_targets)
else:
for slice_idx, d_dst in d_targets:
d_src[slice_idx].copyto(d_dst)
if d_src[slice_idx].shape != d_dst.shape:
n = d_dst.shape[0] / (slice_idx.stop - slice_idx.start)
new_slice = slice(slice_idx.start * n, slice_idx.stop * n)
d_src[new_slice].copyto(d_dst)
else:
d_src[slice_idx].copyto(d_dst)

def _load_data(batch, targets):
"""Load data into sliced arrays"""
Expand Down Expand Up @@ -196,6 +201,10 @@ class DataParallelExecutorGroup(object):
The dataset for training. It could be any object with `provide_data` and
`provide_label` properties. Loading of actual data is not necessarily needed
at this stage.
max_data_shape: list of tuple (name, shape)
Maximum shape of input data. The order is the same as `train_data.provide_data`
max_label_shape: list of tuple (name, shape)
Maximum shape of input label. The order is the same as `train_data.provide_label`
shared_grop: DataParallelExecutorGroup
An existing executor group, if to share parameters with it.
"""
Expand All @@ -216,13 +225,31 @@ def __init__(self, sym, arg_names, param_names, ctx, slices, train_data, shared_

self.train_execs = []
for i in range(len(ctx)):
data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:]))
for k, v in train_data.provide_data + train_data.provide_label}
shared_exec = None if shared_group is None else shared_group.train_execs[i]
train_exec = _bind_exec(sym, ctx[i], data_shapes, self.param_names,
need_grad=True, base_exec=shared_exec,
shared_data_arrays=self.shared_data_arrays[i])
self.train_execs.append(train_exec)
data_shapes = {}
batch_size = 0
for k, v in train_data.provide_data:
if k == 'data':
batch_size = v[0]
if max_data_shape is None:
max_data_shape = []
if max_label_shape is None:
max_label_shape = []
max_data_shape_dict = {k: v for k, v in max_data_shape + max_label_shape}
for k, v in train_data.provide_data + train_data.provide_label:
# initialize first executor group with maximum shape provided
if shared_group is None:
if k in max_data_shape_dict:
# data size is set to max possible size of input data
data_shapes[k] = tuple([slices[i].stop - slices[i].start] + \
list(max_data_shape_dict[k][1:]))
else:
# support inputs with different batch size from data
# by indexing first dimension of input by portions in batch instead of batch_size
data_shapes[k] = tuple([int((slices[i].stop - slices[i].start) * v[0] \
/ batch_size)] + list(v[1:]))
else:
data_shapes[k] = tuple([int((slices[i].stop - slices[i].start) * v[0] \
/ batch_size)] + list(v[1:]))

# data structure
self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
Expand Down Expand Up @@ -258,7 +285,9 @@ def backward(self):
def update_metric(self, metric, labels):
""" Update evaluation metric with label and current outputs """
for texec, islice in zip(self.train_execs, self.slices):
labels_slice = [label[islice] for label in labels]
n = int(texec.outputs[0].shape[0] / (islice.stop - islice.start))
new_slice = slice(islice.start * n, islice.stop * n)
labels_slice = [label[new_slice] for label in labels]
metric.update(labels_slice, texec.outputs)

class DataParallelExecutorManager(object):
Expand All @@ -284,6 +313,12 @@ class DataParallelExecutorManager(object):
When not specified, default logger will be used.
sym_gen : a function that generate new Symbols depending on different
input shapes. Used only for bucketing.
mutable_data_shape: bool
Whether input data have different shapes or not.
max_data_shape: list of tuple (name, shape)
Maximum shape of input data. The order is the same as `train_data.provide_data`
max_label_shape: list of tuple (name, shape)
Maximum shape of input label. The order is the same as `train_data.provide_label`
"""
def __init__(self, symbol, ctx, train_data,
arg_names, param_names, aux_names,
Expand All @@ -306,9 +341,10 @@ def __init__(self, symbol, ctx, train_data,
self.param_names = param_names
self.aux_names = aux_names
self.ctx = ctx
self.mutable_data_shape = mutable_data_shape

self.execgrp = DataParallelExecutorGroup(symbol, self.arg_names, self.param_names, self.ctx,
self.slices, train_data)
self.slices, train_data, max_data_shape, max_label_shape)
self.symbol = symbol

self.sym_gen = sym_gen
Expand Down Expand Up @@ -388,6 +424,13 @@ def load_data_batch(self, data_batch):
self.execgrp_bucket[key] = execgrp

self.curr_execgrp = self.execgrp_bucket[key]
elif self.mutable_data_shape is True:
# for each data batch, generate new execgrp and share params with the initial one
execgrp = DataParallelExecutorGroup(self.symbol, self.arg_names,
self.param_names, self.ctx,
self.slices, data_batch,
shared_group=self.execgrp)
self.curr_execgrp = execgrp
else:
self.curr_execgrp = self.execgrp

Expand Down
15 changes: 13 additions & 2 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
train_data, eval_data=None, eval_metric=None,
epoch_end_callback=None, batch_end_callback=None,
logger=None, work_load_list=None, monitor=None,
eval_batch_end_callback=None, sym_gen=None):
eval_batch_end_callback=None, sym_gen=None,
mutable_data_shape=False, max_data_shape=None, max_label_shape=None):
"""Internal training function on multiple devices.
This function will also work for single device as well.
Parameters
Expand Down Expand Up @@ -176,6 +177,13 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
monitor : Monitor, optional
Monitor installed to executor,
for monitoring outputs, weights, and gradients for debugging.
mutable_data_shape: bool, optional
Whether input data have different shapes or not.
It is set to False in default.
max_data_shape: list of tuple (name, shape)
Maximum shape of input data. The order is the same as `train_data.provide_data`
max_label_shape: list of tuple (name, shape)
Maximum shape of input label. The order is the same as `train_data.provide_label`
Notes
-----
- This function will inplace update the NDArrays in arg_params and aux_states.
Expand All @@ -190,7 +198,10 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
arg_names=arg_names,
aux_names=aux_names,
work_load_list=work_load_list,
logger=logger)
logger=logger,
mutable_data_shape=mutable_data_shape,
max_data_shape=max_data_shape,
max_label_shape=max_label_shape)
if monitor:
executor_manager.install_monitor(monitor)

Expand Down

0 comments on commit 2a2809c

Please sign in to comment.