Skip to content

Commit

Permalink
context interface
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Aug 5, 2024
1 parent 304fa59 commit 24c5cc6
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 51 deletions.
52 changes: 52 additions & 0 deletions ark/api/context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "ark/context.hpp"

#include "logging.hpp"
#include "model/model_context_manager.hpp"
#include "model/model_graph_impl.hpp"

namespace ark {

Context::Context(Model& model)
: context_manager_(std::make_shared<ModelContextManager>(model)) {
static size_t next_id = 0;
id_ = next_id++;
}

void Context::add(const std::string& key, const std::string& value,
ContextType type) {
Json value_json;
try {
value_json = Json::parse(value);
} catch (const ::nlohmann::json::parse_error& e) {
ERR(PlanError, "Failed to parse context value as JSON: `", value, "`");
}
if (type == ContextType::ContextTypeOverwrite) {
context_manager_->add(key, value_json);
} else if (type == ContextType::ContextTypeExtend) {
auto ctx = context_manager_->get(key);
if (ctx.empty()) {
context_manager_->add(key, value_json);
} else if (!ctx.is_object() || !value_json.is_object()) {
ERR(PlanError,
"Context value must be a JSON object when type is "
"ContextTypeExtend. Key: ",
key, ", old value: ", ctx.dump(), ", new value: ", value);
} else {
for (const auto& [k, v] : value_json.items()) {
ctx[k] = v;
}
context_manager_->add(key, ctx);
}
} else if (type == ContextType::ContextTypeImmutable) {
if (!context_manager_->has(key)) {
context_manager_->add(key, value);
}
} else {
ERR(PlanError, "Unknown context type");
}
}

} // namespace ark
60 changes: 60 additions & 0 deletions ark/api/context_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "ark/context.hpp"

#include "model/model_node.hpp"
#include "unittest/unittest_utils.h"

ark::unittest::State test_context() {
ark::Model model;
ark::Tensor t0 = model.tensor({1}, ark::FP32);
ark::Tensor t1 = model.tensor({1}, ark::FP32);

// node 0
ark::Tensor t2 = model.add(t0, t1);

ark::Tensor t3;
ark::Tensor t4;
ark::Tensor t5;
{
// node 1
ark::Context ctx(model);
ctx.add("key0", ark::Json("val1").dump());
t3 = model.relu(t2);

// node 2
ctx.add("key1", ark::Json("val2").dump());
t4 = model.sqrt(t3);
}
{
// node 3
ark::Context ctx(model);
ctx.add("key0", ark::Json("val3").dump());
t5 = model.exp(t2);
}

UNITTEST_TRUE(model.verify());

auto compressed = model.compress();
UNITTEST_TRUE(compressed.verify());

auto nodes = compressed.nodes();
UNITTEST_EQ(nodes.size(), 4);

UNITTEST_EQ(nodes[0]->context.size(), 0);
UNITTEST_EQ(nodes[1]->context.size(), 1);
UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1"));
UNITTEST_EQ(nodes[2]->context.size(), 2);
UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1"));
UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2"));
UNITTEST_EQ(nodes[3]->context.size(), 1);
UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3"));

return ark::unittest::SUCCESS;
}

int main() {
UNITTEST(test_context);
return 0;
}
1 change: 1 addition & 0 deletions ark/include/ark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ark/version.hpp>
// clang-format on

#include <ark/context.hpp>
#include <ark/data_type.hpp>
#include <ark/dims.hpp>
#include <ark/error.hpp>
Expand Down
75 changes: 75 additions & 0 deletions ark/include/ark/context.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef ARK_CONTEXT_HPP
#define ARK_CONTEXT_HPP

#include <ark/model.hpp>

namespace ark {

class ModelContextManager;

enum class ContextType {
ContextTypeOverwrite,
ContextTypeExtend,
ContextTypeImmutable,
};

class Context {
public:
///
/// Construct an empty context for the given model.
///
/// @param model The model to create the context for.
///
Context(Model& model);

/// Get the ID of this context.
size_t id() const { return id_; }

///
/// Add an item to the context.
///
/// The given context item is valid for the lifetime of the context
/// object. @p `value` is assumed to be a JSON string.
/// If @p `key` is already in use by another valid context item
/// of either the same or different context object for the same model,
/// the behavior is determined by the context type @p `type` as follows.
///
/// - `ContextTypeOverwrite` (default): The existing value will be
/// replaced with the new one while the context object is alive.
/// When the context object is destroyed, the previous value will be
/// restored.
///
/// - `ContextTypeExtend`: The new value will extend the existing
/// value while the context object is alive. This type is feasible only
/// when the value represents a JSON object, which is convertible to a
/// map. If the new JSON object has a key that already exists in the
/// existing JSON object, the value of the existing key will be
/// overwritten by the new value. When the context object is destroyed,
/// the previous value will be restored.
///
/// - `ContextTypeImmutable`: The new value will be adopted only when the
/// key does not exist in the existing context or when the value of the key
/// is empty. If the key already exists, the new value will be ignored.
/// When the context object is destroyed, if the key did not exist in the
/// existing context, the key will be removed.
/// Otherwise, nothing will be changed.
///
/// @param key The key of the context item.
/// @param value The value of the context item. The value is assumed to
/// be a JSON string.
/// @param type The context type. Default is `ContextTypeOverwrite`.
///
void add(const std::string& key, const std::string& value,
ContextType type = ContextType::ContextTypeOverwrite);

private:
std::shared_ptr<ModelContextManager> context_manager_;
size_t id_;
};

} // namespace ark

