From a147c08f728b8d85aff6bb282532944dd2729c1f Mon Sep 17 00:00:00 2001 From: Mourad Gouicem Date: Tue, 30 Jul 2019 12:56:11 -0700 Subject: [PATCH] src: cpu: rnn: avoid unaligned pointers in vex instructions --- src/cpu/rnn/jit_uni_gru_cell_postgemm_1.hpp | 13 ++++--- src/cpu/rnn/jit_uni_gru_cell_postgemm_2.hpp | 9 +++-- src/cpu/rnn/jit_uni_gru_lbr_cell_postgemm.hpp | 34 ++++++++++++++----- src/cpu/rnn/jit_uni_lstm_cell_postgemm.hpp | 16 ++++++--- 4 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/cpu/rnn/jit_uni_gru_cell_postgemm_1.hpp b/src/cpu/rnn/jit_uni_gru_cell_postgemm_1.hpp index 53d75c698d5..b36e493862a 100644 --- a/src/cpu/rnn/jit_uni_gru_cell_postgemm_1.hpp +++ b/src/cpu/rnn/jit_uni_gru_cell_postgemm_1.hpp @@ -73,7 +73,7 @@ struct jit_uni_gru_cell_postgemm_part1_fwd: public jit_uni_rnn_postgemm Reg64 table_reg(rbx); // table is used for data scale and shifts // We skip vmm0 as it can be used by the injector for masks on sse4.1 - Vmm G0(1), G1(2); + Vmm G0(1), G1(2), tmp1_vmm(3); // We start code generations here preamble(); @@ -98,18 +98,23 @@ struct jit_uni_gru_cell_postgemm_part1_fwd: public jit_uni_rnn_postgemm { // Compute gate 0: G0 = sigmoid(G0 + b0) uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); - uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0, G0, tmp1_vmm); sigmoid_injector_->compute_vector(G0.getIdx()); // we store it for use in postgemm_part2 uni_vmovups(ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size], G0); // Compute gate 1: G1 = sigmoid(G1 + b1) uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); - uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, tmp1_vmm); sigmoid_injector_->compute_vector(G1.getIdx()); // states_t_l = states_tm1_l * G1 - uni_vmulps(G1, G1, ptr[addr_states_tm1_l_reg]); + uni_vmovups(tmp1_vmm, ptr[addr_states_tm1_l_reg]); + uni_vmulps(G1, G1, tmp1_vmm); uni_vmovups(ptr[addr_states_t_l_reg], G1); // increment address pointers diff --git a/src/cpu/rnn/jit_uni_gru_cell_postgemm_2.hpp b/src/cpu/rnn/jit_uni_gru_cell_postgemm_2.hpp index 0815f743cc4..2f08c2f348c 100644 --- a/src/cpu/rnn/jit_uni_gru_cell_postgemm_2.hpp +++ b/src/cpu/rnn/jit_uni_gru_cell_postgemm_2.hpp @@ -73,7 +73,7 @@ struct jit_uni_gru_cell_postgemm_part2_fwd: public jit_uni_rnn_postgemm Reg64 table_reg(rbx); // table is used for data scale and shifts // We skip vmm0 as it can be used by the injector for masks on sse4.1 - Vmm G0(1), G2(2), tmp1_vmm(3); + Vmm G0(1), G2(2), tmp1_vmm(3), tmp2_vmm(4); // constant table map Address one_addr = ptr[table_reg]; @@ -99,14 +99,17 @@ struct jit_uni_gru_cell_postgemm_part2_fwd: public jit_uni_rnn_postgemm { // Compute gate 2: G2 = tanh(G2 + b2) uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); - uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, tmp1_vmm); tanh_injector_->compute_vector(G2.getIdx()); // states_t_l = states_tm1_l * G0 + (1 - G0) * G2 uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); uni_vmovups(tmp1_vmm, one_addr); uni_vsubps(tmp1_vmm, tmp1_vmm, G0); - uni_vmulps(G0, G0, ptr[addr_states_tm1_l_reg]); + uni_vmovups(tmp2_vmm, ptr[addr_states_tm1_l_reg]); + uni_vmulps(G0, G0, tmp2_vmm); uni_vfmadd231ps(G0, tmp1_vmm, G2); uni_vmovups(ptr[addr_states_t_l_reg], G0); diff --git a/src/cpu/rnn/jit_uni_gru_lbr_cell_postgemm.hpp b/src/cpu/rnn/jit_uni_gru_lbr_cell_postgemm.hpp index 619bd32e2c2..fef8c1a4547 100644 --- a/src/cpu/rnn/jit_uni_gru_lbr_cell_postgemm.hpp +++ b/src/cpu/rnn/jit_uni_gru_lbr_cell_postgemm.hpp @@ -77,7 +77,7 @@ struct jit_uni_gru_lbr_cell_postgemm_fwd: public jit_uni_rnn_postgemm Reg64 table_reg(rbx); // table is used for data scale and shifts // We skip vmm0 as it can be used by the injector for masks on sse4.1 - Vmm G0(1), G1(2), G2(3), tmp1_vmm(5); + Vmm G0(1), G1(2), G2(3), tmp1_vmm(5), tmp2_vmm(6); // constant table map Address one_addr = ptr[table_reg]; @@ -114,27 +114,43 @@ struct jit_uni_gru_lbr_cell_postgemm_fwd: public jit_uni_rnn_postgemm { // Compute gate 0 uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); - uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); - uni_vaddps(G0, G0, ptr[addr_ws_gemm_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0, G0, tmp1_vmm); + uni_vmovups(tmp1_vmm, + ptr[addr_ws_gemm_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vaddps(G0, G0, tmp1_vmm); sigmoid_injector_->compute_vector(G0.getIdx()); // Compute gate 1 uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); - uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); - uni_vaddps(G1, G1, ptr[addr_ws_gemm_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, tmp1_vmm); + uni_vmovups(tmp1_vmm, + ptr[addr_ws_gemm_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vaddps(G1, G1, tmp1_vmm); sigmoid_injector_->compute_vector(G1.getIdx()); // compute last gate uni_vmovups(G2, ptr[addr_ws_gemm_reg + 2 * rnn_.dic * gate_dt_size]); - uni_vaddps(G2, G2, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); - uni_vfmadd213ps(G2, G1, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); // G2 * G1 + gates2 - uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, tmp1_vmm); + uni_vmovups(tmp1_vmm, + ptr[addr_ws_gates_reg + + 2 * rnn_.dic * gate_dt_size]); // G2 * G1 + gates2 + uni_vfmadd213ps(G2, G1, tmp1_vmm); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, tmp1_vmm); tanh_injector_->compute_vector(G2.getIdx()); // states_t_l = states_tm1_l * G0 + (1 - G0) * G2 uni_vmovups(tmp1_vmm, one_addr); uni_vsubps(tmp1_vmm, tmp1_vmm, G0); - uni_vmulps(G0, G0, ptr[addr_states_tm1_l_reg]); + uni_vmovups(tmp2_vmm, ptr[addr_states_tm1_l_reg]); + uni_vmulps(G0, G0, tmp2_vmm); uni_vfmadd231ps(G0, tmp1_vmm, G2); // write back the result diff --git a/src/cpu/rnn/jit_uni_lstm_cell_postgemm.hpp b/src/cpu/rnn/jit_uni_lstm_cell_postgemm.hpp index ebccfce15ad..5c3ea419388 100644 --- a/src/cpu/rnn/jit_uni_lstm_cell_postgemm.hpp +++ b/src/cpu/rnn/jit_uni_lstm_cell_postgemm.hpp @@ -203,10 +203,18 @@ struct jit_uni_lstm_cell_postgemm_fwd: public jit_uni_rnn_postgemm } // add biases - uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); - uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); - uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); - uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0, G0, tmp1_vmm); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, tmp1_vmm); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, tmp1_vmm); + uni_vmovups( + tmp1_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3, G3, tmp1_vmm); // inject eltwise code sigmoid_injector_->compute_vector(G0.getIdx());