Skip to content

Commit

Permalink
feat(//core/conversion): Evaluation of static conditionals works now
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jun 11, 2020
1 parent 7466b8a commit 6421f3d
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n);

void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
Expand All @@ -204,6 +206,31 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
}
}

void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
auto b = condition ? n->blocks()[0] : n->blocks()[1];

for (const auto bn : b->nodes()) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, bn);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.")
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
}
}

MapIValues(ctx, b->outputs(), n->outputs(), 0, 0);
}

// TODO: With functionalization pass we may be able to make this into a regular evaluator later
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
Expand All @@ -213,16 +240,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {

MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);

LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());

while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
for (auto bn : n->blocks()[0]->nodes()) {
auto eval = EvaluateNode(ctx, bn);
if (eval) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
} else {
Expand All @@ -236,8 +268,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
trip_count.swap(new_trip_count);
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
}
}

Expand All @@ -255,6 +287,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
bool blacklisted = isNodeConversionBlacklisted(n);
if (n->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (n->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, n);
} else if (to_eval) {
auto eval = EvaluateNode(ctx, n);
if (eval) {
Expand Down Expand Up @@ -303,10 +337,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
std::set<std::string> unsupported_ops;
for (const auto n : b->nodes()) {
if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) {
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.VerifyCoverterSupportForBlock");
<< " (conversion.VerifyCoverterSupportForBlock)");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
Expand Down

0 comments on commit 6421f3d

Please sign in to comment.