diff --git a/examples/cnn_ms/README.md b/examples/cnn_ms/README.md new file mode 100644 index 000000000..177ae4d32 --- /dev/null +++ b/examples/cnn_ms/README.md @@ -0,0 +1,44 @@ + + +# Image Classification using Convolutional Neural Networks + +Examples inside this folder show how to train CNN models using +SINGA for image classification. + +* `data` includes the scripts for preprocessing image datasets. + Currently, MNIST, CIFAR10 and CIFAR100 are included. + +* `model` includes the CNN model construction codes by creating + a subclass of `Module` to wrap the neural network operations + of each model. Then computational graph is enabled to optimized + the memory and efficiency. + +* `autograd` includes the codes to train CNN models by calling the + [neural network operations](../../python/singa/autograd.py) imperatively. + The computational graph is not created. + +* `train_cnn.py` is the training script, which controls the training flow by + doing BackPropagation and SGD update. + +* `train_multiprocess.py` is the script for distributed training on a single + node with multiple GPUs; it uses Python's multiprocessing module and NCCL. + +* `train_mpi.py` is the script for distributed training (among multiple nodes) + using MPI and NCCL for communication. + +* `benchmark.py` tests the training throughput using `ResNet50` as the workload. \ No newline at end of file diff --git a/examples/cnn_ms/autograd/cifar10_multiprocess.py b/examples/cnn_ms/autograd/cifar10_multiprocess.py new file mode 100644 index 000000000..815d0119e --- /dev/null +++ b/examples/cnn_ms/autograd/cifar10_multiprocess.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from resnet_cifar10 import * +import multiprocessing +import sys + +if __name__ == '__main__': + + # Generate a NCCL ID to be used for collective communication + nccl_id = singa.NcclIdHolder() + + # Configure the number of GPUs to be used + world_size = int(sys.argv[1]) + + # Testing the experimental partial-parameter update asynchronous training + partial_update = True + + process = [] + for local_rank in range(0, world_size): + process.append( + multiprocessing.Process(target=train_cifar10, + args=(True, local_rank, world_size, nccl_id, + partial_update))) + + for p in process: + p.start() diff --git a/examples/cnn_ms/autograd/xceptionnet.py b/examples/cnn_ms/autograd/xceptionnet.py new file mode 100644 index 000000000..8fb23d8cb --- /dev/null +++ b/examples/cnn_ms/autograd/xceptionnet.py @@ -0,0 +1,303 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from singa import autograd +from singa import tensor +from singa import device +from singa import layer +from singa import opt + +import numpy as np +from tqdm import trange + +# the code is modified from +# https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py + + +class Block(layer.Layer): + + def __init__(self, + in_filters, + out_filters, + reps, + strides=1, + padding=0, + start_with_relu=True, + grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = layer.Conv2d(in_filters, + out_filters, + 1, + stride=strides, + padding=padding, + bias=False) + self.skipbn = layer.BatchNorm2d(out_filters) + else: + self.skip = None + + self.layers = [] + + filters = in_filters + if grow_first: + self.layers.append(layer.ReLU()) + self.layers.append( + layer.SeparableConv2d(in_filters, + out_filters, + 3, + stride=1, + padding=1, + bias=False)) + self.layers.append(layer.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps - 1): + self.layers.append(layer.ReLU()) + self.layers.append( + layer.SeparableConv2d(filters, + filters, + 3, + stride=1, + padding=1, + bias=False)) + self.layers.append(layer.BatchNorm2d(filters)) + + if not grow_first: + self.layers.append(layer.ReLU()) + self.layers.append( + layer.SeparableConv2d(in_filters, + out_filters, + 3, + stride=1, + padding=1, + bias=False)) + self.layers.append(layer.BatchNorm2d(out_filters)) + + if not start_with_relu: + self.layers = self.layers[1:] + else: + self.layers[0] = layer.ReLU() + + if strides != 1: + self.layers.append(layer.MaxPool2d(3, strides, padding + 1)) + + self.register_layers(*self.layers) + + self.add = layer.Add() + + def forward(self, x): + y = self.layers[0](x) + for layer in self.layers[1:]: + if isinstance(y, tuple): + y = y[0] + y = layer(y) + + if self.skip is not None: + skip = self.skip(x) + skip = self.skipbn(skip) + else: + skip = x + y = self.add(y, skip) + return y + + +__all__ = ['Xception'] + + +class Xception(layer.Layer): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=1000): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.num_classes = num_classes + + self.conv1 = layer.Conv2d(3, 32, 3, 2, 0, bias=False) + self.bn1 = layer.BatchNorm2d(32) + self.relu1 = layer.ReLU() + + self.conv2 = layer.Conv2d(32, 64, 3, 1, 1, bias=False) + self.bn2 = layer.BatchNorm2d(64) + self.relu2 = layer.ReLU() + # do relu here + + self.block1 = Block(64, + 128, + 2, + 2, + padding=0, + start_with_relu=False, + grow_first=True) + self.block2 = Block(128, + 256, + 2, + 2, + padding=0, + start_with_relu=True, + grow_first=True) + self.block3 = Block(256, + 728, + 2, + 2, + padding=0, + start_with_relu=True, + grow_first=True) + + self.block4 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + self.block5 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + self.block6 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + self.block7 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + + self.block8 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + self.block9 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + self.block10 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + self.block11 = Block(728, + 728, + 3, + 1, + start_with_relu=True, + grow_first=True) + + self.block12 = Block(728, + 1024, + 2, + 2, + start_with_relu=True, + grow_first=False) + + self.conv3 = layer.SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = layer.BatchNorm2d(1536) + self.relu3 = layer.ReLU() + + # Relu Layer + self.conv4 = layer.SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = layer.BatchNorm2d(2048) + + self.relu4 = layer.ReLU() + self.globalpooling = layer.MaxPool2d(10, 1) + self.flatten = layer.Flatten() + self.fc = layer.Linear(2048, num_classes) + + def features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.relu3(x) + + x = self.conv4(x) + x = self.bn4(x) + return x + + def logits(self, features): + x = self.relu4(features) + x = self.globalpooling(x) + x = self.flatten(x) + x = self.fc(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +if __name__ == '__main__': + model = Xception(num_classes=1000) + print('Start intialization............') + dev = device.create_cuda_gpu_on(0) + #dev = device.create_cuda_gpu() + + niters = 20 + batch_size = 16 + IMG_SIZE = 299 + sgd = opt.SGD(lr=0.1, momentum=0.9, weight_decay=1e-5) + + tx = tensor.Tensor((batch_size, 3, IMG_SIZE, IMG_SIZE), dev) + ty = tensor.Tensor((batch_size,), dev, tensor.int32) + autograd.training = True + x = np.random.randn(batch_size, 3, IMG_SIZE, IMG_SIZE).astype(np.float32) + y = np.random.randint(0, 1000, batch_size, dtype=np.int32) + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + + with trange(niters) as t: + for _ in t: + x = model(tx) + loss = autograd.softmax_cross_entropy(x, ty) + sgd(loss) diff --git a/examples/cnn_ms/benchmark.py b/examples/cnn_ms/benchmark.py new file mode 100644 index 000000000..9f69feee0 --- /dev/null +++ b/examples/cnn_ms/benchmark.py @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# the code is modified from +# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +from singa import opt +from singa import device +from singa import tensor + +import argparse +import time +import numpy as np +from tqdm import trange + + +def train_resnet(DIST=True, graph=True, sequential=False, verbosity=0): + + # Define the hypermeters for the train_resnet + niters = 100 + batch_size = 32 + sgd = opt.SGD(lr=0.1, momentum=0.9, weight_decay=1e-5) + + IMG_SIZE = 224 + + # For distributed training, sequential has better throughput in the current version + if DIST == True: + sgd = opt.DistOpt(sgd) + world_size = sgd.world_size + local_rank = sgd.local_rank + global_rank = sgd.global_rank + sequential = True + else: + local_rank = 0 + world_size = 1 + global_rank = 0 + sequential = False + + dev = device.create_cuda_gpu_on(local_rank) + + tx = tensor.Tensor((batch_size, 3, IMG_SIZE, IMG_SIZE), dev) + ty = tensor.Tensor((batch_size,), dev, tensor.int32) + x = np.random.randn(batch_size, 3, IMG_SIZE, IMG_SIZE).astype(np.float32) + y = np.random.randint(0, 1000, batch_size, dtype=np.int32) + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + + dev.SetVerbosity(verbosity) + dev.SetSkipIteration(5) + + # Construct the model + from model import resnet + model = resnet.resnet50(num_channels=3, num_classes=1000) + + model.train() + model.set_optimizer(sgd) + model.compile([tx], is_train=True, use_graph=graph, sequential=sequential) + + # Train model + dev.Sync() + start = time.time() + with trange(niters) as t: + for _ in t: + model(tx, ty, dist_option='fp32', spars=None) + + dev.Sync() + end = time.time() + titer = (end - start) / float(niters) + throughput = float(niters * batch_size * world_size) / (end - start) + if global_rank == 0: + print("\nThroughput = {} per second".format(throughput), flush=True) + print("TotalTime={}".format(end - start), flush=True) + print("Total={}".format(titer), flush=True) + dev.PrintTimeProfiling() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Throughput test using Resnet 50') + parser.add_argument('--dist', + '--enable-dist', + default='False', + action='store_true', + help='enable distributed training', + dest='DIST') + parser.add_argument('--no-graph', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('--verbosity', + '--log-verbosity', + default=0, + type=int, + help='logging verbosity', + dest='verbosity') + + args = parser.parse_args() + + train_resnet(DIST=args.DIST, + graph=args.graph, + sequential=False, + verbosity=args.verbosity) diff --git a/examples/cnn_ms/pkg_model_code/model.py b/examples/cnn_ms/pkg_model_code/model.py new file mode 100644 index 000000000..3fea9143f --- /dev/null +++ b/examples/cnn_ms/pkg_model_code/model.py @@ -0,0 +1,357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= +''' +This script includes Model class for python users +to use Computational Graph in their model. +''' + +import os +import gc +import time +import json +import zipfile +import numpy as np +from functools import wraps +from collections import Iterable + +from singa import tensor +from singa import autograd +from singa import layer +from .tensor import Tensor +from . import singa_wrap as singa + + +class ModelMeta(layer.LayerMeta): + + def buffer_operation(func): + + def remove_creator(tensors): + if not tensors: + return + + if isinstance(tensors, Iterable): + if isinstance(tensors, str): + return + else: + for item in tensors: + if isinstance(item, Iterable): + remove_creator(item) + elif isinstance(item, tensor.Tensor): + item.creator = None + elif isinstance(tensors, tensor.Tensor): + tensors.creator = None + + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.graph_mode and self.training: + if len(args) == 0: + raise ValueError('expect at least one input tensor') + + if isinstance(args[0], list): + assert isinstance( + args[0][0], + Tensor), ('function expects PlaceHolders or Tensors') + dev = args[0][0].device + else: + assert isinstance( + args[0], + Tensor), ('function expects PlaceHolders or Tensors') + dev = args[0].device + + if not self._buffered: + # buffer operations + dev.EnableGraph(True) + self._results = func(self, *args, **kwargs) + dev.Sync() + dev.EnableGraph(False) + self._buffered = True + + # deconstruct Operations before running the entire graph + remove_creator(self._results) + + # make sure all Operations are deallocated + gc.collect() + + # run graph + dev.RunGraph(self.sequential) + return self._results + else: + return func(self, *args, **kwargs) + + return wrapper + + def __new__(cls, name, bases, attr): + if 'train_one_batch' in attr: + attr['train_one_batch'] = ModelMeta.buffer_operation( + attr['train_one_batch']) + + return super(ModelMeta, cls).__new__(cls, name, bases, attr) + + +class Model(layer.Layer, metaclass=ModelMeta): + """ Base class for your neural network models. + + Example usage:: + + import numpy as np + from singa import opt + from singa import tensor + from singa import device + from singa import autograd + from singa import layer + from singa import model + + class MyModel(model.Model): + def __init__(self): + super(MyModel, self).__init__() + + self.softmax_cross_entropy = layer.SoftMaxCrossEntropy() + self.conv1 = layer.Conv2d(1, 20, 5, padding=0) + self.conv2 = layer.Conv2d(20, 50, 5, padding=0) + self.sgd = opt.SGD(lr=0.01) + + def forward(self, x): + y = self.conv1(x) + y = self.conv2(y) + return y + + def train_one_batch(self, x, y): + out = self.forward(x) + loss = self.softmax_cross_entropy(out, y) + self.sgd(loss) + return out, loss + + """ + + # save load states constant + TENSOR_DICT_FILENAME = '/tensor_dict.npz' + STATES_ATTR_FILENAME = '/states_attr.json' + MODEL_STATE_TYPE = 0 + AUX_STATE_TYPE = 1 + + def __init__(self): + """ + Initializes internal Model state + """ + super(Model, self).__init__() + + self.training = True + self.graph_mode = True + self.sequential = False + self._buffered = False + self._results = None + + def compile(self, inputs, is_train=True, use_graph=False, sequential=False): + """ Compile and initialize the model + + This function will automatically derive the shape of parameters + in each sublayer based on the shape of input placeholders. It will + also do some settings. + + Args: + inputs(list): the list of input tensors(placeholders) + is_train(bool): when is_trainis True, this model will enter + training mode, otherwise it will enter the evaluation mode + use_graph(bool): when use_graph is True, computational graph + will be used to train this model + sequential(bool): when sequential is True, model will execute ops + in the graph follow the order of joining the graph + """ + assert len(inputs) > 0 and isinstance(inputs[0], Tensor), ( + 'compile function expects PlaceHolders or Tensors') + + dev = inputs[0].device + dev.EnableGraph(True) + self.forward(*inputs) + dev.EnableGraph(False) + dev.ResetGraph() + + autograd.training = is_train + self.training = is_train + self.graph_mode = use_graph + self.sequential = sequential + + def forward(self, *input): + """Defines the computation performed in every forward propagation. + + Should be overridden by all subclasses. + + Args: + *input: the input training data for the model + + Returns: + out: the outputs of the forward propagation. + """ + raise NotImplementedError + + def train_one_batch(self, *input, **kwargs): + """Defines the computation performed in every training iteration + + Should be overridden by all subclasses. + + Args: + *input: the arguments of train_one_batch + **kwargs: the keyword arguments of train_one_batch + """ + raise NotImplementedError + + def train(self, mode=True): + """Set the model in evaluation mode. + + Args: + mode(bool): when mode is True, this model will enter training mode + """ + self.training = mode + autograd.training = mode + + def eval(self): + """Sets the model in evaluation mode. + """ + self.train(mode=False) + + def graph(self, mode=True, sequential=False): + """ Turn on the computational graph. Specify execution mode. + + Args: + mode(bool): when mode is True, model will use computational graph + sequential(bool): when sequential is True, model will execute ops + in the graph follow the order of joining the graph + """ + self.graph_mode = mode + self.sequential = sequential + + def __get_name__(self): + return self.__class__.__name__ + + def __call__(self, *input, **kwargs): + if self.training: + return self.train_one_batch(*input, **kwargs) + else: + return self.forward(*input, **kwargs) + + def save_states(self, fpath, aux_states={}): + """Save states. + + Args: + fpath: output file path (without the extension) + aux_states(dict): values are standard data types or Tensor, + e.g., epoch ID, learning rate, optimizer states + """ + assert not os.path.isfile(fpath), ( + "Failed to save states, %s is already existed." % fpath) + + states = self.get_states() + + # save states data and attr + tensor_dict = {} + states_attr = {} + for k, v in states.items(): + assert isinstance(v, tensor.Tensor), "Only tensor state is allowed" + tensor_dict[k] = tensor.to_numpy(v) + states_attr[k] = { + 'state_type': self.MODEL_STATE_TYPE, + 'shape': v.shape, + 'dtype': v.dtype + } + + for k, v in aux_states.items(): + assert isinstance(v, + tensor.Tensor), "Only tensor aux state is allowed" + tensor_dict[k] = tensor.to_numpy(v) + states_attr[k] = { + 'state_type': self.AUX_STATE_TYPE, + 'shape': v.shape, + 'dtype': v.dtype + } + + # save to files + timestamp = time.time() + tmp_dir = '/tmp/singa_save_states_%s' % timestamp + os.mkdir(tmp_dir) + tensor_dict_fp = tmp_dir + self.TENSOR_DICT_FILENAME + states_attr_fp = tmp_dir + self.STATES_ATTR_FILENAME + + np.savez(tensor_dict_fp, **tensor_dict) + + with open(states_attr_fp, 'w') as fp: + json.dump(states_attr, fp) + + compression = zipfile.ZIP_DEFLATED + with zipfile.ZipFile(fpath, mode="w") as zf: + zf.write(tensor_dict_fp, + os.path.basename(tensor_dict_fp), + compress_type=compression) + zf.write(states_attr_fp, + os.path.basename(states_attr_fp), + compress_type=compression) + + # clean up tmp files + os.remove(tensor_dict_fp) + os.remove(states_attr_fp) + os.rmdir(tmp_dir) + + def load_states(self, fpath): + """Load the model states and auxiliary states from disk. + + Usage: + m = MyModel() + m.compile(...) + aux_states = m.load_states('mymodel.zip') + + Args: + path: input file path (without the extension) + Returns: + dict + """ + + assert os.path.isfile(fpath), ( + "Failed to load states, %s is not exist." % fpath) + + timestamp = time.time() + tmp_dir = '/tmp/singa_load_states_%s' % timestamp + os.mkdir(tmp_dir) + + with zipfile.ZipFile(fpath, 'r') as zf: + zf.extractall(tmp_dir) + + tensor_dict_fp = tmp_dir + self.TENSOR_DICT_FILENAME + states_attr_fp = tmp_dir + self.STATES_ATTR_FILENAME + + with open(states_attr_fp) as f: + states_attr = json.load(f) + + tensor_dict = np.load(tensor_dict_fp) + + # restore singa tensor from numpy + model_states = dict() + aux_states = dict() + + for k in tensor_dict.files: + if states_attr[k]['state_type'] == self.MODEL_STATE_TYPE: + model_states[k] = tensor.from_numpy(tensor_dict[k]) + elif states_attr[k]['state_type'] == self.AUX_STATE_TYPE: + aux_states[k] = tensor.from_numpy(tensor_dict[k]) + + # restore model_states + self.set_states(model_states) + + # clean up tmp files + os.remove(tensor_dict_fp) + os.remove(states_attr_fp) + os.rmdir(tmp_dir) + return aux_states \ No newline at end of file diff --git a/examples/ms_model_mlp/model.py b/examples/ms_model_mlp/model.py new file mode 100644 index 000000000..454b382d5 --- /dev/null +++ b/examples/ms_model_mlp/model.py @@ -0,0 +1,226 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from singa import layer +from singa import model +from singa import tensor +from singa import opt +from singa import device +from singa.autograd import Operator +from singa.layer import Layer +from singa import singa_wrap as singa +import argparse +import numpy as np + +np_dtype = {"float16": np.float16, "float32": np.float32} + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +#### self-defined loss begin + +### from autograd.py +class SumError(Operator): + + def __init__(self): + super(SumError, self).__init__() + # self.t = t.data + + def forward(self, x): + # self.err = singa.__sub__(x, self.t) + self.data_x = x + # sqr = singa.Square(self.err) + # loss = singa.SumAll(sqr) + loss = singa.SumAll(x) + # self.n = 1 + # for s in x.shape(): + # self.n *= s + # loss /= self.n + return loss + + def backward(self, dy=1.0): + # dx = self.err + dev = device.get_default_device() + dx = tensor.Tensor(self.data_x.shape, dev, singa_dtype['float32']) + dx.copy_from_numpy(np.ones(self.data_x.shape)) + # dx *= float(2 / self.n) + dx *= dy + return dx + +def se_loss(x): + # assert x.shape == t.shape, "input and target shape different: %s, %s" % ( + # x.shape, t.shape) + return SumError()(x)[0] + +### from layer.py +class SumErrorLayer(Layer): + """ + Generate a MeanSquareError operator + """ + + def __init__(self): + super(SumErrorLayer, self).__init__() + + def forward(self, x): + return se_loss(x) + +#### self-defined loss end + +class MSMLP(model.Model): + + def __init__(self, data_size=10, perceptron_size=100, num_classes=10, layer_hidden_list=[10,10,10,10]): + super(MSMLP, self).__init__() + self.num_classes = num_classes + self.dimension = 2 + + self.relu = layer.ReLU() + self.linear1 = layer.Linear(layer_hidden_list[0]) + self.linear2 = layer.Linear(layer_hidden_list[1]) + self.linear3 = layer.Linear(layer_hidden_list[2]) + self.linear4 = layer.Linear(layer_hidden_list[3]) + self.linear5 = layer.Linear(num_classes) + self.softmax_cross_entropy = layer.SoftMaxCrossEntropy() + self.sum_error = SumErrorLayer() + + def forward(self, inputs): + y = self.linear1(inputs) + y = self.relu(y) + y = self.linear2(y) + y = self.relu(y) + y = self.linear3(y) + y = self.relu(y) + y = self.linear4(y) + y = self.relu(y) + y = self.linear5(y) + return y + + def train_one_batch(self, x, y, dist_option, spars, synflow_flag): + # print ("in train_one_batch") + out = self.forward(x) + # print ("train_one_batch x.data: \n", x.data) + # print ("train_one_batch y.data: \n", y.data) + # print ("train_one_batch out.data: \n", out.data) + if synflow_flag: + # print ("sum_error") + loss = self.sum_error(out) + else: # normal training + # print ("softmax_cross_entropy") + loss = self.softmax_cross_entropy(out, y) + # print ("train_one_batch loss.data: \n", loss.data) + + if dist_option == 'plain': + # print ("before pn_p_g_list = self.optimizer(loss)") + pn_p_g_list = self.optimizer(loss) + # print ("after pn_p_g_list = self.optimizer(loss)") + elif dist_option == 'half': + self.optimizer.backward_and_update_half(loss) + elif dist_option == 'partialUpdate': + self.optimizer.backward_and_partial_update(loss) + elif dist_option == 'sparseTopK': + self.optimizer.backward_and_sparse_update(loss, + topK=True, + spars=spars) + elif dist_option == 'sparseThreshold': + self.optimizer.backward_and_sparse_update(loss, + topK=False, + spars=spars) + # print ("len(pn_p_g_list): \n", len(pn_p_g_list)) + # print ("len(pn_p_g_list[0]): \n", len(pn_p_g_list[0])) + # print ("pn_p_g_list[0][0]: \n", pn_p_g_list[0][0]) + # print ("pn_p_g_list[0][1].data: \n", pn_p_g_list[0][1].data) + # print ("pn_p_g_list[0][2].data: \n", pn_p_g_list[0][2].data) + return pn_p_g_list, out, loss + # return pn_p_g_list[0], pn_p_g_list[1], pn_p_g_list[2], out, loss + + def set_optimizer(self, optimizer): + self.optimizer = optimizer + + +def create_model(pretrained=False, **kwargs): + """Constructs a CNN model. + + Args: + pretrained (bool): If True, returns a pre-trained model. + + Returns: + The created CNN model. + """ + model = MSMLP(**kwargs) + + return model + + +__all__ = ['MLP', 'create_model'] + +if __name__ == "__main__": + np.random.seed(0) + + parser = argparse.ArgumentParser() + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-g', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('-m', + '--max-epoch', + default=1001, + type=int, + help='maximum epochs', + dest='max_epoch') + args = parser.parse_args() + + # generate the boundary + f = lambda x: (5 * x + 1) + bd_x = np.linspace(-1.0, 1, 200) + bd_y = f(bd_x) + + # generate the training data + x = np.random.uniform(-1, 1, 400) + y = f(x) + 2 * np.random.randn(len(x)) + + # choose one precision + precision = singa_dtype[args.precision] + np_precision = np_dtype[args.precision] + + # convert training data to 2d space + label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]).astype(np.int32) + data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np_precision) + + dev = device.create_cuda_gpu_on(0) + sgd = opt.SGD(0.1, 0.9, 1e-5, dtype=singa_dtype[args.precision]) + tx = tensor.Tensor((400, 2), dev, precision) + ty = tensor.Tensor((400,), dev, tensor.int32) + model = MLP(data_size=2, perceptron_size=3, num_classes=2) + + # attach model to graph + model.set_optimizer(sgd) + model.compile([tx], is_train=True, use_graph=args.graph, sequential=True) + model.train() + + for i in range(args.max_epoch): + tx.copy_from_numpy(data) + ty.copy_from_numpy(label) + out, loss = model(tx, ty, 'fp32', spars=None) + + if i % 100 == 0: + print("training loss = ", tensor.to_numpy(loss)[0]) diff --git a/examples/ms_model_mlp/native.py b/examples/ms_model_mlp/native.py new file mode 100644 index 000000000..a82ec3b24 --- /dev/null +++ b/examples/ms_model_mlp/native.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from singa import tensor +from singa.tensor import Tensor +from singa import autograd +from singa import opt +import numpy as np +from singa import device +import argparse + +np_dtype = {"float16": np.float16, "float32": np.float32} + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-m', + '--max-epoch', + default=1001, + type=int, + help='maximum epochs', + dest='max_epoch') + args = parser.parse_args() + + np.random.seed(0) + + autograd.training = True + + # prepare training data in numpy array + + # generate the boundary + f = lambda x: (5 * x + 1) + bd_x = np.linspace(-1.0, 1, 200) + bd_y = f(bd_x) + + # generate the training data + x = np.random.uniform(-1, 1, 400) + y = f(x) + 2 * np.random.randn(len(x)) + + # convert training data to 2d space + label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]) + data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32) + + def to_categorical(y, num_classes): + """ + Converts a class vector (integers) to binary class matrix. + + Args: + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + + Returns: + A binary matrix representation of the input. + """ + y = np.array(y, dtype="int") + n = y.shape[0] + categorical = np.zeros((n, num_classes)) + categorical[np.arange(n), y] = 1 + return categorical + + label = to_categorical(label, 2).astype(np.float32) + print("train_data_shape:", data.shape) + print("train_label_shape:", label.shape) + + precision = singa_dtype[args.precision] + np_precision = np_dtype[args.precision] + + dev = device.create_cuda_gpu() + + inputs = Tensor(data=data, device=dev) + target = Tensor(data=label, device=dev) + + inputs = inputs.as_type(precision) + target = target.as_type(tensor.int32) + + w0_np = np.random.normal(0, 0.1, (2, 3)).astype(np_precision) + w0 = Tensor(data=w0_np, + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b0 = Tensor(shape=(3,), + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b0.set_value(0.0) + + w1_np = np.random.normal(0, 0.1, (3, 2)).astype(np_precision) + w1 = Tensor(data=w1_np, + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b1 = Tensor(shape=(2,), + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b1.set_value(0.0) + + sgd = opt.SGD(0.05, 0.8) + + # training process + for i in range(args.max_epoch): + x = autograd.matmul(inputs, w0) + x = autograd.add_bias(x, b0) + x = autograd.relu(x) + x = autograd.matmul(x, w1) + x = autograd.add_bias(x, b1) + loss = autograd.softmax_cross_entropy(x, target) + sgd(loss) + + if i % 100 == 0: + print("%d, training loss = " % i, tensor.to_numpy(loss)[0])