Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #104 from antinucleon/master
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
antinucleon committed Sep 20, 2015
2 parents e6cc80f + 935c513 commit ded4555
Show file tree
Hide file tree
Showing 7 changed files with 601 additions and 17 deletions.
445 changes: 445 additions & 0 deletions example/imagenet/alexnet.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import optimizer
from . import model
from . import initializer
from . import visualization
import atexit

__version__ = "0.1.0"
5 changes: 3 additions & 2 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# pylint: disable=invalid-name
"""Online evaluation metric module."""
import numpy as np
from .ndarray import NDArray

class EvalMetric(object):
"""Base class of all evaluation metrics."""
def __init__(self, name):
self.name = name
self.reset()

def update(pred, label):
def update(self, pred, label):
"""Update the internal evaluation.
Parameters
Expand Down Expand Up @@ -40,6 +40,7 @@ def get(self):


class Accuracy(EvalMetric):
"""Calculate accuracy"""
def __init__(self):
super(Accuracy, self).__init__('accuracy')

Expand Down
27 changes: 14 additions & 13 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# pylint: skip-file
# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, no-member
# pylint: disable=too-many-branches, too-many-statements, unused-argument, unused-variable
"""MXNet model module"""
import numpy as np
import time
from . import io
from . import nd
from . import optimizer as opt
from . import metric
from .symbol import Symbol
from .context import Context
from .initializer import Xavier

Expand All @@ -20,7 +21,7 @@


def _train(symbol, ctx, input_shape,
arg_params, aux_states,
arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, verbose=True):
Expand All @@ -40,7 +41,7 @@ def _train(symbol, ctx, input_shape,
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_states : dict of str to NDArray
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
begin_round : int
Expand Down Expand Up @@ -81,16 +82,16 @@ def _train(symbol, ctx, input_shape,
grad_arrays = train_exec.grad_arrays
aux_arrays = train_exec.aux_arrays
# copy initialized parameters to executor parameters
for key, weight in zip(arg_names, arg_arrays):
for key, weight in list(zip(arg_names, arg_arrays)):
if key in arg_params:
arg_params[key].copyto(weight)
for key, weight in zip(aux_names, aux_arrays):
for key, weight in list(zip(aux_names, aux_arrays)):
if key in aux_params:
aux_params[key].copyto(weight)
# setup helper data structures
label_array = None
data_array = None
for name, arr in zip(symbol.list_arguments(), arg_arrays):
for name, arr in list(zip(symbol.list_arguments(), arg_arrays)):
if name.endswith('label'):
assert label_array is None
label_array = arr
Expand Down Expand Up @@ -151,10 +152,10 @@ def _train(symbol, ctx, input_shape,
for key, weight, gard in arg_blocks:
if key in arg_params:
weight.copyto(arg_params[key])
for key, arr in zip(aux_names, aux_states):
arr.copyto(aux_states[key])
for key, arr in list(zip(aux_names, aux_arrays)):
arr.copyto(aux_params[key])
if iter_end_callback:
iter_end_callback(i, arg_params, aux_states)
iter_end_callback(i, arg_params, aux_arrays)
# end of the function
return

Expand Down Expand Up @@ -224,11 +225,11 @@ def _init_params(self):
arg_shapes, _, aux_shapes = self.symbol.infer_shape(data=self.input_shape)
if self.arg_params is None:
arg_names = self.symbol.list_arguments()
self.arg_params = {k : nd.zeros(s) for k, s in zip(arg_names, arg_shapes)
self.arg_params = {k : nd.zeros(s) for k, s in list(zip(arg_names, arg_shapes))
if not is_data_arg(k)}
if self.aux_states is None:
aux_names = self.symbol.list_auxiliary_states()
self.aux_states = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)}
self.aux_states = {k : nd.zeros(s) for k, s in list(zip(aux_names, aux_shapes))}
for k, v in self.arg_params.items():
self.initializer(k, v)
for k, v in self.aux_states.items():
Expand All @@ -241,7 +242,7 @@ def _init_predictor(self):
# for now only use the first device
pred_exec = self.symbol.simple_bind(
self.ctx[0], grad_req='null', data=self.input_shape)
for name, value in zip(self.symbol.list_arguments(), pred_exec.arg_arrays):
for name, value in list(zip(self.symbol.list_arguments(), pred_exec.arg_arrays)):
if name not in self.arg_datas:
assert name in self.arg_params
self.arg_params[name].copyto(value)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: skip-file
# pylint: disable=fixme, invalid-name
"""Common Optimization algorithms with regularizations."""
from .ndarray import NDArray, zeros

Expand Down
1 change: 0 additions & 1 deletion python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def simple_bind(self, ctx, grad_req='write', **kwargs):
arg_ndarrays = [zeros(shape, ctx) for shape in arg_shapes]

if grad_req != 'null':
req = {}
grad_ndarrays = {}
for name, shape in zip(self.list_arguments(), arg_shapes):
if not (name.endswith('data') or name.endswith('label')):
Expand Down
137 changes: 137 additions & 0 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-locals, fixme
# pylint: disable=unused-argument, too-many-branches, too-many-statements
"""Visualization module"""
from .symbol import Symbol
import json
import re
import copy


def _str2tuple(string):
"""convert shape string to list, internal use only
Parameters
----------
string: str
shape string
Returns
-------
list of str to represent shape
"""
return re.findall(r"\d+", string)


def network2dot(title, symbol, shape=None):
"""convert symbol to dot object for visualization
Parameters
----------
title: str
title of the dot graph
symbol: Symbol
symbol to be visualized
shape: TODO
TODO
Returns
------
dot: Diagraph
dot object of symbol
"""
# todo add shape support
try:
from graphviz import Digraph
except:
raise ImportError("Draw network requires graphviz library")
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be Symbol")
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
heads = set(conf["heads"][0]) # TODO(xxx): check careful
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
dot = Digraph(name=title)
# make nodes
for i in range(len(nodes)):
node = nodes[i]
op = node["op"]
name = "%s_%d" % (op, i)
# input data
if i in heads and op == "null":
label = node["name"]
attr = copy.deepcopy(node_attr)
dot.node(name=name, label=label, **attr)
if op == "null":
continue
elif op == "Convolution":
label = "Convolution\n%sx%s/%s, %s" % (_str2tuple(node["param"]["kernel"])[0],
_str2tuple(node["param"]["kernel"])[1],
_str2tuple(node["param"]["stride"])[0],
node["param"]["num_filter"])
attr = copy.deepcopy(node_attr)
attr["color"] = "royalblue1"
dot.node(name=name, label=label, **attr)
elif op == "FullyConnected":
label = "FullyConnected\n%s" % node["param"]["num_hidden"]
attr = copy.deepcopy(node_attr)
attr["color"] = "royalblue1"
dot.node(name=name, label=label, **attr)
elif op == "BatchNorm":
label = "BatchNorm"
attr = copy.deepcopy(node_attr)
attr["color"] = "orchid1"
dot.node(name=name, label=label, **attr)
elif op == "Concat":
label = "Concat"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Flatten":
label = "Flatten"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Reshape":
label = "Reshape"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Pooling":
label = "Pooling\n%s, %sx%s/%s" % (node["param"]["pool_type"],
_str2tuple(node["param"]["kernel"])[0],
_str2tuple(node["param"]["kernel"])[1],
_str2tuple(node["param"]["stride"])[0])
attr = copy.deepcopy(node_attr)
attr["color"] = "firebrick2"
dot.node(name=name, label=label, **attr)
elif op == "Activation" or op == "LeakyReLU":
label = "%s\n%s" % (op, node["param"]["act_type"])
attr = copy.deepcopy(node_attr)
attr["color"] = "salmon"
dot.node(name=name, label=label, **attr)
else:
label = op
attr = copy.deepcopy(node_attr)
attr["color"] = "olivedrab1"
dot.node(name=name, label=label, **attr)

# add edges
for i in range(len(nodes)):
node = nodes[i]
op = node["op"]
name = "%s_%d" % (op, i)
if op == "null":
continue
else:
inputs = node["inputs"]
for item in inputs:
input_node = nodes[item[0]]
input_name = "%s_%d" % (input_node["op"], item[0])
if input_node["op"] != "null" or item[0] in heads:
# add shape into label
attr = {"dir": "back"}
dot.edge(tail_name=name, head_name=input_name, **attr)

return dot

0 comments on commit ded4555

Please sign in to comment.