#endif // ARK_CONTEXT_HPP
1 change: 1 addition & 0 deletions ark/include/ark/model_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ModelGraph {
protected:
friend class Model;
friend class ModelContextManager;
friend class Context;

class Impl;
std::unique_ptr<Impl> impl_;
Expand Down
38 changes: 12 additions & 26 deletions ark/model/model_context_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,28 @@

#include "model_context_manager.hpp"

#include "model_graph_impl.hpp"

namespace ark {

class ModelContextManager::Impl {
public:
Impl(std::shared_ptr<ModelGraphContextStack> context_stack)
: context_stack_(context_stack) {}

void add(const std::string& key, const Json& value);

~Impl();
ModelContextManager::ModelContextManager(Model& model)
: context_stack_(model.impl_->context_stack_) {}

private:
std::shared_ptr<ModelGraphContextStack> context_stack_;
std::vector<std::string> keys_;
};
ModelContextManager::~ModelContextManager() {
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
context_stack_->pop(*it);
}
}

void ModelContextManager::Impl::add(const std::string& key, const Json& value) {
void ModelContextManager::add(const std::string& key, const Json& value) {
context_stack_->push(key, value);
keys_.push_back(key);
}

ModelContextManager::Impl::~Impl() {
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
context_stack_->pop(*it);
}
bool ModelContextManager::has(const std::string& key) const {
return context_stack_->has(key);
}

ModelContextManager::ModelContextManager(Model& model)
: impl_(std::make_shared<Impl>(model.impl_->context_stack_)) {}

ModelContextManager& ModelContextManager::add(const std::string& key,
const Json& value) {
impl_->add(key, value);
return *this;
Json ModelContextManager::get(const std::string& key) const {
return context_stack_->get(key);
}

} // namespace ark
13 changes: 10 additions & 3 deletions ark/model/model_context_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <map>

#include "ark/model.hpp"
#include "model_graph_impl.hpp"
#include "model_json.hpp"

namespace ark {
Expand All @@ -15,11 +16,17 @@ class ModelContextManager {
public:
ModelContextManager(Model& model);

ModelContextManager& add(const std::string& key, const Json& value);
~ModelContextManager();

void add(const std::string& key, const Json& value);

bool has(const std::string& key) const;

Json get(const std::string& key) const;

private:
class Impl;
std::shared_ptr<Impl> impl_;
std::shared_ptr<ModelGraphContextStack> context_stack_;
std::vector<std::string> keys_;
};

} // namespace ark
Expand Down
19 changes: 12 additions & 7 deletions ark/model/model_context_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@ ark::unittest::State test_model_context_manager() {
ark::Model model;
ark::Tensor t0 = model.tensor({1}, ark::FP32);
ark::Tensor t1 = model.tensor({1}, ark::FP32);

// node 0
ark::Tensor t2 = model.add(t0, t1);

ark::Tensor t3;
ark::Tensor t4;
ark::Tensor t5;
{
// node 1
ark::ModelContextManager cm(model);
cm.add("key0", "val1");
cm.add("key0", ark::Json("val1"));
t3 = model.relu(t2);

cm.add("key1", "val2");
// node 2
cm.add("key1", ark::Json("val2"));
t4 = model.sqrt(t3);
}
{
// node 3
ark::ModelContextManager cm(model);
cm.add("key0", "val3");
cm.add("key0", ark::Json("val3"));
t5 = model.exp(t2);
}

Expand All @@ -39,12 +44,12 @@ ark::unittest::State test_model_context_manager() {

UNITTEST_EQ(nodes[0]->context.size(), 0);
UNITTEST_EQ(nodes[1]->context.size(), 1);
UNITTEST_EQ(nodes[1]->context.at("key0"), "val1");
UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1"));
UNITTEST_EQ(nodes[2]->context.size(), 2);
UNITTEST_EQ(nodes[2]->context.at("key0"), "val1");
UNITTEST_EQ(nodes[2]->context.at("key1"), "val2");
UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1"));
UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2"));
UNITTEST_EQ(nodes[3]->context.size(), 1);
UNITTEST_EQ(nodes[3]->context.at("key0"), "val3");
UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3"));

return ark::unittest::SUCCESS;
}
Expand Down
24 changes: 13 additions & 11 deletions ark/model/model_graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,23 @@ void ModelGraphContextStack::pop(const std::string &key) {
ERR(InternalError, "context stack is empty");
}
it->second.pop_back();
if (it->second.empty()) {
this->storage_.erase(it);
}
}

bool ModelGraphContextStack::has(const std::string &key) const {
return this->storage_.find(key) != this->storage_.end();
}

Json ModelGraphContextStack::get_context(const std::string &key) const {
if (this->storage_.find(key) == this->storage_.end() ||
this->storage_.at(key).empty()) {
return Json();
Json ModelGraphContextStack::get(const std::string &key) const {
if (this->has(key)) {
return *this->storage_.at(key).back();
}
return *this->storage_.at(key).back();
return Json();
}

std::map<std::string, Json> ModelGraphContextStack::get_context_all() const {
std::map<std::string, Json> ModelGraphContextStack::get_all() const {
std::map<std::string, Json> cur;
for (const auto &pair : this->storage_) {
if (!pair.second.empty()) {
Expand Down Expand Up @@ -172,10 +178,6 @@ bool ModelGraph::Impl::verify() const {
return true;
}

Json ModelGraph::Impl::get_context(const std::string &key) const {
return context_stack_->get_context(key);
}

ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) {
for (auto &tns : op->input_tensors()) {
if (tensor_to_producer_op_.find(tns) == tensor_to_producer_op_.end()) {
Expand Down Expand Up @@ -214,7 +216,7 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) {
producer->consumers.push_back(node);
}

node->context = context_stack_->get_context_all();
node->context = context_stack_->get_all();

nodes_.push_back(node);
return node;
Expand Down
Loading

0 comments on commit 24c5cc6

Please sign in to comment.