Skip to content

Commit

Permalink
Reduced the number of times indices need to be copied to the GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
rhenry-nv committed Oct 9, 2020
1 parent 0ff56ea commit 201dc0a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
6 changes: 4 additions & 2 deletions src/layers/generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ namespace marian {
}

// if selIdx are given, then we must reshuffle accordingly
if (!hypIndices.empty()) // use the same function that shuffles decoder state
sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
if (!hypIndices.empty()) { // use the same function that shuffles decoder state
auto indices = graph()->indices(hypIndices);
sel = rnn::State::select(sel, indices, (int)beamSize, /*isBatchMajor=*/false);
}
return sel;
}

Expand Down
3 changes: 2 additions & 1 deletion src/models/states.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class EncoderState {
// Sub-select active batch entries from encoder context and context mask
Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
// Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
auto indices = context_->graph()->indices(batchIndices);
return New<EncoderState>(index_select(context_, -2, indices), index_select(mask_, -2, indices), batch_);
}
};

Expand Down
27 changes: 22 additions & 5 deletions src/rnn/types.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "common/definitions.h"
#include "common/shape.h"
#include "marian.h"

#include <iostream>
Expand All @@ -12,23 +14,22 @@ struct State {
Expr output;
Expr cell;

State select(const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
State select(Expr selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
return{ select(output, selIdx, beamSize, isBatchMajor),
select(cell, selIdx, beamSize, isBatchMajor) };
}

// this function is also called by Logits
static Expr select(Expr sel, // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN)
const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
Expr selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor)
{
if (!sel)
return sel; // keep nullptr untouched

sel = atleast_4d(sel);

int dimBatch = (int)selIdx.size() / beamSize;
int dimBatch =(int) selIdx->shape().elements()/beamSize;
int dimDepth = sel->shape()[-1];
int dimTime = isBatchMajor ? sel->shape()[-2] : sel->shape()[-3];

Expand Down Expand Up @@ -83,8 +84,24 @@ class States {
States select(const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
States selected;
Expr indices;
// I think this doesn't work if model split among gpus but not sure if it matters

for (auto& state : states_) {
if (state.cell) {
indices = state.cell->graph()->indices(selIdx);
break;
}

if (state.output) {
indices = state.output->graph()->indices(selIdx);
break;
}
}

// GPU OPT: Implement kernel to batch these on GPU
for(auto& state : states_)
selected.push_back(state.select(selIdx, beamSize, isBatchMajor));
selected.push_back(state.select(indices, beamSize, isBatchMajor));
return selected;
}

Expand Down

0 comments on commit 201dc0a

Please sign in to comment.