From 96888222ea5a24603ad21ea5184804ec3c5cf7d2 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sat, 9 May 2020 23:06:27 -0700 Subject: [PATCH] Fix interleave matmul doc (#18260) * fix doc * fix doc * fix axis Co-authored-by: Lin --- src/operator/contrib/transformer.cc | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 58826a2d96a8..1abd2a0ebce9 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -655,14 +655,16 @@ the input must be a single tensor of interleaved projections of queries, keys and values following the layout: (seq_length, batch_size, num_heads * head_dim * 3) -the equivalent code would be: -tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) -q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) -q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True) -q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) -k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) -k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) -output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) +the equivalent code would be:: + + tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) + q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) + q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True) + q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) + k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) + k_proj = mx.nd.reshape(k_proj, shape=(-1, 0, 0), reverse=True) + output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) + )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -703,9 +705,9 @@ the equivalent code would be: tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3)) v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True) -output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) +output = mx.nd.batch_dot(attention, v_proj) output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) -output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) +output = mx.nd.transpose(output, axes=(2, 0, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1)) )code" ADD_FILELINE) .set_num_inputs(2)