diff --git a/.github/workflows/ut-cuda.yml b/.github/workflows/ut-cuda.yml index 918c1a4a..4e573adf 100644 --- a/.github/workflows/ut-cuda.yml +++ b/.github/workflows/ut-cuda.yml @@ -88,4 +88,3 @@ jobs: - name: Run Tutorials run: | python3 ./examples/tutorial/quickstart_tutorial.py - python3 ./examples/tutorial/plan_tutorial.py diff --git a/ark/api/context.cpp b/ark/api/context.cpp new file mode 100644 index 00000000..76baedc8 --- /dev/null +++ b/ark/api/context.cpp @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "context_impl.hpp" +#include "logging.hpp" + +namespace ark { + +Context::Context(Model& model) : impl_(std::make_shared(model)) {} + +int Context::id() const { return this->impl_->id_; } + +std::string Context::get(const std::string& key) const { + if (!this->impl_->has(key)) { + return ""; + } + return this->impl_->get(key).dump(); +} + +void Context::set(const std::string& key, const std::string& value, + ContextType type) { + Json value_json; + try { + value_json = Json::parse(value); + } catch (const ::nlohmann::json::parse_error& e) { + ERR(InvalidUsageError, "Failed to parse context value as JSON: `", + value, "`"); + } + this->impl_->set(key, value_json, type); +} + +} // namespace ark diff --git a/ark/api/context_test.cpp b/ark/api/context_test.cpp new file mode 100644 index 00000000..82bc38d4 --- /dev/null +++ b/ark/api/context_test.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/context.hpp" + +#include "model/model_node.hpp" +#include "unittest/unittest_utils.h" + +ark::unittest::State test_context() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::Context ctx(model); + ctx.set("key0", ark::Json("val1").dump()); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("key0"), ark::Json("val1").dump()); + + // node 2 + ctx.set("key1", ark::Json("val2").dump()); + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("key0"), ark::Json("val1").dump()); + UNITTEST_EQ(ctx.get("key1"), ark::Json("val2").dump()); + } + { + // node 3 + ark::Context ctx(model); + ctx.set("key0", ark::Json("val3").dump()); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("key0"), ark::Json("val3").dump()); + UNITTEST_EQ(ctx.get("key1"), ""); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_EQ(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1")); + UNITTEST_EQ(nodes[2]->context.size(), 2); + UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1")); + UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2")); + UNITTEST_EQ(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3")); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_context_invalid() { + ark::Model model; + ark::Context ctx(model); + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + ark::Tensor t2 = model.add(t0, t1); + + UNITTEST_THROW(ctx.set("key", "val"), ark::InvalidUsageError); + + return ark::unittest::SUCCESS; +} + +int main() { + UNITTEST(test_context); + UNITTEST(test_context_invalid); + return 0; +} diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 1ec2da5e..d9fc9217 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -213,6 +213,15 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, plan_json = Json::parse(plan); } + auto gpu_manager = GpuManager::get_instance(gpu_id_); + if (!gpu_manager->info().arch->belongs_to( + Arch::from_name(plan_json.at("Architecture")))) { + LOG(WARN, "Architecture name of the plan `", + plan_json.at("Architecture").get(), + "` is not compatible with the GPU architecture `", + gpu_manager->info().arch->name(), "`."); + } + buffer_id_to_offset_ = init_buffers(plan_json); std::string buffer_id_to_offset_str; @@ -224,7 +233,6 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, codegen_ = std::make_shared(plan_json, buffer_id_to_offset_, name); - auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); timer_end_ = gpu_manager->create_event(); buffer_ = gpu_manager->malloc(total_bytes_, 65536); @@ -816,9 +824,9 @@ DefaultExecutor::DefaultExecutor(const Model &model, int gpu_id, model.rank(), model.world_size(), (gpu_id < 0) ? (model.rank() % get_env().num_ranks_per_host) : gpu_id, name, - DefaultPlanner(model, (gpu_id < 0) ? (model.rank() % - get_env().num_ranks_per_host) - : gpu_id) + Planner(model, (gpu_id < 0) + ? (model.rank() % get_env().num_ranks_per_host) + : gpu_id) .plan()) {} } // namespace ark diff --git a/ark/api/model.cpp b/ark/api/model.cpp index d62ce8f1..dcbd4940 100644 --- a/ark/api/model.cpp +++ b/ark/api/model.cpp @@ -9,6 +9,15 @@ namespace ark { +Model::Model(int rank, int world_size) : ModelGraph(rank, world_size) { + static size_t next_id = 0; + id_ = next_id++; +} + +Model::Model(const Model &other) : ModelGraph(other), id_(other.id()) {} + +size_t Model::id() const { return id_; } + Model Model::compress() const { Model model(*this); model.compress_nodes(); diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 49556025..22b9b680 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -4,36 +4,108 @@ #include "ark/planner.hpp" #include "ark/model.hpp" +#include "context_impl.hpp" #include "env.h" #include "file_io.h" #include "gpu/gpu_manager.h" #include "model/model_json.hpp" #include "model/model_node.hpp" #include "model/model_op.hpp" +#include "range.hpp" namespace ark { -class DefaultPlanner::Impl { +PlannerContext::PlannerContext(Model &model) : Context(model) { + this->impl_->set("Id", this->id(), ContextType::Immutable); +} + +void PlannerContext::check_range(const std::string &key, + const Range &range) { + auto prev = this->impl_->get(key); + if (prev.empty()) { + // ok + return; + } + auto prev_vec = prev.get>(); + if (prev_vec.size() < 2 || prev_vec.size() > 3) { + ERR(InternalError, "unexpected"); + } + int prev_step = (prev_vec.size() == 3) ? prev_vec[2] : 1; + Range prev_range(prev_vec[0], prev_vec[1], prev_step); + if (!range.is_subset_of(prev_range)) { + ERR(PlanError, "New ", key, " ", range, + " is not a subset of the previous range ", prev_range); + } +} + +void PlannerContext::processor_range(int start, int end, int step) { + check_range("ProcessorRange", {start, end, step}); + if (step == 1) { + this->impl_->set("ProcessorRange", {start, end}, + ContextType::Overwrite); + } else { + this->impl_->set("ProcessorRange", {start, end, step}, + ContextType::Overwrite); + } +} + +void PlannerContext::warp_range(int start, int end, int step) { + check_range("WarpRange", {start, end, step}); + if (step == 1) { + this->impl_->set("WarpRange", {start, end}, ContextType::Overwrite); + } else { + this->impl_->set("WarpRange", {start, end, step}, + ContextType::Overwrite); + } +} + +void PlannerContext::sram_range(int start, int end, int step) { + check_range("SramRange", {start, end, step}); + if (step == 1) { + this->impl_->set("SramRange", {start, end}, ContextType::Overwrite); + } else { + this->impl_->set("SramRange", {start, end, step}, + ContextType::Overwrite); + } +} + +void PlannerContext::sync(bool sync) { + if (sync) { + // `true` should not overwrite `false`. + if (this->impl_->get("Sync") == Json(false)) { + LOG(WARN, "Ignoring sync(true) while sync(false) is already set"); + return; + } + this->impl_->set("Sync", true, ContextType::Immutable); + } else { + this->impl_->set("Sync", false, ContextType::Overwrite); + } +} + +void PlannerContext::config(const std::string &config) { + this->impl_->set("Config", Json::parse(config), ContextType::Extend); +} + +class Planner::Impl { public: - Impl(const Model &model, int gpu_id); + Impl(const Model &model, int device_id); - void install_config_rule(DefaultPlanner::ConfigRule rule); + void install_config_rule(Planner::ConfigRule rule); std::string plan(bool pretty) const; protected: - friend class DefaultPlanner; + friend class Planner; Model model_; - int gpu_id_; - std::vector config_rules_; + int device_id_; + std::vector config_rules_; }; -DefaultPlanner::Impl::Impl(const Model &model, int gpu_id) - : model_(model.compress()), gpu_id_(gpu_id) {} +Planner::Impl::Impl(const Model &model, int device_id) + : model_(model.compress()), device_id_(device_id) {} -void DefaultPlanner::Impl::install_config_rule( - DefaultPlanner::ConfigRule rule) { +void Planner::Impl::install_config_rule(Planner::ConfigRule rule) { config_rules_.push_back( [rule](const std::string &op, const std::string &arch) -> std::string { try { @@ -52,23 +124,35 @@ static void check_config_field(const ModelOpRef op, const Json &config, } } -std::string DefaultPlanner::Impl::plan(bool pretty) const { - const auto gpu_info = GpuManager::get_instance(gpu_id_)->info(); +std::string Planner::Impl::plan(bool pretty) const { + const auto gpu_info = GpuManager::get_instance(device_id_)->info(); size_t num_sm = gpu_info.num_sm; Json task_infos = Json::array(); Json processor_groups = Json::array(); - size_t max_num_warps = 1; - size_t max_num_processors = 1; - size_t next_node_id = 0; + size_t max_processor_id = 1; + size_t max_warp_id = 1; + size_t next_task_id = 0; + int prev_ctx_id = -1; + bool first_op = true; + + auto get_context = [&](const ModelNodeRef &node, + const std::string &key) -> Json { + if (node->context.find(key) != node->context.end()) { + return node->context.at(key); + } + return Json(); + }; + for (const auto &node : model_.nodes()) { const auto &op = node->op; if (op->is_virtual()) continue; - Json task_info; - task_info["Id"] = next_node_id++; + auto ctx_config = get_context(node, "Config"); Json config; - if (!config_rules_.empty()) { + if (!ctx_config.empty()) { + config = ctx_config; + } else if (!config_rules_.empty()) { const std::string op_str = op->serialize().dump(); for (auto &rule : config_rules_) { auto config_str = rule(op_str, gpu_info.arch->name()); @@ -87,38 +171,79 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { size_t num_warps = config["NumWarps"]; size_t num_tasks = config["NumTasks"]; size_t sram_bytes = config["SramBytes"]; - task_info["NumWarps"] = num_warps; - task_info["SramBytes"] = sram_bytes; - - max_num_warps = std::max(max_num_warps, num_warps); - - task_info["Ops"] = Json::array(); - task_info["Ops"].push_back(op->serialize()); - task_info["Ops"][0]["Config"] = config; - task_infos.push_back(task_info); - - Json resource_group; - size_t num_processors = std::min(num_sm, num_tasks); - max_num_processors = std::max(max_num_processors, num_processors); - resource_group["ProcessorRange"] = {0, num_processors}; - resource_group["WarpRange"] = {0, num_warps}; - resource_group["SramRange"] = {0, sram_bytes}; - resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, - {"TaskRange", {0, num_tasks}}, - {"Granularity", 1}}}; - - Json processor_group; - processor_group["ProcessorRange"] = {0, num_processors}; - processor_group["ResourceGroups"] = Json::array(); - processor_group["ResourceGroups"].push_back(resource_group); - processor_groups.push_back(processor_group); + + size_t granularity = config.value("Granularity", 1); + auto ctx_id = get_context(node, "Id"); + auto ctx_sync = get_context(node, "Sync"); + int id = ctx_id.empty() ? -1 : ctx_id.get(); + bool sync = ctx_sync.empty() ? true : ctx_sync.get(); + if (id == prev_ctx_id && !sync) { + auto &task_info = task_infos.back(); + task_info["NumWarps"] = + std::max(task_info["NumWarps"].get(), num_warps); + task_info["SramBytes"] = + std::max(task_info["SramBytes"].get(), sram_bytes); + task_info["Ops"].push_back(op->serialize()); + task_info["Ops"].back()["Config"] = config; + } else { + Json task_info; + task_info["Id"] = first_op ? next_task_id : ++next_task_id; + task_info["NumWarps"] = num_warps; + task_info["SramBytes"] = sram_bytes; + task_info["Ops"] = Json::array(); + task_info["Ops"].push_back(op->serialize()); + task_info["Ops"][0]["Config"] = config; + task_infos.push_back(task_info); + + auto ctx_processor_range = get_context(node, "ProcessorRange"); + auto ctx_warp_range = get_context(node, "WarpRange"); + auto ctx_sram_range = get_context(node, "SramRange"); + + Json processor_group; + if (!ctx_processor_range.empty()) { + processor_group["ProcessorRange"] = ctx_processor_range; + max_processor_id = std::max( + max_processor_id, ctx_processor_range[1].get()); + } else { + size_t num_processors = std::min(num_sm, num_tasks); + processor_group["ProcessorRange"] = {0, num_processors}; + max_processor_id = std::max(max_processor_id, num_processors); + } + + Json resource_group; + resource_group["ProcessorRange"] = + processor_group["ProcessorRange"]; + if (!ctx_warp_range.empty()) { + resource_group["WarpRange"] = ctx_warp_range; + max_warp_id = + std::max(max_warp_id, ctx_warp_range[1].get()); + } else { + resource_group["WarpRange"] = {0, num_warps}; + max_warp_id = std::max(max_warp_id, num_warps); + } + if (!ctx_sram_range.empty()) { + resource_group["SramRange"] = ctx_sram_range; + } else { + resource_group["SramRange"] = {0, sram_bytes}; + } + resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, + {"TaskRange", {0, num_tasks}}, + {"Granularity", granularity}}}; + + processor_group["ResourceGroups"] = Json::array(); + processor_group["ResourceGroups"].push_back(resource_group); + processor_groups.push_back(processor_group); + } + prev_ctx_id = id; + first_op = false; } Json plan; plan["Rank"] = model_.rank(); plan["WorldSize"] = model_.world_size(); - plan["NumProcessors"] = max_num_processors; - plan["NumWarpsPerProcessor"] = max_num_warps; + plan["Architecture"] = gpu_info.arch->name(); + plan["NumProcessors"] = max_processor_id; + plan["NumWarpsPerProcessor"] = max_warp_id; plan["TaskInfos"] = task_infos; plan["ProcessorGroups"] = processor_groups; @@ -129,23 +254,22 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { plan_str = plan.dump(); } const auto &tmp = get_env().path_tmp_dir; - write_file(tmp + "/model_gpu" + std::to_string(gpu_id_) + ".json", + write_file(tmp + "/model_gpu" + std::to_string(device_id_) + ".json", model_.serialize()); - write_file(tmp + "/plan_gpu" + std::to_string(gpu_id_) + ".json", plan_str); + write_file(tmp + "/plan_gpu" + std::to_string(device_id_) + ".json", + plan_str); return plan_str; } -DefaultPlanner::DefaultPlanner(const Model &model, int gpu_id) - : impl_(std::make_unique(model, gpu_id)) {} +Planner::Planner(const Model &model, int device_id) + : impl_(std::make_unique(model, device_id)) {} -DefaultPlanner::~DefaultPlanner() = default; +Planner::~Planner() = default; -void DefaultPlanner::install_config_rule(DefaultPlanner::ConfigRule rule) { +void Planner::install_config_rule(Planner::ConfigRule rule) { impl_->install_config_rule(rule); } -std::string DefaultPlanner::plan(bool pretty) const { - return impl_->plan(pretty); -} +std::string Planner::plan(bool pretty) const { return impl_->plan(pretty); } } // namespace ark diff --git a/ark/api/planner_test.cpp b/ark/api/planner_test.cpp new file mode 100644 index 00000000..011b25d8 --- /dev/null +++ b/ark/api/planner_test.cpp @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/planner.hpp" + +#include "model/model_node.hpp" +#include "unittest/unittest_utils.h" + +ark::unittest::State test_planner_context_processor_range() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::PlannerContext ctx(model); + ctx.processor_range(0, 4); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({0, 4}).dump()); + + // node 2 + ctx.processor_range(2, 4); + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({2, 4}).dump()); + + // Invalid usage: range (0, 4) is out of previous range (2, 4) + UNITTEST_THROW(ctx.processor_range(0, 4), ark::PlanError); + } + { + // node 3 + ark::PlannerContext ctx(model); + ctx.processor_range(2, 6, 2); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({2, 6, 2}).dump()); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_GE(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("ProcessorRange"), ark::Json({0, 4})); + UNITTEST_GE(nodes[2]->context.size(), 1); + UNITTEST_EQ(nodes[2]->context.at("ProcessorRange"), ark::Json({2, 4})); + UNITTEST_GE(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("ProcessorRange"), ark::Json({2, 6, 2})); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_planner_context_warp_range() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::PlannerContext ctx(model); + ctx.warp_range(0, 4); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("WarpRange"), ark::Json({0, 4}).dump()); + + // node 2 + ctx.warp_range(2, 4); + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("WarpRange"), ark::Json({2, 4}).dump()); + + // Invalid usage: range (0, 4) is out of previous range (2, 4) + UNITTEST_THROW(ctx.warp_range(0, 4), ark::PlanError); + } + { + // node 3 + ark::PlannerContext ctx(model); + ctx.warp_range(2, 6, 2); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("WarpRange"), ark::Json({2, 6, 2}).dump()); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_GE(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("WarpRange"), ark::Json({0, 4})); + UNITTEST_GE(nodes[2]->context.size(), 1); + UNITTEST_EQ(nodes[2]->context.at("WarpRange"), ark::Json({2, 4})); + UNITTEST_GE(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("WarpRange"), ark::Json({2, 6, 2})); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_planner_context_sram_range() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::PlannerContext ctx(model); + ctx.sram_range(0, 4); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("SramRange"), ark::Json({0, 4}).dump()); + + // node 2 + ctx.sram_range(2, 4); + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("SramRange"), ark::Json({2, 4}).dump()); + + // Invalid usage: range (0, 4) is out of previous range (2, 4) + UNITTEST_THROW(ctx.sram_range(0, 4), ark::PlanError); + } + { + // node 3 + ark::PlannerContext ctx(model); + ctx.sram_range(2, 6, 2); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("SramRange"), ark::Json({2, 6, 2}).dump()); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_GE(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("SramRange"), ark::Json({0, 4})); + UNITTEST_GE(nodes[2]->context.size(), 1); + UNITTEST_EQ(nodes[2]->context.at("SramRange"), ark::Json({2, 4})); + UNITTEST_GE(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("SramRange"), ark::Json({2, 6, 2})); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_planner_context_sync() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::PlannerContext ctx(model); + ctx.sync(false); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("Sync"), ark::Json(false).dump()); + + // node 2 + ctx.sync(true); // will be ignored with a warning message + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("Sync"), ark::Json(false).dump()); + } + { + // node 3 + ark::PlannerContext ctx(model); + ctx.sync(true); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("Sync"), ark::Json(true).dump()); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_GE(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("Sync"), ark::Json(false)); + UNITTEST_GE(nodes[2]->context.size(), 1); + UNITTEST_EQ(nodes[2]->context.at("Sync"), ark::Json(false)); + UNITTEST_GE(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("Sync"), ark::Json(true)); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_planner_context_config() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::PlannerContext ctx(model); + ctx.config(ark::Json({{"key0", "val1"}}).dump()); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("Config"), ark::Json({{"key0", "val1"}}).dump()); + + // node 2 + ctx.config(ark::Json({{"key1", "val2"}}).dump()); + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("Config"), + ark::Json({{"key0", "val1"}, {"key1", "val2"}}).dump()); + } + { + // node 3 + ark::PlannerContext ctx(model); + ctx.config(ark::Json({{"key2", "val3"}}).dump()); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("Config"), ark::Json({{"key2", "val3"}}).dump()); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_GE(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("Config"), ark::Json({{"key0", "val1"}})); + UNITTEST_GE(nodes[2]->context.size(), 1); + UNITTEST_EQ(nodes[2]->context.at("Config"), + ark::Json({{"key0", "val1"}, {"key1", "val2"}})); + UNITTEST_GE(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("Config"), ark::Json({{"key2", "val3"}})); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_planner_context_plan() { + ark::Model model; + ark::PlannerContext ctx(model); + ctx.processor_range(0, 2); + ctx.warp_range(0, 4); + ctx.sram_range(0, 0); + ctx.sync(false); + ark::Json cfg({{"NumWarps", 1}, + {"SramBytes", 0}, + {"NumTasks", 1}, + {"Tile", {1, 64}}}); + ctx.config(cfg.dump()); + + ark::Tensor t0 = model.tensor({1024}, ark::FP32); + ark::Tensor t1 = model.mul(t0, 0.5); + ark::Tensor t2 = model.add(t0, t1); + + ark::Planner planner(model, 0); + auto plan = ark::Json::parse(planner.plan(false)); + + UNITTEST_EQ(plan["NumProcessors"].get(), 2); + UNITTEST_EQ(plan["NumWarpsPerProcessor"].get(), 4); + UNITTEST_EQ(plan["TaskInfos"].size(), 1); + UNITTEST_EQ(plan["TaskInfos"][0]["NumWarps"], 1); + UNITTEST_EQ(plan["TaskInfos"][0]["SramBytes"], 0); + UNITTEST_EQ(plan["TaskInfos"][0]["Ops"].size(), 2); + UNITTEST_EQ(plan["TaskInfos"][0]["Ops"][0]["Type"].get(), + "ScalarMul"); + UNITTEST_EQ(plan["TaskInfos"][0]["Ops"][0]["Config"], cfg); + UNITTEST_EQ(plan["TaskInfos"][0]["Ops"][1]["Type"].get(), + "Add"); + UNITTEST_EQ(plan["TaskInfos"][0]["Ops"][1]["Config"], cfg); + + return ark::unittest::SUCCESS; +} + +int main() { + UNITTEST(test_planner_context_processor_range); + UNITTEST(test_planner_context_warp_range); + UNITTEST(test_planner_context_sram_range); + UNITTEST(test_planner_context_sync); + UNITTEST(test_planner_context_config); + UNITTEST(test_planner_context_plan); + return 0; +} diff --git a/ark/context_impl.cpp b/ark/context_impl.cpp new file mode 100644 index 00000000..9a2692ea --- /dev/null +++ b/ark/context_impl.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "context_impl.hpp" + +#include "logging.hpp" +#include "model/model_context_manager.hpp" +#include "model/model_graph_impl.hpp" + +namespace ark { + +Context::Impl::Impl(Model& model) + : context_manager_(std::make_shared(model)) { + static int next_id = 0; + id_ = next_id++; +} + +Json Context::Impl::get(const std::string& key) const { + return context_manager_->get(key); +} + +void Context::Impl::set(const std::string& key, const Json& value_json, + ContextType type) { + if (type == ContextType::Overwrite) { + context_manager_->set(key, value_json); + } else if (type == ContextType::Extend) { + auto ctx = context_manager_->get(key); + if (ctx.empty()) { + context_manager_->set(key, value_json); + } else if (!ctx.is_object() || !value_json.is_object()) { + ERR(InvalidUsageError, + "Context value must be a JSON object when type is " + "ContextTypeExtend. Key: ", + key, ", old value: ", ctx.dump(), + ", new value: ", value_json.dump()); + } else { + for (const auto& [k, v] : value_json.items()) { + ctx[k] = v; + } + context_manager_->set(key, ctx); + } + } else if (type == ContextType::Immutable) { + if (!context_manager_->has(key)) { + context_manager_->set(key, value_json); + } + } else { + ERR(InvalidUsageError, "Unknown context type"); + } +} + +bool Context::Impl::has(const std::string& key) const { + return context_manager_->has(key); +} + +} // namespace ark diff --git a/ark/context_impl.hpp b/ark/context_impl.hpp new file mode 100644 index 00000000..1a77891b --- /dev/null +++ b/ark/context_impl.hpp @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_CONTEXT_IMPL_HPP_ +#define ARK_CONTEXT_IMPL_HPP_ + +#include "ark/context.hpp" +#include "model/model_json.hpp" + +namespace ark { + +class ModelContextManager; + +class Context::Impl { + public: + Impl(Model& model); + + Json get(const std::string& key) const; + + void set(const std::string& key, const Json& value_json, ContextType type); + + bool has(const std::string& key) const; + + protected: + friend class Context; + + std::shared_ptr context_manager_; + int id_; +}; + +} // namespace ark + +#endif // ARK_CONTEXT_IMPL_HPP_ diff --git a/ark/include/ark.hpp b/ark/include/ark.hpp index 2ca79617..b1955bf9 100644 --- a/ark/include/ark.hpp +++ b/ark/include/ark.hpp @@ -8,6 +8,7 @@ #include // clang-format on +#include #include #include #include diff --git a/ark/include/ark/context.hpp b/ark/include/ark/context.hpp new file mode 100644 index 00000000..f3eef283 --- /dev/null +++ b/ark/include/ark/context.hpp @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_CONTEXT_HPP +#define ARK_CONTEXT_HPP + +#include + +namespace ark { + +enum class ContextType { + Overwrite, + Extend, + Immutable, +}; + +class Context { + public: + /// + /// Construct an empty context for the given model. + /// + /// @param model The model to create the context for. + /// + Context(Model& model); + + /// Get the ID of this context. + int id() const; + + /// Get context value by key. + /// @param key The key of the context item. + /// @return The value of the context item. If the key does not exist, + /// an empty string is returned. + std::string get(const std::string& key) const; + + /// + /// Add an item to the context. + /// + /// The given context item is valid for the lifetime of the context + /// object. @p `value` is assumed to be a JSON string. + /// If @p `key` is already in use by another valid context item + /// of either the same or different context object for the same model, + /// the behavior is determined by the context type @p `type` as follows. + /// + /// - `ContextType::Overwrite` (default): The existing value will be + /// replaced with the new one while the context object is alive. + /// When the context object is destroyed, the previous value will be + /// restored. + /// + /// - `ContextType::Extend`: The new value will extend the existing + /// value while the context object is alive. This type is feasible only + /// when the value represents a JSON object, which is convertible to a + /// map. If the new JSON object has a key that already exists in the + /// existing JSON object, the value of the existing key will be + /// overwritten by the new value. When the context object is destroyed, + /// the previous value will be restored. + /// + /// - `ContextType::Immutable`: The new value will be adopted only when the + /// key does not exist in the existing context or when the value of the key + /// is empty. If the key already exists, the new value will be ignored. + /// When the context object is destroyed, if the key did not exist in the + /// existing context, the key will be removed. + /// Otherwise, nothing will be changed. + /// + /// @param key The key of the context item. + /// @param value The value of the context item. The value is assumed to + /// be a JSON string. An empty JSON string is also allowed. + /// @param type The context type. Default is `ContextType::Overwrite`. + /// + /// @throw `InvalidUsageError` In the following cases: + /// + /// - The value cannot be parsed as JSON. + /// + /// - The value is not a JSON object when the context type is + /// `ContextType::Extend`. + /// + /// - The context type is unknown. + /// + void set(const std::string& key, const std::string& value, + ContextType type = ContextType::Overwrite); + + protected: + friend class PlannerContext; + + class Impl; + std::shared_ptr impl_; +}; + +} // namespace ark + +#endif // ARK_CONTEXT_HPP diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 66551a03..9766b023 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -17,15 +17,21 @@ namespace ark { class Model : public ModelGraph { private: + size_t id_; std::set tags_; public: - Model(int rank = 0, int world_size = 1) : ModelGraph(rank, world_size) {} - Model(const Model &other) : ModelGraph(other) {} + Model(int rank = 0, int world_size = 1); + + Model(const Model &other); + ~Model() {} Model &operator=(const Model &other) = default; + /// Get the unique identifier of the model. + size_t id() const; + Model compress() const; int unique_tag(); @@ -33,6 +39,14 @@ class Model : public ModelGraph { Tensor constant(float val, const Dims &shape, DataType data_type, const std::string &name = ""); + /// No operation. + /// + /// This operator can be used to prevent unused tensors from being optimized + /// out by the compiler. + /// + /// @param input Input tensor. + /// @param name Name of the operator. + /// void noop(Tensor input, const std::string &name = ""); /// Returns a tensor object. diff --git a/ark/include/ark/model_graph.hpp b/ark/include/ark/model_graph.hpp index bd7c5903..29074630 100644 --- a/ark/include/ark/model_graph.hpp +++ b/ark/include/ark/model_graph.hpp @@ -38,6 +38,8 @@ class ModelGraph { protected: friend class Model; + friend class ModelContextManager; + friend class Context; class Impl; std::unique_ptr impl_; diff --git a/ark/include/ark/planner.hpp b/ark/include/ark/planner.hpp index 13f3158f..9547848b 100644 --- a/ark/include/ark/planner.hpp +++ b/ark/include/ark/planner.hpp @@ -4,19 +4,39 @@ #ifndef ARK_PLANNER_HPP #define ARK_PLANNER_HPP +#include #include #include #include namespace ark { -class Model; +template +class Range; -class DefaultPlanner { +class PlannerContext : public Context { public: - DefaultPlanner(const Model &model, int gpu_id); + PlannerContext(Model& model); - ~DefaultPlanner(); + void processor_range(int start, int end, int step = 1); + + void warp_range(int start, int end, int step = 1); + + void sram_range(int start, int end, int step = 1); + + void sync(bool sync); + + void config(const std::string& config); + + private: + void check_range(const std::string& key, const Range& range); +}; + +class Planner { + public: + Planner(const Model& model, int device_id); + + ~Planner(); using ConfigRule = std::function; diff --git a/ark/model/model_context_manager.cpp b/ark/model/model_context_manager.cpp new file mode 100644 index 00000000..f1bb62e9 --- /dev/null +++ b/ark/model/model_context_manager.cpp @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "model_context_manager.hpp" + +namespace ark { + +ModelContextManager::ModelContextManager(Model& model) + : context_stack_(model.impl_->context_stack_) {} + +ModelContextManager::~ModelContextManager() { + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + context_stack_->pop(*it); + } +} + +void ModelContextManager::set(const std::string& key, const Json& value) { + context_stack_->push(key, value); + keys_.push_back(key); +} + +bool ModelContextManager::has(const std::string& key) const { + return context_stack_->has(key); +} + +Json ModelContextManager::get(const std::string& key) const { + return context_stack_->get(key); +} + +} // namespace ark diff --git a/ark/model/model_context_manager.hpp b/ark/model/model_context_manager.hpp new file mode 100644 index 00000000..6aa91692 --- /dev/null +++ b/ark/model/model_context_manager.hpp @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_MODEL_CONTEXT_MANAGER_HPP_ +#define ARK_MODEL_CONTEXT_MANAGER_HPP_ + +#include + +#include "ark/model.hpp" +#include "model_graph_impl.hpp" +#include "model_json.hpp" + +namespace ark { + +class ModelContextManager { + public: + ModelContextManager(Model& model); + + ~ModelContextManager(); + + void set(const std::string& key, const Json& value); + + bool has(const std::string& key) const; + + Json get(const std::string& key) const; + + private: + std::shared_ptr context_stack_; + std::vector keys_; +}; + +} // namespace ark + +#endif // ARK_MODEL_CONTEXT_MANAGER_HPP_ diff --git a/ark/model/model_context_manager_test.cpp b/ark/model/model_context_manager_test.cpp new file mode 100644 index 00000000..b63f03ca --- /dev/null +++ b/ark/model/model_context_manager_test.cpp @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "model_context_manager.hpp" + +#include "model_node.hpp" +#include "unittest/unittest_utils.h" + +ark::unittest::State test_model_context_manager() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::ModelContextManager cm(model); + cm.set("key0", ark::Json("val1")); + t3 = model.relu(t2); + + // node 2 + cm.set("key1", ark::Json("val2")); + t4 = model.sqrt(t3); + } + { + // node 3 + ark::ModelContextManager cm(model); + cm.set("key0", ark::Json("val3")); + t5 = model.exp(t2); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_EQ(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1")); + UNITTEST_EQ(nodes[2]->context.size(), 2); + UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1")); + UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2")); + UNITTEST_EQ(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3")); + + return ark::unittest::SUCCESS; +} + +int main() { + UNITTEST(test_model_context_manager); + return 0; +} diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index e187468e..7c1ea3fb 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -17,6 +17,51 @@ namespace ark { +ModelGraphContextStack::ModelGraphContextStack( + const ModelGraphContextStack &other) { + for (const auto &pair : other.storage_) { + for (const auto &value : pair.second) { + this->storage_[pair.first].push_back(value); + } + } +} + +void ModelGraphContextStack::push(const std::string &key, const Json &value) { + this->storage_[key].push_back(std::make_shared(value)); +} + +void ModelGraphContextStack::pop(const std::string &key) { + auto it = this->storage_.find(key); + if (it == this->storage_.end() || it->second.empty()) { + ERR(InternalError, "context stack is empty"); + } + it->second.pop_back(); + if (it->second.empty()) { + this->storage_.erase(it); + } +} + +bool ModelGraphContextStack::has(const std::string &key) const { + return this->storage_.find(key) != this->storage_.end(); +} + +Json ModelGraphContextStack::get(const std::string &key) const { + if (this->has(key)) { + return *this->storage_.at(key).back(); + } + return Json(); +} + +std::map ModelGraphContextStack::get_all() const { + std::map cur; + for (const auto &pair : this->storage_) { + if (!pair.second.empty()) { + cur[pair.first] = *pair.second.back(); + } + } + return cur; +} + ModelGraph::Impl::Impl(const ModelGraph::Impl &other) { *this = other; } ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { @@ -25,6 +70,7 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { for (const auto &node : other.nodes_) { ModelNodeRef new_node = std::make_shared(); new_node->op = node->op; + new_node->context = node->context; node_map.emplace(node, new_node); nodes_.push_back(new_node); } @@ -61,6 +107,8 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { rank_ = other.rank_; world_size_ = other.world_size_; compressed_ = other.compressed_; + context_stack_ = + std::make_shared(*(other.context_stack_)); return *this; } @@ -168,6 +216,8 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { producer->consumers.push_back(node); } + node->context = context_stack_->get_all(); + nodes_.push_back(node); return node; } diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index ae471831..62944f99 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -4,6 +4,7 @@ #ifndef ARK_MODEL_GRAPH_IMPL_HPP_ #define ARK_MODEL_GRAPH_IMPL_HPP_ +#include #include #include #include @@ -18,10 +19,35 @@ namespace ark { +class ModelGraphContextStack { + private: + std::map>> storage_; + + public: + ModelGraphContextStack() = default; + + ModelGraphContextStack(const ModelGraphContextStack &other); + + ~ModelGraphContextStack() = default; + + void push(const std::string &key, const Json &value); + + void pop(const std::string &key); + + bool has(const std::string &key) const; + + Json get(const std::string &key) const; + + std::map get_all() const; +}; + class ModelGraph::Impl { public: Impl(int rank, int world_size) - : rank_(rank), world_size_(world_size), compressed_(false){}; + : rank_(rank), + world_size_(world_size), + compressed_(false), + context_stack_(std::make_shared()){}; Impl(const Impl &other); @@ -93,6 +119,12 @@ class ModelGraph::Impl { /// True if `compress_nodes` has been called. bool compressed_; + + protected: + friend class ModelContextManager; + + /// Graph context stack. + std::shared_ptr context_stack_; }; } // namespace ark diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index bdef38d4..b82f9e48 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -257,9 +257,13 @@ static void verify_format_processor_group(const Json &json) { } static void verify_format_plan(const Json &json) { - const std::vector required_fields = { - "Rank", "WorldSize", "NumProcessors", "NumWarpsPerProcessor", - "TaskInfos", "ProcessorGroups"}; + const std::vector required_fields = {"Rank", + "WorldSize", + "Architecture", + "NumProcessors", + "NumWarpsPerProcessor", + "TaskInfos", + "ProcessorGroups"}; for (const auto &field : required_fields) { if (!json.contains(field)) { ERR(PlanError, field + " not found"); @@ -279,7 +283,16 @@ static void verify_format_plan(const Json &json) { } } -PlanJson::PlanJson(const Json &json) : Json(json) { verify_format_plan(*this); } +PlanJson::PlanJson(const Json &json) + : Json((json != nullptr) ? json + : Json{{"Rank", 0}, + {"WorldSize", 1}, + {"NumProcessors", 1}, + {"NumWarpsPerProcessor", 1}, + {"TaskInfos", Json::array()}, + {"ProcessorGroups", Json::array()}}) { + verify_format_plan(*this); +} static std::stringstream &dump_pretty_plan(const Json &json, std::stringstream &ss, int indent, @@ -290,6 +303,9 @@ static std::stringstream &dump_pretty_plan(const Json &json, dump_pretty_item(json.at("WorldSize"), "WorldSize", ss, indent + indent_step) << ",\n"; + dump_pretty_item(json.at("Architecture"), "Architecture", ss, + indent + indent_step) + << ",\n"; dump_pretty_item(json.at("NumProcessors"), "NumProcessors", ss, indent + indent_step) << ",\n"; diff --git a/ark/model/model_json.hpp b/ark/model/model_json.hpp index cf5fbbce..e42640a9 100644 --- a/ark/model/model_json.hpp +++ b/ark/model/model_json.hpp @@ -18,7 +18,7 @@ class ModelJson : public Json { class PlanJson : public Json { public: - PlanJson(const Json &json); + PlanJson(const Json &json = nullptr); std::string dump_pretty(int indent = 0, int indent_step = 2) const; }; diff --git a/ark/model/model_node.hpp b/ark/model/model_node.hpp index 264e891e..ca97f454 100644 --- a/ark/model/model_node.hpp +++ b/ark/model/model_node.hpp @@ -8,6 +8,7 @@ #include #include "ark/model_ref.hpp" +#include "model_json.hpp" #include "unique_list.hpp" namespace ark { @@ -25,6 +26,9 @@ class ModelNode { /// The list of @ref ModelNode that this @ref ModelNode depends on. UniqueList producers; + + /// Graph context of this node. + std::map context; }; } // namespace ark diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index f4685c99..173d1a92 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -199,8 +199,11 @@ std::shared_ptr ModelOp::deserialize(const Json &serialized) { } else if (!serialized.contains("Args")) { ERR(ModelError, "ModelOp deserialization failed: missing Args"); } + // Run `ModelOpT::from_name` before `construct()` to ensure all operators + // are registered. + auto op_type = ModelOpT::from_name(serialized["Type"]); auto ret = model_op_factory()->construct(serialized["Type"]); - ret->type_ = ModelOpT::from_name(serialized["Type"]); + ret->type_ = op_type; ret->name_ = serialized["Name"]; ret->is_virtual_ = serialized["IsVirtual"]; for (const auto &t : serialized["ReadTensors"]) { diff --git a/ark/ops/ops_communication_test.cpp b/ark/ops/ops_communication_test.cpp index 2b63642e..f2a66f08 100644 --- a/ark/ops/ops_communication_test.cpp +++ b/ark/ops/ops_communication_test.cpp @@ -229,7 +229,7 @@ ark::unittest::State test_communication_send_recv_bidir_sm() { ark::Tensor tns2 = model.identity(tns2_data, {tns}); tns2 = model.recv(tns2_data, remote_gpu_id, tag); - ark::DefaultPlanner planner(model, gpu_id); + ark::Planner planner(model, gpu_id); planner.install_config_rule(config_rule); ark::Executor exe(gpu_id, 2, gpu_id, "Executor", planner.plan()); exe.compile(); @@ -275,7 +275,7 @@ ark::unittest::State test_communication_send_recv_bidir_sm() { ark::Tensor sum = model.add(tns2, tns_data); - ark::DefaultPlanner planner(model, gpu_id); + ark::Planner planner(model, gpu_id); planner.install_config_rule(config_rule); ark::Executor exe(gpu_id, 2, gpu_id, "Executor", planner.plan()); exe.compile(); diff --git a/ark/range.cpp b/ark/range.cpp new file mode 100644 index 00000000..348d2a46 --- /dev/null +++ b/ark/range.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "range.hpp" + +namespace ark { + +std::ostream& operator<<(std::ostream& os, const Range& range) { + if (range.step() == 1) { + os << "(" << *range.begin() << ", " << *range.end() << ")"; + } else { + os << "(" << *range.begin() << ", " << *range.end() << ", " + << range.step() << ")"; + } + return os; +} + +} // namespace ark diff --git a/ark/range.hpp b/ark/range.hpp index 47a312a7..18c8170d 100644 --- a/ark/range.hpp +++ b/ark/range.hpp @@ -4,6 +4,7 @@ #ifndef ARK_RANGE_HPP_ #define ARK_RANGE_HPP_ +#include #include namespace ark { @@ -71,7 +72,7 @@ class Range { T size() const { return (end_ - begin_) / step_; } - std::vector intersection(const Range &other) { + std::vector intersection(const Range &other) const { T begin, step; T opp_begin, opp_step; if (begin_ > other.begin_) { @@ -98,6 +99,10 @@ class Range { return result; } + bool is_subset_of(const Range &other) const { + return intersection(other).size() == static_cast(size()); + } + private: T begin_; T end_; @@ -119,6 +124,8 @@ Range range(T begin, T end, T step) { return Range(begin, end, step); } +std::ostream &operator<<(std::ostream &os, const Range &range); + } // namespace ark #endif // ARK_RANGE_HPP_ diff --git a/ark/unittest/unittest_utils.h b/ark/unittest/unittest_utils.h index 3cfc6803..423c8536 100644 --- a/ark/unittest/unittest_utils.h +++ b/ark/unittest/unittest_utils.h @@ -143,6 +143,42 @@ std::string get_kernel_code(const std::string &name); ") >= `" #exp1 "` (value: ", _v1, ")"); \ } while (0) +// Check if the `exp0` is less than or equal to `exp1`. +#define UNITTEST_LE(exp0, exp1) \ + do { \ + auto _v0 = (exp0); \ + auto _v1 = (exp1); \ + if (_v0 <= static_cast(_v1)) { \ + break; \ + } \ + UNITTEST_FEXIT("`" #exp0 "` (value: ", _v0, \ + ") > `" #exp1 "` (value: ", _v1, ")"); \ + } while (0) + +// Check if the `exp0` is greater than `exp1`. +#define UNITTEST_GT(exp0, exp1) \ + do { \ + auto _v0 = (exp0); \ + auto _v1 = (exp1); \ + if (_v0 > static_cast(_v1)) { \ + break; \ + } \ + UNITTEST_FEXIT("`" #exp0 "` (value: ", _v0, \ + ") <= `" #exp1 "` (value: ", _v1, ")"); \ + } while (0) + +// Check if the `exp0` is greater than or equal to `exp1`. +#define UNITTEST_GE(exp0, exp1) \ + do { \ + auto _v0 = (exp0); \ + auto _v1 = (exp1); \ + if (_v0 >= static_cast(_v1)) { \ + break; \ + } \ + UNITTEST_FEXIT("`" #exp0 "` (value: ", _v0, \ + ") < `" #exp1 "` (value: ", _v1, ")"); \ + } while (0) + // Check if the given expression throws a given exception. #define UNITTEST_THROW(exp, exception) \ do { \ diff --git a/examples/tutorial/plan_tutorial.py b/examples/tutorial/plan_tutorial.py index 056523e1..56002152 100644 --- a/examples/tutorial/plan_tutorial.py +++ b/examples/tutorial/plan_tutorial.py @@ -331,7 +331,7 @@ def main(plan_path: str): # Calculate default result ground_truth = None with ark.Runtime.get_runtime() as rt: - planner = ark.DefaultPlanner() + planner = ark.Planner() # If this rule is installed, default planner will perform the same as # `plan_1_larger_tile.json` on A100. diff --git a/examples/tutorial/planner_tutorial.py b/examples/tutorial/planner_tutorial.py new file mode 100644 index 00000000..1f6c3ac5 --- /dev/null +++ b/examples/tutorial/planner_tutorial.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark +import time +import torch +import torch.nn.functional as F + + +class VanillaSoftmax(ark.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + max = ark.reduce_max(input, axis=-1) + output = ark.sub(input, max) + output = ark.exp(output) + sum = ark.reduce_sum(output, axis=-1) + output = ark.div(output, sum) + return output + + +class Softmax(ark.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + with ark.PlannerContext( + warp_range=[0, 8], + sram_range=[0, 0], + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + "NumTasks": 65536, + }, + ): + with ark.PlannerContext(config={"ImplType": "WarpWise"}): + max = ark.reduce_max(input, axis=-1) + with ark.PlannerContext(config={"Tile": [1, 2048]}): + output = ark.sub(input, max) + output = ark.exp(output) + with ark.PlannerContext(config={"ImplType": "WarpWise"}): + sum = ark.reduce_sum(output, axis=-1) + with ark.PlannerContext(config={"Tile": [1, 2048]}): + output = ark.div(output, sum) + return output + + +def eval(tensor: ark.Tensor): + with ark.Runtime() as rt: + rt.launch() + rt.run() + return tensor.to_torch() + + +def perf(): + with ark.Runtime() as rt: + rt.launch() + + start = time.time() + rt.run(iter=1000) + end = time.time() + return (end - start) / 1000 + + +if __name__ == "__main__": + ark.init() + + shape = (32, 2048, 2048) + + # input = torch.randn(*shape).to("cuda:0") + input = ark.tensor(shape) + + output = Softmax()(input) + + # if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): + # print("Correct result") + # else: + # print("Incorrect result") + + print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/python/ark/__init__.py b/python/ark/__init__.py index a85c91d4..e8dc7e6c 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -39,7 +39,7 @@ def set_world_size(world_size): from .init import init from .tensor import Dims, Tensor, Parameter from .module import Module -from .runtime import Runtime, DefaultPlanner +from .runtime import Runtime from .serialize import save, load from .data_type import ( DataType, @@ -91,13 +91,5 @@ def set_world_size(world_size): ones, zeros, ) -from .error import ( - BaseError, - InternalError, - InvalidUsageError, - ModelError, - PlanError, - UnsupportedError, - SystemError, - GpuError, -) +from .planner import * +from .error import * diff --git a/python/ark/planner.py b/python/ark/planner.py new file mode 100644 index 00000000..e7eb2e7e --- /dev/null +++ b/python/ark/planner.py @@ -0,0 +1,230 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import json +from typing import Callable, Dict, List, Any + +from _ark_core import _Planner, _PlannerContext +from .model import Model + + +def idnt(indent): + return " " * indent + + +def dquote(s): + return '"' + s + '"' + + +def denser_json_obj(obj, key, level, indent, indent_step, ret=""): + if len(obj) == 0: + if key: + return ret + idnt(indent) + dquote(key) + ": {}" + else: + return ret + idnt(indent) + "{}" + ret += idnt(indent) + if key: + ret += dquote(key) + ": {\n" + else: + ret += "{\n" + num_item = len(obj) + for k, v in obj.items(): + is_obj_or_arr = isinstance(v, dict) or isinstance(v, list) + is_num_arr = isinstance(v, list) and v and isinstance(v[0], int) + if level <= 0 or not is_obj_or_arr or is_num_arr: + ret += ( + idnt(indent + indent_step) + + dquote(k) + + ": " + + json.dumps(v, separators=(",", ":")) + ) + elif isinstance(v, dict): + ret += denser_json_obj( + v, k, level - 1, indent + indent_step, indent_step + ) + elif isinstance(v, list): + ret += denser_json_arr( + v, k, level - 1, indent + indent_step, indent_step + ) + num_item -= 1 + if num_item > 0: + ret += ",\n" + else: + ret += "\n" + ret += idnt(indent) + "}" + return ret + + +def denser_json_arr(obj, key, level, indent, indent_step, ret=""): + if len(obj) == 0: + if key: + return ret + idnt(indent) + dquote(key) + ": []" + else: + return ret + idnt(indent) + "[]" + ret += idnt(indent) + if key: + ret += dquote(key) + ": [\n" + else: + ret += "[\n" + num_item = len(obj) + for v in obj: + is_obj_or_arr = isinstance(v, dict) or isinstance(v, list) + is_num_arr = ( + isinstance(v, list) + and v + and (isinstance(v[0], int) or isinstance(v[0], float)) + ) + if level <= 0 or not is_obj_or_arr or is_num_arr: + ret += idnt(indent + indent_step) + json.dumps( + v, separators=(",", ":") + ) + elif isinstance(v, dict): + ret += denser_json_obj( + v, "", level - 1, indent + indent_step, indent_step + ) + elif isinstance(v, list): + ret += denser_json_arr( + v, "", level - 1, indent + indent_step, indent_step + ) + num_item -= 1 + if num_item > 0: + ret += ",\n" + else: + ret += "\n" + ret += idnt(indent) + "]" + return ret + + +def denser_json(obj, level, indent_step=2): + if isinstance(obj, dict): + return denser_json_obj(obj, "", level, 0, indent_step, "") + elif isinstance(obj, list): + return denser_json_arr(obj, "", level, 0, indent_step, "") + return json.dumps(obj, indent=indent_step) + + +class Plan: + def __init__(self, plan: Dict[str, Any]): + if plan is None: + plan = {} + plan["Rank"] = 0 + plan["WorldSize"] = 1 + plan["Architecture"] = "ANY" + plan["NumProcessors"] = 1 + plan["NumWarpsPerProcessor"] = 1 + plan["TaskInfos"] = [] + plan["ProcessorGroups"] = [] + else: + plan = copy.deepcopy(plan) + self.plan = plan + + def __str__(self) -> str: + return denser_json(self.plan, 5) + + @property + def rank(self) -> int: + return self.plan["Rank"] + + @property + def world_size(self) -> int: + return self.plan["WorldSize"] + + @property + def architecture(self) -> str: + return self.plan["Architecture"] + + @property + def num_processors(self) -> int: + return self.plan["NumProcessors"] + + @property + def num_warps_per_processor(self) -> int: + return self.plan["NumWarpsPerProcessor"] + + @property + def task_infos(self) -> List[Dict[str, Any]]: + return self.plan["TaskInfos"] + + @property + def processor_groups(self) -> List[Dict[str, Any]]: + return self.plan["ProcessorGroups"] + + @staticmethod + def from_str(plan_str: str) -> "Plan": + plan = json.loads(plan_str) + return Plan(plan) + + @staticmethod + def from_file(file_path: str) -> "Plan": + with open(file_path, "r") as f: + plan = json.load(f) + return Plan(plan) + + +class PlannerContext(_PlannerContext): + def __init__(self, **kwargs): + """ + Plan manager for specifying the parallelization and tiling configuration of the operators in the context. + + Args: + processor_range (List[int], optional): The range of processors to be used. Defaults to None. + warp_range (List[int], optional): The range of warps to be used. Defaults to None. + sram_range (List[int], optional): The range of SRAMs to be used. Defaults to None. + sync (bool, optional): Whether to synchronize the execution. Defaults to True. + config (Dict[str, Any], optional): The configuration for the operators. Defaults to None. + """ + super().__init__(Model.get_model()) + prange: List[int] = kwargs.get("processor_range", None) + wrange: List[int] = kwargs.get("warp_range", None) + srange: List[int] = kwargs.get("sram_range", None) + sync: bool = kwargs.get("sync", True) + config: Dict[str, Any] = kwargs.get("config", None) + + if prange is not None: + self.processor_range(*prange) + if wrange is not None: + self.warp_range(*wrange) + if srange is not None: + self.sram_range(*srange) + if sync is False: + self.sync(sync) + if config is not None: + self.config(json.dumps(config)) + + def __enter__(self) -> "PlannerContext": + """ + Enter the plan manager. + """ + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + """ + Exit the plan manager. + """ + del self + + +class Planner(_Planner): + def __init__(self, device_id: int = 0): + compressed = Model.get_model().compress() + super().__init__(compressed, device_id) + + def install_config_rule(self, rule: Callable[[str, str], str]): + """ + Install a configuration rule. + + Args: + rule: A function that takes an operator description and a target + architecture name and returns a configuration description. + """ + super().install_config_rule(rule) + + def plan(self) -> Plan: + """ + Generate an execution plan. + """ + return Plan.from_str(super().plan(pretty=False)) + + +__all__ = ["Plan", "PlannerContext", "Planner"] diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 7480ce7d..d29b036c 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -3,10 +3,9 @@ import logging from enum import Enum -from typing import Callable -from _ark_core import _Executor, _DefaultPlanner -from .model import Model +from _ark_core import _Executor +from .planner import Planner, Plan class _RuntimeState: @@ -18,31 +17,6 @@ class _RuntimeState: executor = None -class DefaultPlanner(_DefaultPlanner): - def __init__(self, gpu_id: int = 0): - compressed = Model.get_model().compress() - super().__init__(compressed, gpu_id) - - def install_config_rule(self, rule: Callable[[str, str], str]): - """ - Install a configuration rule. - - Args: - rule: A function that takes an operator description and a target - architecture name and returns a configuration description. - """ - super().install_config_rule(rule) - - def plan(self, pretty: bool = True) -> str: - """ - Generate an execution plan. - - Args: - pretty: Whether to generate a pretty plan. - """ - return super().plan(pretty) - - class Executor(_Executor): pass @@ -101,11 +75,8 @@ def running(self) -> bool: def launch( self, - rank: int = 0, - world_size: int = 1, - gpu_id: int = 0, - plan: str = "", - plan_path: str = "", + plan: Plan = None, + device_id: int = 0, ): """ Create an executor and schedule the ARK model. The scheduler will generate @@ -115,12 +86,7 @@ def launch( if self.launched(): logging.warn("Runtime is already launched, skip launching") return - if not plan: - if not plan_path: - plan = DefaultPlanner(gpu_id).plan() - else: - with open(plan_path, "r") as f: - plan = f.read() + plan = Planner(device_id).plan() if plan is None else plan # If the RuntimeState is init, we need to create a new executor and # compile the kernels if self.state == Runtime.State.Init: @@ -130,11 +96,11 @@ def launch( _RuntimeState.executor.destroy() _RuntimeState.executor = Executor( - rank, - world_size, - gpu_id, + plan.rank, + plan.world_size, + device_id, "ArkRuntime", - plan, + str(plan), ) self.executor = _RuntimeState.executor self.executor.compile() diff --git a/python/model_py.cpp b/python/model_py.cpp index 2d1e5f63..46c70a7d 100644 --- a/python/model_py.cpp +++ b/python/model_py.cpp @@ -15,6 +15,7 @@ void register_model(py::module &m) { .def(py::init(), py::arg("rank"), py::arg("world_size")) .def("rank", &ark::Model::rank) .def("world_size", &ark::Model::world_size) + .def("id", &ark::Model::id) .def("compress", &ark::Model::compress) .def("add", py::overload_cast(m, "_DefaultPlanner") + py::class_(m, "_PlannerContext") + .def(py::init()) + .def("processor_range", &ark::PlannerContext::processor_range, + py::arg("start"), py::arg("end"), py::arg("step") = 1) + .def("warp_range", &ark::PlannerContext::warp_range, py::arg("start"), + py::arg("end"), py::arg("step") = 1) + .def("sram_range", &ark::PlannerContext::sram_range, py::arg("start"), + py::arg("end"), py::arg("step") = 1) + .def("sync", &ark::PlannerContext::sync, py::arg("sync")) + .def("config", &ark::PlannerContext::config, py::arg("config")); + + py::class_(m, "_Planner") .def(py::init()) .def("install_config_rule", - [](ark::DefaultPlanner *self, const py::function &rule) { + [](ark::Planner *self, const py::function &rule) { self->install_config_rule( [rule](const std::string &op, const std::string &arch) { return rule(op, arch).cast(); }); }) - .def("plan", &ark::DefaultPlanner::plan, py::arg("pretty") = true); + .def("plan", &ark::Planner::plan, py::arg("pretty") = true); } diff --git a/python/unittest/test_runtime.py b/python/unittest/test_runtime.py index bd9098fe..d91fd85c 100644 --- a/python/unittest/test_runtime.py +++ b/python/unittest/test_runtime.py @@ -2,29 +2,17 @@ # Licensed under the MIT license. import ark -import json def test_runtime_relaunch(): ark.init() - empty_plan = json.dumps( - { - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 1, - "NumWarpsPerProcessor": 1, - "TaskInfos": [], - "ProcessorGroups": [], - } - ) - with ark.Runtime.get_runtime() as rt: assert rt.launched() == False - rt.launch(plan=empty_plan) + rt.launch() assert rt.launched() == True with ark.Runtime.get_runtime() as rt: assert rt.launched() == False - rt.launch(plan=empty_plan) + rt.launch() assert rt.launched() == True