Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend plan interfaces #231

Merged
merged 13 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ut-cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,3 @@ jobs:
- name: Run Tutorials
run: |
python3 ./examples/tutorial/quickstart_tutorial.py
python3 ./examples/tutorial/plan_tutorial.py
32 changes: 32 additions & 0 deletions ark/api/context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "context_impl.hpp"
#include "logging.hpp"

namespace ark {

Context::Context(Model& model) : impl_(std::make_shared<Impl>(model)) {}

int Context::id() const { return this->impl_->id_; }

std::string Context::get(const std::string& key) const {
if (!this->impl_->has(key)) {
return "";
}
return this->impl_->get(key).dump();
}

void Context::set(const std::string& key, const std::string& value,
ContextType type) {
Json value_json;
try {
value_json = Json::parse(value);
} catch (const ::nlohmann::json::parse_error& e) {
ERR(InvalidUsageError, "Failed to parse context value as JSON: `",
value, "`");
}
this->impl_->set(key, value_json, type);
}

} // namespace ark
81 changes: 81 additions & 0 deletions ark/api/context_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "ark/context.hpp"

#include "model/model_node.hpp"
#include "unittest/unittest_utils.h"

ark::unittest::State test_context() {
ark::Model model;
ark::Tensor t0 = model.tensor({1}, ark::FP32);
ark::Tensor t1 = model.tensor({1}, ark::FP32);

// node 0
ark::Tensor t2 = model.add(t0, t1);

ark::Tensor t3;
ark::Tensor t4;
ark::Tensor t5;
{
// node 1
ark::Context ctx(model);
ctx.set("key0", ark::Json("val1").dump());
t3 = model.relu(t2);

UNITTEST_EQ(ctx.get("key0"), ark::Json("val1").dump());

// node 2
ctx.set("key1", ark::Json("val2").dump());
t4 = model.sqrt(t3);

UNITTEST_EQ(ctx.get("key0"), ark::Json("val1").dump());
UNITTEST_EQ(ctx.get("key1"), ark::Json("val2").dump());
}
{
// node 3
ark::Context ctx(model);
ctx.set("key0", ark::Json("val3").dump());
t5 = model.exp(t2);

UNITTEST_EQ(ctx.get("key0"), ark::Json("val3").dump());
UNITTEST_EQ(ctx.get("key1"), "");
}

UNITTEST_TRUE(model.verify());

auto compressed = model.compress();
UNITTEST_TRUE(compressed.verify());

auto nodes = compressed.nodes();
UNITTEST_EQ(nodes.size(), 4);

UNITTEST_EQ(nodes[0]->context.size(), 0);
UNITTEST_EQ(nodes[1]->context.size(), 1);
UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1"));
UNITTEST_EQ(nodes[2]->context.size(), 2);
UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1"));
UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2"));
UNITTEST_EQ(nodes[3]->context.size(), 1);
UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3"));

return ark::unittest::SUCCESS;
}

ark::unittest::State test_context_invalid() {
ark::Model model;
ark::Context ctx(model);
ark::Tensor t0 = model.tensor({1}, ark::FP32);
ark::Tensor t1 = model.tensor({1}, ark::FP32);
ark::Tensor t2 = model.add(t0, t1);

UNITTEST_THROW(ctx.set("key", "val"), ark::InvalidUsageError);

return ark::unittest::SUCCESS;
}

int main() {
UNITTEST(test_context);
UNITTEST(test_context_invalid);
return 0;
}
16 changes: 12 additions & 4 deletions ark/api/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,15 @@
plan_json = Json::parse(plan);
}

auto gpu_manager = GpuManager::get_instance(gpu_id_);
if (!gpu_manager->info().arch->belongs_to(
Arch::from_name(plan_json.at("Architecture")))) {
LOG(WARN, "Architecture name of the plan `",

Check warning on line 219 in ark/api/executor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/executor.cpp#L219

Added line #L219 was not covered by tests
plan_json.at("Architecture").get<std::string>(),
"` is not compatible with the GPU architecture `",
gpu_manager->info().arch->name(), "`.");
}

buffer_id_to_offset_ = init_buffers(plan_json);

std::string buffer_id_to_offset_str;
Expand All @@ -224,7 +233,6 @@
codegen_ =
std::make_shared<CodeGenerator>(plan_json, buffer_id_to_offset_, name);

auto gpu_manager = GpuManager::get_instance(gpu_id_);
timer_begin_ = gpu_manager->create_event();
timer_end_ = gpu_manager->create_event();
buffer_ = gpu_manager->malloc(total_bytes_, 65536);
Expand Down Expand Up @@ -816,9 +824,9 @@
model.rank(), model.world_size(),
(gpu_id < 0) ? (model.rank() % get_env().num_ranks_per_host) : gpu_id,
name,
DefaultPlanner(model, (gpu_id < 0) ? (model.rank() %
get_env().num_ranks_per_host)
: gpu_id)
Planner(model, (gpu_id < 0)
? (model.rank() % get_env().num_ranks_per_host)
: gpu_id)
.plan()) {}

} // namespace ark
9 changes: 9 additions & 0 deletions ark/api/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@

namespace ark {

Model::Model(int rank, int world_size) : ModelGraph(rank, world_size) {
static size_t next_id = 0;
id_ = next_id++;
}

Model::Model(const Model &other) : ModelGraph(other), id_(other.id()) {}

size_t Model::id() const { return id_; }

Model Model::compress() const {
Model model(*this);
model.compress_nodes();
Expand Down
Loading
Loading