Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BACKPORT] [FEATURE] Add API to control denormalized computations (#2…
Browse files Browse the repository at this point in the history
…0387)

* [1.x] Add API to control denormalized computations

* Edit name and description

* Add direct imports

* Edit description

Co-authored-by: Andrzej Kotłowski <[email protected]>

* Sanity & review

* Return previous state of the FTZ flag

* Utilize Engine::PushSync

* Disable FTZ for numpy_interoperability case

* Update python/mxnet/util.py

Co-authored-by: Sheng Zha <[email protected]>

* Add required header & fix test

* Fix macro expansion

* Don't include x86instrin.h when compiling with MSVC

* Update documentation

Co-authored-by: Andrzej Kotłowski <[email protected]>
Co-authored-by: Sheng Zha <[email protected]>
  • Loading branch information
3 people authored Jul 30, 2021
1 parent 7077bc4 commit 1155c9e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 2 deletions.
11 changes: 11 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ MXNET_DLL int MXRandomSeed(int seed);
*/
MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);

/*!
* \brief Change floating-point calculations when dealing with denormalized values.
* Currently this option is only supported in CPU backend.
* Flushing denormalized values to zero is enabled by default.
*
* \param value state of flush-to-zero and denormals-are-zero to set.
* \param prev_state state of flush-to-zero and denormals-are-zero before setting new state.
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXSetFlushDenorms(bool value, bool* prev_state);

/*!
* \brief Notify the engine about a shutdown,
* This can help engine to print less messages into display.
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def _load_lib():
# library instance of mxnet
_LIB = _load_lib()

check_call(_LIB.MXSetFlushDenorms(ctypes.c_bool(True),
ctypes.byref(ctypes.c_bool())))
# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
Expand Down
24 changes: 24 additions & 0 deletions python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,3 +1200,27 @@ def get_rtc_compile_opts(ctx):
arch_opt = "--gpu-architecture={}_{}".format("sm" if should_compile_to_SASS else "compute",
device_cc_as_used)
return [arch_opt]

def set_flush_denorms(value):
"""Change floating-point calculations on CPU when dealing with denormalized values.
This is only applicable to architectures which supports flush-to-zero.
Denormalized values are positive and negative values that are very close to 0
(exponent is the smallest possible value).
Flushing denormalized values to 0 can speedup calculations if such values occurs,
but if fulfilling whole IEEE 754 standard is required this option should be disabled.
Flushing denormalized values is enabled in MXNet by default.
Parameters
----------
value : bool
State of flush-to-zero and denormals-are-zero in MXCSR register
Returns
-------
prev_state : bool
Previous state of flush-to-zero in MXCSR register
"""
ret = ctypes.c_bool()
passed_value = ctypes.c_bool(value)
check_call(_LIB.MXSetFlushDenorms(passed_value, ctypes.byref(ret)))
return ret.value
63 changes: 63 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@
#include "miniz.h"
#include "nnvm/pass_functions.h"

// FTZ only applies to SSE and AVX instructions.
#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
(defined(_M_IX86_FP) && _M_IX86_FP >= 1)
#define SUPPORT_FTZ_DMZ 1
#else
#define SUPPORT_FTZ_DMZ 0
#endif

#if SUPPORT_FTZ_DMZ
#include <immintrin.h>
#include <xmmintrin.h>
#endif
#if SUPPORT_FTZ_DMZ && !defined(_MSC_VER)
#include <x86intrin.h>
#endif


using namespace mxnet;

// Internal function to get the information
Expand Down Expand Up @@ -1587,6 +1604,52 @@ int MXRandomSeedContext(int seed, int dev_type, int dev_id) {
API_END();
}

int MXSetFlushDenorms(bool value, bool* prev_state) {
API_BEGIN();
*prev_state = false;

#if SUPPORT_FTZ_DMZ
std::function<bool()> is_dmz_flag_available = []() {
// Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1
// "Checking for the DAZ Flag in the MXCSR Register"
constexpr unsigned int mxcsr_mask_offset = 28;
constexpr unsigned int dmz_flag_offset = 5;
constexpr unsigned int fxsave_req_bytes = 512;

char* fxsave_area_ptr = reinterpret_cast<char*>(malloc(fxsave_req_bytes));
memset(fxsave_area_ptr, 0, fxsave_req_bytes); // fill memory with 0
_fxsave(fxsave_area_ptr);

char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset;
uint32_t mxcsr_mask = *(reinterpret_cast<uint32_t*>((mxcsr_mask_ptr)));
// DMZ flag is supported if sixth bit of MXCSR_MASK is hot
bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1;
free(fxsave_area_ptr);
return dmz_flag;
};

Engine::Get()->PushSync(
[value, prev_state, is_dmz_flag_available](RunContext rctx) {
const unsigned int DMZ_STATE = value ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF;
const unsigned int FTZ_STATE = value ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF;
*prev_state = _MM_GET_FLUSH_ZERO_MODE();
_MM_SET_FLUSH_ZERO_MODE(FTZ_STATE);

// If the DAZ flag is not supported, then it is a reserved bit and attempting to write a 1
// to it will cause a general-protection exception (#GP)
if (is_dmz_flag_available()) {
_MM_SET_DENORMALS_ZERO_MODE(DMZ_STATE);
}
}, Context::CPU(), {}, {},
FnProperty::kNormal, 0, "SetFlushDenorms");

Engine::Get()->WaitForAll();

#endif

API_END();
}

int MXNotifyShutdown() {
API_BEGIN();
mxnet::op::custom::CustomOperator::Get()->Stop();
Expand Down
8 changes: 6 additions & 2 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as _np
import unittest
import pytest
from mxnet import np
from mxnet import np, util
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import use_np
from mxnet.test_utils import is_op_runnable
Expand Down Expand Up @@ -3341,7 +3341,11 @@ def test_np_array_function_protocol():
@with_array_ufunc_protocol
@pytest.mark.serial
def test_np_array_ufunc_protocol():
check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)
prev_state = util.set_flush_denorms(False)
try:
check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)
finally:
util.set_flush_denorms(prev_state)


@use_np
Expand Down

0 comments on commit 1155c9e

Please sign in to comment.