diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 6eec767611e0..600e3c565358 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -290,14 +290,14 @@ class CanonicalSimplifier { }; /*! - * \brief A RAII constraint context. + * \brief Constraint context. * * \code * * Var("x"); * arith::Analyzer analyzer; * { - * arith::ConstraintContext cctx(&analyzer, x % 3 == 0); + * With scope(&analyzer, x % 3 == 0); * CHECK_EQ(analyzer.modular_set(x)->coeff, 3); * } * // constraint no longer in effect. @@ -306,19 +306,24 @@ class CanonicalSimplifier { * \endcode */ class ConstraintContext { - public: + private: + // declare friend to enable with. + friend class With; /*! * \brief Construct a constraint context. * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION; - /*! \brief destructor */ - ~ConstraintContext() DMLC_THROW_EXCEPTION { - exit_(); - } - - private: + ConstraintContext(Analyzer* analyzer, Expr constraint) + : analyzer_(analyzer), constraint_(constraint) {} + // enter the scope. + void EnterWithScope(); + // exit the scope. + void ExitWithScope(); + /*! \brief The analyzer */ + Analyzer* analyzer_; + /*! \brief The constraint */ + Expr constraint_; /*! \brief function to be called in recovery */ std::function exit_; }; diff --git a/include/tvm/base.h b/include/tvm/base.h index 049a427ffce8..f358f7f5d447 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -101,6 +101,50 @@ using ::tvm::AttrVisitor; TVM_DEFINE_NODE_REF_COW(NodeName); \ }; +/*! + * \brief RAII wrapper function to enter and exit a context object + * similar to python's with syntax. + * + * \code + * // context class + * class MyContext { + * private: + * friend class With; + MyContext(arguments); + * void EnterWithScope(); + * void ExitWithScope(); + * }; + * + * { + * With scope(arguments); + * // effect take place. + * } + * \endcode + * + * \tparam ContextType Type of the context object. + */ +template +class With { + public: + /*! + * \brief constructor. + * Enter the scope of the context. + */ + template + explicit With(Args&& ...args) + : ctx_(std::forward(args)...) { + ctx_.EnterWithScope(); + } + /*! \brief destructor, leaves the scope of the context. */ + ~With() DMLC_THROW_EXCEPTION { + ctx_.ExitWithScope(); + } + + private: + /*! \brief internal context type. */ + ContextType ctx_; +}; + /*! * \brief save the node as well as all the node it depends on as json. * This can be used to serialize any TVM object diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 7fb456c823a7..187a74552241 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -37,7 +37,7 @@ namespace tvm { /*! * \brief Container for target device information. -* Use target::llvm, target::cuda etc functions instead of constructing directly. +* Use target::llvm, target::cuda etc functions instead of constructing directly. */ class TargetNode : public Node { public: @@ -89,65 +89,47 @@ class TargetNode : public Node { mutable std::string str_repr_; }; +/*! \brief reference cpass to the target. */ class Target : public NodeRef { public: Target() {} explicit Target(NodePtr n) : NodeRef(n) {} - /*! * \brief Create a Target given a string * \param target_str the string to parse */ - TVM_DLL static Target create(const std::string& target_str); - - /*! - * \brief Push a new target context onto the thread local stack. The Target on top of - * the stack is used to determine which specialization to use when invoking a GenericFunc. - * \param target The target to set as the current context. - */ - TVM_DLL static void EnterTargetScope(const tvm::Target& target); - - /*! - * \brief Pop a target off the thread local context stack, restoring the previous target - * as the current context. - */ - TVM_DLL static void ExitTargetScope(); - + TVM_DLL static Target Create(const std::string& target_str); /*! - * \brief Get the current target context from thread local storage. - * \param allow_not_defined If the context stack is empty and this is set to true, an - * undefined Target will be returned. Otherwise, an empty context stack will cause a - * runtime error. - * \return The target that is the current context. The target may not be defined if - * allow_not_defined is true. - */ - TVM_DLL static tvm::Target current_target(bool allow_not_defined = true); + * \brief Get the current target context from thread local storage. + * \param allow_not_defined If the context stack is empty and this is set to true, an + * undefined Target will be returned. Otherwise, an empty context stack will cause a + * runtime error. + * \return The target that is the current context. The target may not be defined if + * allow_not_defined is true. + */ + TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - inline const TargetNode* operator->() const { + const TargetNode* operator->() const { return static_cast(node_.get()); } using ContainerType = TargetNode; -}; - -/*! - * \brief RAII container to provide a scoped target context. Pushes a target onto the - * context stack when constructed, and pops it when destructed. - */ -struct TargetContext { + class Internal; + private: + // enable with syntax. + friend class Internal; + friend class With; /*! - * \brief Enter a new target context. The given target becomes the new current context. - * When the TargetContext is destructed, the previous context is restored. - * \param target The target to set as the new current context. + * \brief Push a new target context onto the thread local stack. + * The Target on top of the stack is used to determine which + * specialization to use when invoking a GenericFunc. */ - explicit TargetContext(const tvm::Target& target) { - Target::EnterTargetScope(target); - } - - /*! \brief Destructor. Pops the context off the thread local stack. */ - ~TargetContext() { - Target::ExitTargetScope(); - } + TVM_DLL void EnterWithScope(); + /*! + * \brief Pop a target off the thread local context stack, + * restoring the previous target as the current context. + */ + TVM_DLL void ExitWithScope(); }; /*! \brief This namespace provides functions to construct Target instances */ @@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector& options = } // namespace target -class BuildConfig; - /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class BuildConfigNode : public Node { public: /*! @@ -271,69 +251,48 @@ class BuildConfigNode : public Node { }; /*! -* \brief Container for build configuration options -*/ + * \brief Build configuration for compilations. + */ class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} - const BuildConfigNode* operator->() const { return static_cast(node_.get()); } - BuildConfigNode* operator->() { return static_cast(node_.get()); } - /*! - * \brief Push a new BuildConfig context onto the thread local stack. - * \param build_config The configuration to set as the current context. + * \brief Construct a BuildConfig containing a empty build config node. + * \return The new BuildConfig */ - TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config); - - /*! - * \brief Pop a build config off the thread local context stack, restoring the previous - * configuration as the current context. - */ - TVM_DLL static void ExitBuildConfigScope(); - + TVM_DLL static BuildConfig Create(); /*! * \brief Get the current BuildConfig context from thread local storage, or a default * configuration if a BuildConfig scope has not been entered. * \return The configuration that is the current context. */ - TVM_DLL static tvm::BuildConfig Current(); + TVM_DLL static BuildConfig Current(); using ContainerType = BuildConfigNode; -}; + class Internal; -/*! - * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the - * context stack when constructed, and pops it when destructed. - */ -struct BuildConfigContext { + private: + // Enable with syntax. + friend class With; /*! - * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current - * context. When the BuildConfigContext is destructed, the previous context is restored. - * \param build_config The BuildConfig to set as the new current context. + * \brief Push a new BuildConfig context onto the thread local stack. */ - explicit BuildConfigContext(const tvm::BuildConfig& build_config) { - BuildConfig::EnterBuildConfigScope(build_config); - } + TVM_DLL void EnterWithScope(); - /*! \brief Destructor. Pops the context off the thread local stack. */ - ~BuildConfigContext() { - BuildConfig::ExitBuildConfigScope(); - } + /*! + * \brief Pop a build config off the thread local context stack, + * restoring the previous configuration as the current context. + */ + TVM_DLL void ExitWithScope(); }; -/*! -* \brief Construct a BuildConfig containing a new BuildConfigNode -* \return The new BuildConfig -*/ -TVM_DLL BuildConfig build_config(); - /*! * \brief Build a LoweredFunc given a schedule, args and binds * \param sch The schedule to lower. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index a28ab98fb60e..76170a844db1 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -187,7 +187,7 @@ def __enter__(self): def __exit__(self, ptype, value, trace): if self.dump_pass_ir: BuildConfig._dump_ir.exit() - _api_internal._ExitBuildConfigScope() + _api_internal._ExitBuildConfigScope(self) def __setattr__(self, name, value): if name in BuildConfig._node_defaults: diff --git a/python/tvm/target.py b/python/tvm/target.py index eff0088b37ce..828fff8e228c 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -133,7 +133,7 @@ def __enter__(self): return self def __exit__(self, ptype, value, trace): - _api_internal._ExitTargetScope() + _api_internal._ExitTargetScope(self) @register_node diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 55a706420f06..4d5d8bdf58d3 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { // can't use make_shared due to noexcept(false) decl in destructor, // see https://stackoverflow.com/a/43907314 - auto ctx = - std::shared_ptr(new ConstraintContext(self.get(), args[0])); + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 420d6f9c1d0d..bd8c7005f458 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) { // skip rewrite simplify } -ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) { + +void ConstraintContext::EnterWithScope() { + CHECK(exit_ == nullptr); // entering the scope. - auto f0 = analyzer->const_int_bound.EnterConstraint(constraint); - auto f1 = analyzer->modular_set.EnterConstraint(constraint); + auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); + auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); // recovery function. exit_ = [f0, f1]() { if (f1 != nullptr) f1(); @@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) }; } +void ConstraintContext::ExitWithScope() { + CHECK(exit_ != nullptr); + exit_(); +} + bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { if (const auto* ptr = expr.as()) { return ptr->value > lower_bound; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 58d2b83a223a..0de2a2535ae7 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) { Expr cond = Mutate(op->condition); Expr true_value, false_value; { - ConstraintContext constraint(parent_, cond); + With constraint(parent_, cond); true_value = Mutate(op->true_value); } { - ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + With constraint(parent_, Mutate(Not::make(cond))); false_value = Mutate(op->false_value); } if (is_zero(cond)) { @@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) { Expr cond = Mutate(op->args[0]); Expr true_value, false_value; { - ConstraintContext constraint(parent_, cond); + With constraint(parent_, cond); true_value = Mutate(op->args[1]); } { - ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + With constraint(parent_, Mutate(Not::make(cond))); false_value = Mutate(op->args[2]); } if (is_zero(cond)) { diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index c793214b92f4..403187eb39fd 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator { Expr condition = this->Mutate(op->condition); Stmt then_case, else_case; { - ConstraintContext ctx(&analyzer_, condition); + With ctx(&analyzer_, condition); then_case = this->Mutate(op->then_case); } if (op->else_case.defined()) { - ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition))); + With ctx(&analyzer_, Mutate(Not::make(condition))); else_case = this->Mutate(op->else_case); } if (is_one(condition)) return then_case; @@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator { Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { Expr condition = this->Mutate(op->condition); Expr message = this->Mutate(op->message); - ConstraintContext ctx(&analyzer_, condition); + With ctx(&analyzer_, condition); Stmt body = this->Mutate(op->body); if (condition.same_as(op->condition) && diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index ac6b797d9683..834b4eea7e3f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Compile executable modules. * \file build_module.cc */ @@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate") TVM_REGISTER_API("_TargetFromString") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; - - *ret = Target::create(target_str); + *ret = Target::Create(target_str); }); std::vector TargetNode::keys() const { @@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) { return ""; } -Target Target::create(const std::string& target_str) { +Target Target::Create(const std::string& target_str) { if (target_str.length() == 0) { LOG(ERROR) << "target_str must not be empty"; } @@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) { struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ std::stack context_stack; - - TVMTargetThreadLocalEntry() { - } }; /*! \brief Thread local store to hold the Target context stack. */ typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; -void Target::EnterTargetScope(const tvm::Target& target) { +void Target::EnterWithScope() { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); - entry->context_stack.push(target); + entry->context_stack.push(*this); } -void Target::ExitTargetScope() { +void Target::ExitWithScope() { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } -tvm::Target Target::current_target(bool allow_not_defined) { +tvm::Target Target::Current(bool allow_not_defined) { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); @@ -574,7 +571,7 @@ runtime::Module build(const Map>& inputs, const BuildConfig& config) { Map> updated_input; for (const auto& it : inputs) { - auto target = Target::create(it.first); + auto target = Target::Create(it.first); updated_input.Set(target, it.second); } return build(updated_input, target_host, config); @@ -589,33 +586,35 @@ runtime::Module build(const Array& funcs, return build(inputs, target_host, config); } -BuildConfig build_config() { +BuildConfig BuildConfig::Create() { return BuildConfig(make_node()); } /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMBuildConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ - tvm::BuildConfig default_config; + BuildConfig default_config; /*! \brief The current build config context */ - std::stack context_stack; + std::stack context_stack; TVMBuildConfigThreadLocalEntry() : - default_config(build_config()) { + default_config(BuildConfig::Create()) { } }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMBuildConfigThreadLocalStore; -void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) { +void BuildConfig::EnterWithScope() { TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); - entry->context_stack.push(build_config); + entry->context_stack.push(*this); } -void BuildConfig::ExitBuildConfigScope() { +void BuildConfig::ExitWithScope() { TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } @@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { auto node = static_cast(node_.get()); - auto target = Target::current_target(true); + auto target = Target::Current(true); PackedFunc func; if (target.defined()) { @@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig") *ret = BuildConfig::Current(); }); +class BuildConfig::Internal { + public: + static void EnterScope(BuildConfig target) { + target.EnterWithScope(); + } + static void ExitScope(BuildConfig target) { + target.ExitWithScope(); + } +}; + TVM_REGISTER_API("_EnterBuildConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig target = args[0]; - BuildConfig::EnterBuildConfigScope(target); - }); +.set_body_typed(BuildConfig::Internal::EnterScope); TVM_REGISTER_API("_ExitBuildConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig::ExitBuildConfigScope(); - }); +.set_body_typed(BuildConfig::Internal::ExitScope); TVM_REGISTER_API("_BuildConfigSetAddLowerPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc") TVM_REGISTER_API("_GetCurrentTarget") .set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; - *ret = Target::current_target(allow_not_defined); + *ret = Target::Current(allow_not_defined); }); +class Target::Internal { + public: + static void EnterScope(Target target) { + target.EnterWithScope(); + } + static void ExitScope(Target target) { + target.ExitWithScope(); + } +}; + TVM_REGISTER_API("_EnterTargetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Target target = args[0]; - Target::EnterTargetScope(target); - }); +.set_body_typed(Target::Internal::EnterScope); TVM_REGISTER_API("_ExitTargetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Target::ExitTargetScope(); - }); +.set_body_typed(Target::Internal::ExitScope); } // namespace tvm diff --git a/src/codegen/codegen_aocl.cc b/src/codegen/codegen_aocl.cc index 6f899cbb0b53..03b9b6869d17 100644 --- a/src/codegen/codegen_aocl.cc +++ b/src/codegen/codegen_aocl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array funcs, std::string target_str, std::string cmd = "aoc aocl.cl"; // AOCL supports fp64. cmd += " -Dcl_khr_fp64"; - Target target = Target::create(target_str); + Target target = Target::Create(target_str); if (target->device_name != "") { cmd += " -board=" + target->device_name; } diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index a18312fe6af5..4d86cc5b4b00 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { std::string xclbin; if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { - Target target = Target::create(target_str); + Target target = Target::Create(target_str); xclbin = (*f)(kernel_info, target->device_name).operator std::string(); } else { LOG(FATAL) << "Cannot compile Vivado HLS code."; diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index bedcdc79ff1f..1e56583a37fd 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { } void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { - arith::ConstraintContext cctx(analyzer_.get(), op->condition); + With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index e6fc0088dc81..fd113ca4614a 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { } void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { - arith::ConstraintContext cctx(analyzer_.get(), op->condition); + With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 8a0c32fc6684..3b1491072d25 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (targets.size() == 1) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); for (const auto& kv : targets) { - TargetContext tctx(kv.second); + With tctx(kv.second); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); } } else { @@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode { */ Target CreateDefaultTarget(int device_type) { std::string name = runtime::DeviceName(device_type); - if (name == "cpu") return Target::create("llvm"); - if (name == "gpu") return Target::create("cuda"); - return Target::create(name); + if (name == "cpu") return Target::Create("llvm"); + if (name == "gpu") return Target::Create("cuda"); + return Target::Create(name); } /*! * \brief Update the target and fallback device required for heterogeneous @@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode { const RelayBuildConfig& cfg, const std::unordered_map ¶ms) { // convert - tvm_cfg_ = build_config(); + tvm_cfg_ = BuildConfig::Create(); TargetsMap device_target; if (targets_.size() > 1) { device_target = UpdateHeterogeneousInputs(targets_, cfg); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index a824c457107a..f11dd2875b80 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_[key] = value; } // Enforce use the target. - TargetContext target_ctx(key->target); + With target_scope(key->target); CHECK(!value->cached_func.defined()); auto spair = CreateSchedule(key->source_func, key->target); @@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_node->funcs = (*f)( spair.first, all_args, cache_node->func_name, key->source_func); } else { - tvm::BuildConfig bcfg = tvm::build_config(); + tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 97f03c629cb7..602e92759624 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor { // Next generate the invoke instruction. CHECK(func->IsPrimitive()); - auto target = Target::create("llvm"); + auto target = Target::Create("llvm"); auto key = CCacheKeyNode::make(func, target); auto cfunc = engine->Lower(key); // TODO(jroesch): support lowered funcs for multiple targets @@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, runtime::Module mod; if (lowered_funcs.size() > 0) { // TODO(@jroesch): we need to read target from build config - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); if (const auto* f = runtime::Registry::Get("relay.backend.build")) { mod = (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), target); } else { diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 45aa449e72ab..c085d80d06e2 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - BuildConfigContext fresh_build_ctx(build_config()); + With fresh_build_ctx(BuildConfig::Create()); return ConstantFolder(CreateInterpreter( Module(nullptr), ctx, target)).Mutate(expr); diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 5349532ca697..ad861743dfd5 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -375,10 +375,10 @@ DLContext CPUContext() { } FInterpreter CPUInterpreter() { - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - BuildConfigContext fresh_build_ctx(build_config()); + With fresh_build_ctx(BuildConfig::Create()); return CreateInterpreter(Module(nullptr), CPUContext(), target); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 393714d8f636..6dbd78e9566d 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -50,14 +50,14 @@ TEST(BuildModule, Basic) { auto args = Array({ A, B, C }); std::unordered_map binds; - auto config = build_config(); + auto config = BuildConfig::Create(); auto target = target::llvm(); auto lowered = lower(s, args, "func", binds, config); auto module = build(lowered, target, Target(), config); - auto mali_target = Target::create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); - CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); + auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); + CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); } TEST(BuildModule, Heterogeneous) { @@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) { auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); auto s2 = create_schedule({elemwise_sub->op}); - auto config = build_config(); + auto config = BuildConfig::Create(); auto args1 = Array({A, B, elemwise_add}); auto args2 = Array({copy, C, elemwise_sub}); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index a1ab29959127..3f46eed9f10e 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,7 @@ TEST(Relay, BuildModule) { auto json_f = build_mod.GetFunction("get_graph_json", false); auto mod_f = build_mod.GetFunction("get_module", false); Map targets; - Target llvm_tgt = Target::create("llvm"); + Target llvm_tgt = Target::Create("llvm"); targets.Set(0, llvm_tgt); build_f(func, targets, llvm_tgt); std::string json = json_f(); diff --git a/topi/src/topi.cc b/topi/src/topi.cc index d3e0bc938f7c..57a2743ae6d0 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) { TVM_REGISTER_GLOBAL("topi.TEST_create_target") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tvm::Target::create(args[0]); + *rv = tvm::Target::Create(args[0]); }); /* Ops from broadcast.h */ @@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function< */ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - auto target = Target::current_target(false); + auto target = Target::Current(false); Array outs; NodeRef argNodeRef = args[0]; if (argNodeRef->type_index() == outs->type_index()) { @@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function