Skip to content

Commit

Permalink
add distribute support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Duyi-Wang committed Nov 22, 2023
1 parent 6c63714 commit 3ea284e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
30 changes: 18 additions & 12 deletions src/pytorch/auto_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,24 @@ struct TorchAutoModel : torch::CustomClassHolder {
return ret;
}

void setPrefix(torch::Tensor inputIds) {
TORCH_CHECK(inputIds.dim() <= 2, "Prefix sharing input expected dim <= 2 but tensor has ", inputIds.dim());
inputIds.squeeze();
TORCH_CHECK(inputIds.dim() == 2, "Prefix sharing only support 1 prompt but input has ", inputIds.size(0));

int seqLen = inputIds.size(-1);

std::vector<int> prefixIds(seqLen);
int64_t *p = inputIds.data_ptr<int64_t>();
for (int i = 0; i < seqLen; ++i) {
prefixIds[i] = static_cast<int>(*p);
p += 1;
void setPrefix(torch::optional<torch::Tensor> inputIds) {
std::vector<int> prefixIds;
if (model->getRank() == 0) {
TORCH_CHECK(inputIds.has_value(), "Make sure master's prefix input is not None.")
TORCH_CHECK(inputIds.value().dim() <= 2, "Prefix sharing input expected dim <= 2 but tensor has ",
inputIds.value().dim());
inputIds.value().squeeze();
TORCH_CHECK(inputIds.value().dim() == 2, "Prefix sharing only support 1 prompt but input has ",
inputIds.value().size(0));

int seqLen = inputIds.value().size(-1);

prefixIds.resize(seqLen);
int64_t *p = inputIds.value().data_ptr<int64_t>();
for (int i = 0; i < seqLen; ++i) {
prefixIds[i] = static_cast<int>(*p);
p += 1;
}
}

model->setPrefix(prefixIds);
Expand Down
4 changes: 2 additions & 2 deletions src/xfastertransformer/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def input(self, input_ids=None):
def forward(self):
return self.model.generate()

def prefix_sharing(self, input_ids, truncate_tail=0):
if truncate_tail > 0:
def prefix_sharing(self, input_ids=None, truncate_tail=0):
if input_ids is not None and truncate_tail > 0:
input_ids = input_ids[:, :-truncate_tail]

self.model.set_prefix(input_ids)
Expand Down

0 comments on commit 3ea284e

Please sign in to comment.