Skip to content

Commit

Permalink
all: rnn: align computing (per-oc) mask with the library
Browse files Browse the repository at this point in the history
The i-th bit in the mask corresponds to the i-th dimension.
The dimensions are enumerated from the outermost one, aka C-array style.
  • Loading branch information
Fomenko, Evarist M committed Dec 5, 2019
1 parent 28f4c96 commit ff3ffab
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion examples/cpu_rnn_inference_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ void simple_net() {
///
const float data_shift = 64.;
const float data_scale = 63.;
const int weights_scale_mask = 3; // 11 for last two dimensions of ldigo
const int weights_scale_mask = 0
+ (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
+ (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`
//[quantize]
std::vector<float> weights_scales(lstm_n_gates * feature_size);
// assign halves of vector with arbitrary values
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct rnn_weights_reorder_t : public primitive_impl_t {
if (itag == format_tag::undef) return invalid_arguments;

const int mask = attr->rnn_weights_qparams_.mask_;
if (!utils::one_of(mask, 0, 3)) return unimplemented;
if (!utils::one_of(mask, 0, 24)) return unimplemented;

auto _pd = new pd_t(
engine, attr, src_engine, src_md, dst_engine, dst_md);
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/rnn/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void create_dnnl_rnn_attr(const prb_t &p, dnnl_primitive_attr_t *dnnl_attr) {

if (p.scale_policy == policy_t::PER_OC) {
DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_qparams(
*dnnl_attr, p.dic * p.n_gates(), 0x3, p.wei_oc_scales));
*dnnl_attr, p.dic * p.n_gates(), 0x18, p.wei_oc_scales));
} else if (p.scale_policy == policy_t::COMMON && p.wei_scale != 1.) {
DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_qparams(
*dnnl_attr, 1, 0, &p.wei_scale));
Expand Down

0 comments on commit ff3ffab

Please sign in to comment.