diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d7c06e76c182..1788f1095fcf 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -25,7 +25,7 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power'] +__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'hanning'] @set_module('mxnet.ndarray.numpy') @@ -293,3 +293,87 @@ def power(x1, x2, out=None): This is a scalar if both x1 and x2 are scalars. """ return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def hanning(M, dtype=_np.float64, ctx=None): + r"""Return the Hanning window. + + The Hanning window is a taper formed by using a weighted cosine. + + Parameters + ---------- + M : int + Number of points in the output window. If zero or less, an + empty array is returned. + dtype : str or numpy.dtype, optional + An optional value type. Default is `numpy.float64`. Note that you need + select numpy.float32 or float64 in this operator. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : ndarray, shape(M,) + The window, with the maximum value normalized to one (the value + one appears only if `M` is odd). + + See Also + -------- + blackman, hamming + + Notes + ----- + The Hanning window is defined as + + .. math:: w(n) = 0.5 - 0.5cos\left(\frac{2\pi{n}}{M-1}\right) + \qquad 0 \leq n \leq M-1 + + The Hanning was named for Julius von Hann, an Austrian meteorologist. + It is also known as the Cosine Bell. Some authors prefer that it be + called a Hann window, to help avoid confusion with the very similar + Hamming window. + + Most references to the Hanning window come from the signal processing + literature, where it is used as one of many windowing functions for + smoothing values. It is also known as an apodization (which means + "removing the foot", i.e. smoothing discontinuities at the beginning + and end of the sampled signal) or tapering function. + + References + ---------- + .. [1] Blackman, R.B. and Tukey, J.W., (1958) The measurement of power + spectra, Dover Publications, New York. + .. [2] E.R. Kanasewich, "Time Sequence Analysis in Geophysics", + The University of Alberta Press, 1975, pp. 106-108. + .. [3] Wikipedia, "Window function", + http://en.wikipedia.org/wiki/Window_function + .. [4] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, + "Numerical Recipes", Cambridge University Press, 1986, page 425. + + Examples + -------- + >>> np.hanning(12) + array([0.00000000e+00, 7.93732437e-02, 2.92292528e-01, 5.71157416e-01, + 8.27430424e-01, 9.79746513e-01, 9.79746489e-01, 8.27430268e-01, + 5.71157270e-01, 2.92292448e-01, 7.93731320e-02, 1.06192832e-13], dtype=float64) + + Plot the window and its frequency response: + + >>> import matplotlib.pyplot as plt + >>> window = np.hanning(51) + >>> plt.plot(window.asnumpy()) + [] + >>> plt.title("Hann window") + Text(0.5, 1.0, 'Hann window') + >>> plt.ylabel("Amplitude") + Text(0, 0.5, 'Amplitude') + >>> plt.xlabel("Sample") + Text(0.5, 0, 'Sample') + >>> plt.show() + """ + if dtype is None: + dtype = _np.float64 + if ctx is None: + ctx = current_context() + return _npi.hanning(M, dtype=dtype, ctx=ctx) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 9e0c52dbfd68..ab7e8a8f6381 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -44,7 +44,7 @@ from ..ndarray.numpy import _internal as _npi __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'power'] + 'mod', 'power', 'hanning'] # This function is copied from ndarray.py since pylint @@ -1549,3 +1549,87 @@ def power(x1, x2, out=None): This is a scalar if both x1 and x2 are scalars. """ return _mx_nd_np.power(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def hanning(M, dtype=_np.float64, ctx=None): + r"""Return the Hanning window. + + The Hanning window is a taper formed by using a weighted cosine. + + Parameters + ---------- + M : int + Number of points in the output window. If zero or less, an + empty array is returned. + dtype : str or numpy.dtype, optional + An optional value type. Default is `numpy.float64`. Note that you need + select numpy.float32 or float64 in this operator. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : ndarray, shape(M,) + The window, with the maximum value normalized to one (the value + one appears only if `M` is odd). + + See Also + -------- + blackman, hamming + + Notes + ----- + The Hanning window is defined as + + .. math:: w(n) = 0.5 - 0.5cos\left(\frac{2\pi{n}}{M-1}\right) + \qquad 0 \leq n \leq M-1 + + The Hanning was named for Julius von Hann, an Austrian meteorologist. + It is also known as the Cosine Bell. Some authors prefer that it be + called a Hann window, to help avoid confusion with the very similar + Hamming window. + + Most references to the Hanning window come from the signal processing + literature, where it is used as one of many windowing functions for + smoothing values. It is also known as an apodization (which means + "removing the foot", i.e. smoothing discontinuities at the beginning + and end of the sampled signal) or tapering function. + + References + ---------- + .. [1] Blackman, R.B. and Tukey, J.W., (1958) The measurement of power + spectra, Dover Publications, New York. + .. [2] E.R. Kanasewich, "Time Sequence Analysis in Geophysics", + The University of Alberta Press, 1975, pp. 106-108. + .. [3] Wikipedia, "Window function", + http://en.wikipedia.org/wiki/Window_function + .. [4] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, + "Numerical Recipes", Cambridge University Press, 1986, page 425. + + Examples + -------- + >>> np.hanning(12) + array([0.00000000e+00, 7.93732437e-02, 2.92292528e-01, 5.71157416e-01, + 8.27430424e-01, 9.79746513e-01, 9.79746489e-01, 8.27430268e-01, + 5.71157270e-01, 2.92292448e-01, 7.93731320e-02, 1.06192832e-13], dtype=float64) + + Plot the window and its frequency response: + + >>> import matplotlib.pyplot as plt + >>> window = np.hanning(51) + >>> plt.plot(window.asnumpy()) + [] + >>> plt.title("Hann window") + Text(0.5, 1.0, 'Hann window') + >>> plt.ylabel("Amplitude") + Text(0, 0.5, 'Amplitude') + >>> plt.xlabel("Sample") + Text(0.5, 0, 'Sample') + >>> plt.show() + """ + if dtype is None: + dtype = _np.float64 + if ctx is None: + ctx = current_context() + return _mx_nd_np.hanning(M, dtype=dtype, ctx=ctx) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 616f3066d98d..ce3ddc0361aa 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -28,7 +28,7 @@ from .._internal import _set_np_symbol_class from . import _internal as _npi -__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power'] +__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'hanning'] def _num_outputs(sym): @@ -1010,4 +1010,88 @@ def power(x1, x2, out=None): return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) +@set_module('mxnet.symbol.numpy') +def hanning(M, dtype=_np.float64, ctx=None): + r"""Return the Hanning window. + + The Hanning window is a taper formed by using a weighted cosine. + + Parameters + ---------- + M : int + Number of points in the output window. If zero or less, an + empty array is returned. + dtype : str or numpy.dtype, optional + An optional value type. Default is `numpy.float64`. Note that you need + select numpy.float32 or float64 in this operator. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : _Symbol, shape(M,) + The window, with the maximum value normalized to one (the value + one appears only if `M` is odd). + + See Also + -------- + blackman, hamming + + Notes + ----- + The Hanning window is defined as + + .. math:: w(n) = 0.5 - 0.5cos\left(\frac{2\pi{n}}{M-1}\right) + \qquad 0 \leq n \leq M-1 + + The Hanning was named for Julius von Hann, an Austrian meteorologist. + It is also known as the Cosine Bell. Some authors prefer that it be + called a Hann window, to help avoid confusion with the very similar + Hamming window. + + Most references to the Hanning window come from the signal processing + literature, where it is used as one of many windowing functions for + smoothing values. It is also known as an apodization (which means + "removing the foot", i.e. smoothing discontinuities at the beginning + and end of the sampled signal) or tapering function. + + References + ---------- + .. [1] Blackman, R.B. and Tukey, J.W., (1958) The measurement of power + spectra, Dover Publications, New York. + .. [2] E.R. Kanasewich, "Time Sequence Analysis in Geophysics", + The University of Alberta Press, 1975, pp. 106-108. + .. [3] Wikipedia, "Window function", + http://en.wikipedia.org/wiki/Window_function + .. [4] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, + "Numerical Recipes", Cambridge University Press, 1986, page 425. + + Examples + -------- + >>> np.hanning(12) + array([0.00000000e+00, 7.93732437e-02, 2.92292528e-01, 5.71157416e-01, + 8.27430424e-01, 9.79746513e-01, 9.79746489e-01, 8.27430268e-01, + 5.71157270e-01, 2.92292448e-01, 7.93731320e-02, 1.06192832e-13], dtype=float64) + + Plot the window and its frequency response: + + >>> import matplotlib.pyplot as plt + >>> window = np.hanning(51) + >>> plt.plot(window.asnumpy()) + [] + >>> plt.title("Hann window") + Text(0.5, 1.0, 'Hann window') + >>> plt.ylabel("Amplitude") + Text(0, 0.5, 'Amplitude') + >>> plt.xlabel("Sample") + Text(0.5, 0, 'Sample') + >>> plt.show() + """ + if dtype is None: + dtype = _np.float64 + if ctx is None: + ctx = current_context() + return _npi.hanning(M, dtype=dtype, ctx=ctx) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_window_op.cc b/src/operator/numpy/np_window_op.cc new file mode 100644 index 000000000000..814744cb109d --- /dev/null +++ b/src/operator/numpy/np_window_op.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_window_op.cc + * \brief CPU Implementation of unary op hanning, hamming, blackman window. + */ + +#include "np_window_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyWindowsParam); + +NNVM_REGISTER_OP(_npi_hanning) +.describe("Return the Hanning window." + "The Hanning window is a taper formed by using a weighted cosine.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyWindowsShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", NumpyWindowCompute) +.add_arguments(NumpyWindowsParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_window_op.cu b/src/operator/numpy/np_window_op.cu new file mode 100644 index 000000000000..04bff6b50bd1 --- /dev/null +++ b/src/operator/numpy/np_window_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_window_op.cu + * \brief CPU Implementation of unary op hanning, hamming, blackman window. + */ + +#include "np_window_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_hanning) +.set_attr("FCompute", NumpyWindowCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_window_op.h b/src/operator/numpy/np_window_op.h new file mode 100644 index 000000000000..1932b1a0a01f --- /dev/null +++ b/src/operator/numpy/np_window_op.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_window_op.h + * \brief CPU Implementation of unary op hanning, hamming, blackman window. + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_WINDOW_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_WINDOW_OP_H_ + +#include +#include +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +#ifdef __CUDA_ARCH__ +__constant__ const float PI = 3.14159265358979323846; +#else +const float PI = 3.14159265358979323846; +using std::isnan; +#endif + +struct NumpyWindowsParam : public dmlc::Parameter { + dmlc::optional M; + std::string ctx; + int dtype; + DMLC_DECLARE_PARAMETER(NumpyWindowsParam) { + DMLC_DECLARE_FIELD(M) + .set_default(dmlc::optional()) + .describe("Number of points in the output window. " + "If zero or less, an empty array is returned."); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .set_default(mshadow::kFloat64) + MXNET_ADD_ALL_TYPES + .describe("Data-type of the returned array."); + } +}; + +inline bool NumpyWindowsShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const NumpyWindowsParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shapes->size(), 0U); + CHECK_EQ(out_shapes->size(), 1U); + CHECK(param.M.has_value()) << "missing 1 required positional argument: 'M'"; + int64_t out_size = param.M.value() <= 0 ? 0 : param.M.value(); + SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast(out_size)})); + return true; +} + +struct hanning_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, index_t M, int req, DType* out) { + if (M == 1) { + KERNEL_ASSIGN(out[i], req, static_cast(1)); + } else { + KERNEL_ASSIGN(out[i], req, DType(0.5) - DType(0.5) * math::cos(DType(2 * PI * i / (M - 1)))); + } + } +}; + +struct hamming_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, index_t M, int req, DType* out) { + if (M == 1) { + KERNEL_ASSIGN(out[i], req, static_cast(1)); + } else { + KERNEL_ASSIGN(out[i], req, + DType(0.54) - DType(0.46) * math::cos(DType(2 * PI * i / (M - 1)))); + } + } +}; + +struct blackman_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, index_t M, int req, DType* out) { + if (M == 1) { + KERNEL_ASSIGN(out[i], req, static_cast(1)); + } else { + KERNEL_ASSIGN(out[i], req, DType(0.42) - DType(0.5) * math::cos(DType(2 * PI * i /(M - 1))) + + DType(0.08) * math::cos(DType(4 * PI * i /(M - 1)))); + } + } +}; + +template +void NumpyWindowCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); + const NumpyWindowsParam& param = nnvm::get(attrs.parsed); + if (param.M.has_value() && param.M.value() <= 0) return; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (window_select == 0) { + Kernel::Launch(s, outputs[0].Size(), static_cast(param.M.value()), + req[0], outputs[0].dptr()); + } else if (window_select == 1) { + Kernel::Launch(s, outputs[0].Size(), static_cast(param.M.value()), + req[0], outputs[0].dptr()); + } else if (window_select == 2) { + Kernel::Launch(s, outputs[0].Size(), static_cast(param.M.value()), + req[0], outputs[0].dptr()); + } else { + LOG(FATAL) << "window_select must be (0, 1, 2)"; + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_WINDOW_OP_H_ diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 0b5142dd8c8b..354fa29931c8 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -35,6 +35,7 @@ from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied from common import run_in_spawned_process from test_operator import * +from test_numpy_op import * from test_numpy_ndarray import * from test_optimizer import * from test_random import * diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py new file mode 100644 index 000000000000..dd8955fd66aa --- /dev/null +++ b/tests/python/unittest/test_numpy_op.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import absolute_import +import numpy as _np +import mxnet as mx +from mxnet import np, npx +from mxnet.base import MXNetError +from mxnet.gluon import HybridBlock +from mxnet.base import MXNetError +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray +from mxnet.test_utils import check_numeric_gradient, use_np +from common import assertRaises, with_seed +import random +import collections + + +@with_seed() +@use_np +def test_np_hanning(): + class TestHanning(HybridBlock): + def __init__(self, M, dtype): + super(TestHanning, self).__init__() + self._M = M + self._dtype = dtype + + def hybrid_forward(self, F, x, *args, **kwargs): + return x + F.np.hanning(M=self._M, dtype=self._dtype) + configs = [-10, -3, -1, 0, 1, 6, 10, 20] + dtypes = ['float32', 'float64'] + + for config in configs: + for dtype in dtypes: + x = np.zeros(shape=(), dtype=dtype) + for hybridize in [False, True]: + net = TestHanning(M=config, dtype=dtype) + np_out = _np.hanning(M=config) + if hybridize: + net.hybridize() + mx_out = net(x) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + mx_out = np.hanning(M=config, dtype=dtype) + np_out = _np.hanning(M=config) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + +@with_seed() +@use_np +def test_np_hanning(): + class TestHanning(HybridBlock): + def __init__(self, M, dtype): + super(TestHanning, self).__init__() + self._M = M + self._dtype = dtype + + def hybrid_forward(self, F, x, *args, **kwargs): + return x + F.np.hanning(M=self._M, dtype=self._dtype) + configs = [-10, -3, -1, 0, 1, 6, 10, 20] + dtypes = ['float32', 'float64'] + + for config in configs: + for dtype in dtypes: + x = np.zeros(shape=(), dtype=dtype) + for hybridize in [False, True]: + net = TestHanning(M=config, dtype=dtype) + np_out = _np.hanning(M=config) + if hybridize: + net.hybridize() + mx_out = net(x) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + mx_out = np.hanning(M=config, dtype=dtype) + np_out = _np.hanning(M=config) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + +if __name__ == '__main__': + import nose + nose.runmodule()