Skip to content

Commit

Permalink
src: cpu: rnn: avoid unaligned pointers in vex instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
mgouicem committed Aug 12, 2019
1 parent eb3c866 commit a147c08
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
13 changes: 9 additions & 4 deletions src/cpu/rnn/jit_uni_gru_cell_postgemm_1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct jit_uni_gru_cell_postgemm_part1_fwd: public jit_uni_rnn_postgemm
Reg64 table_reg(rbx); // table is used for data scale and shifts

// We skip vmm0 as it can be used by the injector for masks on sse4.1
Vmm G0(1), G1(2);
Vmm G0(1), G1(2), tmp1_vmm(3);

// We start code generations here
preamble();
Expand All @@ -98,18 +98,23 @@ struct jit_uni_gru_cell_postgemm_part1_fwd: public jit_uni_rnn_postgemm
{
// Compute gate 0: G0 = sigmoid(G0 + b0)
uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
uni_vaddps(G0, G0, tmp1_vmm);
sigmoid_injector_->compute_vector(G0.getIdx());
// we store it for use in postgemm_part2
uni_vmovups(ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size], G0);

// Compute gate 1: G1 = sigmoid(G1 + b1)
uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
uni_vaddps(G1, G1, tmp1_vmm);
sigmoid_injector_->compute_vector(G1.getIdx());

// states_t_l = states_tm1_l * G1
uni_vmulps(G1, G1, ptr[addr_states_tm1_l_reg]);
uni_vmovups(tmp1_vmm, ptr[addr_states_tm1_l_reg]);
uni_vmulps(G1, G1, tmp1_vmm);
uni_vmovups(ptr[addr_states_t_l_reg], G1);

// increment address pointers
Expand Down
9 changes: 6 additions & 3 deletions src/cpu/rnn/jit_uni_gru_cell_postgemm_2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct jit_uni_gru_cell_postgemm_part2_fwd: public jit_uni_rnn_postgemm
Reg64 table_reg(rbx); // table is used for data scale and shifts

// We skip vmm0 as it can be used by the injector for masks on sse4.1
Vmm G0(1), G2(2), tmp1_vmm(3);
Vmm G0(1), G2(2), tmp1_vmm(3), tmp2_vmm(4);

// constant table map
Address one_addr = ptr[table_reg];
Expand All @@ -99,14 +99,17 @@ struct jit_uni_gru_cell_postgemm_part2_fwd: public jit_uni_rnn_postgemm
{
// Compute gate 2: G2 = tanh(G2 + b2)
uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
uni_vaddps(G2, G2, tmp1_vmm);
tanh_injector_->compute_vector(G2.getIdx());

// states_t_l = states_tm1_l * G0 + (1 - G0) * G2
uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
uni_vmovups(tmp1_vmm, one_addr);
uni_vsubps(tmp1_vmm, tmp1_vmm, G0);
uni_vmulps(G0, G0, ptr[addr_states_tm1_l_reg]);
uni_vmovups(tmp2_vmm, ptr[addr_states_tm1_l_reg]);
uni_vmulps(G0, G0, tmp2_vmm);
uni_vfmadd231ps(G0, tmp1_vmm, G2);
uni_vmovups(ptr[addr_states_t_l_reg], G0);

Expand Down
34 changes: 25 additions & 9 deletions src/cpu/rnn/jit_uni_gru_lbr_cell_postgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct jit_uni_gru_lbr_cell_postgemm_fwd: public jit_uni_rnn_postgemm
Reg64 table_reg(rbx); // table is used for data scale and shifts

// We skip vmm0 as it can be used by the injector for masks on sse4.1
Vmm G0(1), G1(2), G2(3), tmp1_vmm(5);
Vmm G0(1), G1(2), G2(3), tmp1_vmm(5), tmp2_vmm(6);

// constant table map
Address one_addr = ptr[table_reg];
Expand Down Expand Up @@ -114,27 +114,43 @@ struct jit_uni_gru_lbr_cell_postgemm_fwd: public jit_uni_rnn_postgemm
{
// Compute gate 0
uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
uni_vaddps(G0, G0, ptr[addr_ws_gemm_reg + 0 * rnn_.dic * gate_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
uni_vaddps(G0, G0, tmp1_vmm);
uni_vmovups(tmp1_vmm,
ptr[addr_ws_gemm_reg + 0 * rnn_.dic * gate_dt_size]);
uni_vaddps(G0, G0, tmp1_vmm);
sigmoid_injector_->compute_vector(G0.getIdx());

// Compute gate 1
uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
uni_vaddps(G1, G1, ptr[addr_ws_gemm_reg + 1 * rnn_.dic * gate_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
uni_vaddps(G1, G1, tmp1_vmm);
uni_vmovups(tmp1_vmm,
ptr[addr_ws_gemm_reg + 1 * rnn_.dic * gate_dt_size]);
uni_vaddps(G1, G1, tmp1_vmm);
sigmoid_injector_->compute_vector(G1.getIdx());

// compute last gate
uni_vmovups(G2, ptr[addr_ws_gemm_reg + 2 * rnn_.dic * gate_dt_size]);
uni_vaddps(G2, G2, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
uni_vfmadd213ps(G2, G1, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); // G2 * G1 + gates2
uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
uni_vaddps(G2, G2, tmp1_vmm);
uni_vmovups(tmp1_vmm,
ptr[addr_ws_gates_reg
+ 2 * rnn_.dic * gate_dt_size]); // G2 * G1 + gates2
uni_vfmadd213ps(G2, G1, tmp1_vmm);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
uni_vaddps(G2, G2, tmp1_vmm);
tanh_injector_->compute_vector(G2.getIdx());

// states_t_l = states_tm1_l * G0 + (1 - G0) * G2
uni_vmovups(tmp1_vmm, one_addr);
uni_vsubps(tmp1_vmm, tmp1_vmm, G0);
uni_vmulps(G0, G0, ptr[addr_states_tm1_l_reg]);
uni_vmovups(tmp2_vmm, ptr[addr_states_tm1_l_reg]);
uni_vmulps(G0, G0, tmp2_vmm);
uni_vfmadd231ps(G0, tmp1_vmm, G2);

// write back the result
Expand Down
16 changes: 12 additions & 4 deletions src/cpu/rnn/jit_uni_lstm_cell_postgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,18 @@ struct jit_uni_lstm_cell_postgemm_fwd: public jit_uni_rnn_postgemm
}

// add biases
uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
uni_vaddps(G0, G0, tmp1_vmm);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
uni_vaddps(G1, G1, tmp1_vmm);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
uni_vaddps(G2, G2, tmp1_vmm);
uni_vmovups(
tmp1_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
uni_vaddps(G3, G3, tmp1_vmm);

// inject eltwise code
sigmoid_injector_->compute_vector(G0.getIdx());
Expand Down

0 comments on commit a147c08

Please sign in to comment.