From 6421f3de2b16d83f92f3c75c9e98b11179841041 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 10 Jun 2020 19:41:32 -0700 Subject: [PATCH] feat(//core/conversion): Evaluation of static conditionals works now Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion.cpp | 54 +++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 23193df7e4..aea53ff6b1 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) { } } +void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n); + void MapIValues(ConversionCtx* ctx, c10::ArrayRef in_list, c10::ArrayRef out_list, int64_t in_offset, int64_t out_offset) { std::vector> input_output_pairs; std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset, @@ -204,6 +206,31 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef in_l } } +void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) { + auto condition = ctx->evaluated_value_map[n->input(0)].toBool(); + LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition); + auto b = condition ? n->blocks()[0] : n->blocks()[1]; + + for (const auto bn : b->nodes()) { + if (bn->kind() == torch::jit::prim::Loop) { + EvaluateLoopBlock(ctx, bn); + } else if (bn->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, bn); + } else { + TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.") + auto eval = EvaluateNode(ctx, bn); + if (!eval.value().isTensor()) { + LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value()); + } else { + LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); + } + ctx->AssociateValueAndIValue(bn->output(0), eval.value()); + } + } + + MapIValues(ctx, b->outputs(), n->outputs(), 0, 0); +} + // TODO: With functionalization pass we may be able to make this into a regular evaluator later void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { auto max_trip_count = ctx->evaluated_value_map[n->input(0)]; @@ -213,16 +240,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { MapIValues(ctx, n->inputs(), n->outputs(), 2, 0); - LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n); - LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt()); - LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool()); - LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) { MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1); for (auto bn : n->blocks()[0]->nodes()) { - auto eval = EvaluateNode(ctx, bn); - if (eval) { + if (bn->kind() == torch::jit::prim::Loop) { + EvaluateLoopBlock(ctx, n); + } else if (bn->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, bn); + } else { + TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated."); + auto eval = EvaluateNode(ctx, bn); if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value()); } else { @@ -236,8 +268,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]]; auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1); trip_count.swap(new_trip_count); - LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool()); - LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); } } @@ -255,6 +287,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver bool blacklisted = isNodeConversionBlacklisted(n); if (n->kind() == torch::jit::prim::Loop) { EvaluateLoopBlock(ctx, n); + } else if (n->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, n); } else if (to_eval) { auto eval = EvaluateNode(ctx, n); if (eval) { @@ -303,10 +337,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil std::set GetUnsupportedOpsInBlock(const torch::jit::Block* b ) { std::set unsupported_ops; for (const auto n : b->nodes()) { - if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) { + if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) { auto schema = n->maybeSchema(); TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \ - << " (conversion.VerifyCoverterSupportForBlock"); + << " (conversion.VerifyCoverterSupportForBlock)"); std::stringstream ss; ss << *schema; unsupported_ops.insert(ss.str());