Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
…to enable-fc-bwd
  • Loading branch information
TaoLv committed Feb 11, 2020
2 parents 1bf97fa + 6c61afb commit b374889
Show file tree
Hide file tree
Showing 24 changed files with 850 additions and 79 deletions.
78 changes: 78 additions & 0 deletions benchmark/opperf/nd_operations/linalg_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Performance benchmark tests for MXNet NDArray Linear Algebra Operations.
Below 17 Linear Algebra Operators are covered:
['linalg_potri', 'linalg_gemm2', 'linalg_extractdiag', 'linalg_trsm', 'linalg_gelqf', 'linalg_gemm', 'linalg_sumlogdiag',
'linalg_potrf', 'linalg_makediag', 'linalg_syrk', 'linalg_maketrian', 'linalg_trmm', 'linalg_extracttrian',
'linalg_slogdet', 'linalg_det', 'linalg_inverse', 'moments']
"""

import mxnet as mx

from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
from benchmark.opperf.utils.op_registry_utils import get_all_linalg_operators

from benchmark.opperf.utils.benchmark_utils import run_performance_test
from benchmark.opperf.utils.common_utils import merge_map_list
from benchmark.opperf.rules.default_params import MX_OP_MODULE

def run_linalg_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
"""Runs benchmarks with the given context and precision (dtype) for all the linear algebra
operators in MXNet.
Parameters
----------
ctx: mx.ctx
Context to run benchmarks
dtype: str, default 'float32'
Precision to use for benchmarks
profiler: str, default 'native'
Type of Profiler to use (native/python)
warmup: int, default 25
Number of times to run for warmup
runs: int, default 100
Number of runs to capture benchmark results
Returns
-------
Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
"""
# Individual tests for ops with specific requirements on input data
# linalg_potrf requires a positive definite matrix as input
linalg_potrf_benchmark = run_performance_test(getattr(MX_OP_MODULE, "linalg_potrf"),
run_backward=False,
dtype=dtype,
ctx=ctx,
profiler=profiler,
inputs=[{"A": [[1, 0],
[0, 1]]},
{"A": [[2, -1, 0],
[-1, 2, -1],
[0, -1, 2]]}],
warmup=warmup,
runs=runs)

# Fetch all Linear Algebra Operators
mx_linalg_ops = get_all_linalg_operators()
# Run benchmarks
mx_linalg_op_results = run_op_benchmarks(mx_linalg_ops, dtype, ctx, profiler, warmup, runs)
return merge_map_list(linalg_potrf_benchmark + [mx_linalg_op_results])
4 changes: 4 additions & 0 deletions benchmark/opperf/opperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from benchmark.opperf.nd_operations.array_rearrange import run_rearrange_operators_benchmarks
from benchmark.opperf.nd_operations.indexing_routines import run_indexing_routines_benchmarks
from benchmark.opperf.nd_operations.nn_loss_operators import run_loss_operators_benchmarks
from benchmark.opperf.nd_operations.linalg_operators import run_linalg_operators_benchmarks

from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file
from benchmark.opperf.utils.op_registry_utils import get_operators_with_no_benchmark, \
Expand Down Expand Up @@ -114,6 +115,9 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n
# Run all NN loss operations benchmarks with default input values
mxnet_operator_benchmark_results.append(run_loss_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler))

# Run all Linear Algebra operations benchmarks with default input values
mxnet_operator_benchmark_results.append(run_linalg_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler))

# ****************************** PREPARE FINAL RESULTS ********************************
final_benchmark_result_map = merge_map_list(mxnet_operator_benchmark_results)
return final_benchmark_result_map
Expand Down
17 changes: 15 additions & 2 deletions benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@
DEFAULT_DATA_SMCE = [(1024, 1024)]
DEFAULT_LABEL_SMCE = [(1024,)]

# For linalg operators
DEFAULT_A = [(1024, 1024)]
DEFAULT_B = [(1024, 1024)]
DEFAULT_C = [(1024, 1024)]
DEFAULT_A_MT = [(1024, 1035)]
DEFAULT_AXES = [[0, 1]]

# Default Inputs. MXNet Op Param Name to Default Input mapping
DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
"sample": DEFAULT_SAMPLE,
Expand Down Expand Up @@ -206,7 +213,12 @@
"transform_type": DEFAULT_TRANSFORM_TYPE,
"data_gridgenerator": DEFAULT_DATA_GRIDGEN,
"target_shape_gridgenerator": DEFAULT_TARGET_SHAPE,
"data_sample_multinomial": DEFAULT_DATA_SM}
"data_sample_multinomial": DEFAULT_DATA_SM,
"A": DEFAULT_A,
"B": DEFAULT_B,
"C": DEFAULT_C,
"A_linalg_maketrian": DEFAULT_A_MT,
"axes": DEFAULT_AXES}


# These are names of MXNet operator parameters that is of type NDArray.
Expand All @@ -219,4 +231,5 @@
"low", "high", "weight", "bias", "moving_mean", "moving_var",
"weight", "weight32", "grad", "mean", "var", "mom", "n", "d",
"v", "z", "g", "delta", "args", "indices", "shape_like", "y",
"x", "condition", "a", "index", "raveL_data", "label", "grid"]
"x", "condition", "a", "index", "raveL_data", "label", "grid",
"A", "B", "C"]
2 changes: 1 addition & 1 deletion benchmark/opperf/utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from benchmark.opperf.rules.default_params import PARAMS_OF_TYPE_NDARRAY
from .profiler_utils import cpp_profile, python_profile

no_backward = ['gather_nd', 'softmax_cross_entropy']
no_backward = ['gather_nd', 'softmax_cross_entropy', 'linalg_gelqf', 'linalg_slogdet', 'moments']

def _prepare_op_inputs(inputs, run_backward, dtype, ctx):
mx.random.seed(41)
Expand Down
26 changes: 25 additions & 1 deletion benchmark/opperf/utils/op_registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def prepare_op_inputs(op, arg_params):

# 3d tensor is needed by following ops
ops_3d = ['CTCLoss', 'ctc_loss']
custom_data = ['BilinearSampler', 'GridGenerator', 'sample_multinomial']

custom_data = ['BilinearSampler', 'GridGenerator', 'sample_multinomial', 'linalg_maketrian']

# Prepare op to default input mapping
arg_values = {}
Expand Down Expand Up @@ -267,6 +268,29 @@ def get_all_random_sampling_operators():
return random_sampling_mx_operators


def get_all_linalg_operators():
"""Gets all Linear Algebra operators registered with MXNet.
Returns
-------
{"operator_name": {"has_backward", "nd_op_handle", "params"}}
"""
other_linalg_ops = ['moments']

# Already tested linalg_potrf independently
independently_tested = ['linalg_potrf']

# Get all mxnet operators
mx_operators = _get_all_mxnet_operators()

# Filter for Linear Algebra operators
linalg_mx_operators = {}
for op_name, _ in mx_operators.items():
if (op_name.startswith("linalg_") and op_name not in independently_tested) or op_name in other_linalg_ops:
linalg_mx_operators[op_name] = mx_operators[op_name]
return linalg_mx_operators


def get_all_reduction_operators():
"""Gets all Reduction operators registered with MXNet.
Expand Down
73 changes: 73 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"Image transforms."

