diff --git a/benchmark/python/sparse_end2end.py b/benchmark/python/sparse_end2end.py
index f709499408c4..c51f44ba0702 100644
--- a/benchmark/python/sparse_end2end.py
+++ b/benchmark/python/sparse_end2end.py
@@ -25,7 +25,7 @@
 parser.add_argument('--num-gpu', type=int, default=0,
                     help='number of gpus to use. 0 means using cpu(0);'
                          'otherwise, use gpu(0),...,gpu(num_gpu-1)')
-parser.add_argument('--output-dim', type=int, default=1,
+parser.add_argument('--output-dim', type=int, default=4,
                     help='number of columns of the forward output')
 
 
@@ -178,9 +178,6 @@ def get_sym(feature_dim):
             nbatch += 1
             batch = next_batch
 
-            mod.forward_backward(batch)
-            # update parameters
-            mod.update()
             # if have kvstore, need to pull corresponding rows of
             # the weights to each context
             if kv is not None:
@@ -196,6 +193,10 @@ def get_sym(feature_dim):
                     row_idx_array.append(row_indices[indptr[s.start]:indptr[s.stop]])
                 kv.row_sparse_pull('w', weight_array, priority=-index, row_ids=row_idx_array)
 
+            mod.forward_backward(batch)
+            # update parameters
+            mod.update()
+
             try:
                 # pre fetch next batch
                 next_batch = next(data_iter)