import random
import numpy as np

from ...block import Block, HybridBlock
from ...nn import Sequential, HybridSequential
from .... import image
Expand Down Expand Up @@ -198,6 +200,77 @@ def hybrid_forward(self, F, x):
return F.image.normalize(x, self._mean, self._std)


class Rotate(Block):
"""Rotate the input image by a given angle. Keeps the original image shape.
Parameters
----------
rotation_degrees : float32
Desired rotation angle in degrees.
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, rotation_degrees, zoom_in=False, zoom_out=False):
super(Rotate, self).__init__()
self._args = (rotation_degrees, zoom_in, zoom_out)

def forward(self, x):
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.imrotate(x, *self._args)


class RandomRotation(Block):
"""Random rotate the input image by a random angle.
Keeps the original image shape and aspect ratio.
Parameters
----------
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
rotate_with_proba : float32
Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, angle_limits, zoom_in=False, zoom_out=False, rotate_with_proba=1.0):
super(RandomRotation, self).__init__()
lower, upper = angle_limits
if lower >= upper:
raise ValueError("`angle_limits` must be an ordered tuple")
if rotate_with_proba < 0 or rotate_with_proba > 1:
raise ValueError("Probability of rotating the image should be between 0 and 1")
self._args = (angle_limits, zoom_in, zoom_out)
self._rotate_with_proba = rotate_with_proba

def forward(self, x):
if np.random.random() > self._rotate_with_proba:
return x
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.random_rotate(x, *self._args)


class RandomResizedCrop(Block):
"""Crop the input image with random scale and aspect ratio.
Expand Down
Loading

0 comments on commit b374889

Please sign in to comment.