-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
240 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#include "ark/context.hpp" | ||
|
||
#include "logging.hpp" | ||
#include "model/model_context_manager.hpp" | ||
#include "model/model_graph_impl.hpp" | ||
|
||
namespace ark { | ||
|
||
Context::Context(Model& model) | ||
: context_manager_(std::make_shared<ModelContextManager>(model)) { | ||
static size_t next_id = 0; | ||
id_ = next_id++; | ||
} | ||
|
||
void Context::add(const std::string& key, const std::string& value, | ||
ContextType type) { | ||
Json value_json; | ||
try { | ||
value_json = Json::parse(value); | ||
} catch (const ::nlohmann::json::parse_error& e) { | ||
ERR(PlanError, "Failed to parse context value as JSON: `", value, "`"); | ||
} | ||
if (type == ContextType::ContextTypeOverwrite) { | ||
context_manager_->add(key, value_json); | ||
} else if (type == ContextType::ContextTypeExtend) { | ||
auto ctx = context_manager_->get(key); | ||
if (ctx.empty()) { | ||
context_manager_->add(key, value_json); | ||
} else if (!ctx.is_object() || !value_json.is_object()) { | ||
ERR(PlanError, | ||
"Context value must be a JSON object when type is " | ||
"ContextTypeExtend. Key: ", | ||
key, ", old value: ", ctx.dump(), ", new value: ", value); | ||
} else { | ||
for (const auto& [k, v] : value_json.items()) { | ||
ctx[k] = v; | ||
} | ||
context_manager_->add(key, ctx); | ||
} | ||
} else if (type == ContextType::ContextTypeImmutable) { | ||
if (!context_manager_->has(key)) { | ||
context_manager_->add(key, value); | ||
} | ||
} else { | ||
ERR(PlanError, "Unknown context type"); | ||
} | ||
} | ||
|
||
} // namespace ark |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#include "ark/context.hpp" | ||
|
||
#include "model/model_node.hpp" | ||
#include "unittest/unittest_utils.h" | ||
|
||
ark::unittest::State test_context() { | ||
ark::Model model; | ||
ark::Tensor t0 = model.tensor({1}, ark::FP32); | ||
ark::Tensor t1 = model.tensor({1}, ark::FP32); | ||
|
||
// node 0 | ||
ark::Tensor t2 = model.add(t0, t1); | ||
|
||
ark::Tensor t3; | ||
ark::Tensor t4; | ||
ark::Tensor t5; | ||
{ | ||
// node 1 | ||
ark::Context ctx(model); | ||
ctx.add("key0", ark::Json("val1").dump()); | ||
t3 = model.relu(t2); | ||
|
||
// node 2 | ||
ctx.add("key1", ark::Json("val2").dump()); | ||
t4 = model.sqrt(t3); | ||
} | ||
{ | ||
// node 3 | ||
ark::Context ctx(model); | ||
ctx.add("key0", ark::Json("val3").dump()); | ||
t5 = model.exp(t2); | ||
} | ||
|
||
UNITTEST_TRUE(model.verify()); | ||
|
||
auto compressed = model.compress(); | ||
UNITTEST_TRUE(compressed.verify()); | ||
|
||
auto nodes = compressed.nodes(); | ||
UNITTEST_EQ(nodes.size(), 4); | ||
|
||
UNITTEST_EQ(nodes[0]->context.size(), 0); | ||
UNITTEST_EQ(nodes[1]->context.size(), 1); | ||
UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1")); | ||
UNITTEST_EQ(nodes[2]->context.size(), 2); | ||
UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1")); | ||
UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2")); | ||
UNITTEST_EQ(nodes[3]->context.size(), 1); | ||
UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3")); | ||
|
||
return ark::unittest::SUCCESS; | ||
} | ||
|
||
int main() { | ||
UNITTEST(test_context); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#ifndef ARK_CONTEXT_HPP | ||
#define ARK_CONTEXT_HPP | ||
|
||
#include <ark/model.hpp> | ||
|
||
namespace ark { | ||
|
||
class ModelContextManager; | ||
|
||
enum class ContextType { | ||
ContextTypeOverwrite, | ||
ContextTypeExtend, | ||
ContextTypeImmutable, | ||
}; | ||
|
||
class Context { | ||
public: | ||
/// | ||
/// Construct an empty context for the given model. | ||
/// | ||
/// @param model The model to create the context for. | ||
/// | ||
Context(Model& model); | ||
|
||
/// Get the ID of this context. | ||
size_t id() const { return id_; } | ||
|
||
/// | ||
/// Add an item to the context. | ||
/// | ||
/// The given context item is valid for the lifetime of the context | ||
/// object. @p `value` is assumed to be a JSON string. | ||
/// If @p `key` is already in use by another valid context item | ||
/// of either the same or different context object for the same model, | ||
/// the behavior is determined by the context type @p `type` as follows. | ||
/// | ||
/// - `ContextTypeOverwrite` (default): The existing value will be | ||
/// replaced with the new one while the context object is alive. | ||
/// When the context object is destroyed, the previous value will be | ||
/// restored. | ||
/// | ||
/// - `ContextTypeExtend`: The new value will extend the existing | ||
/// value while the context object is alive. This type is feasible only | ||
/// when the value represents a JSON object, which is convertible to a | ||
/// map. If the new JSON object has a key that already exists in the | ||
/// existing JSON object, the value of the existing key will be | ||
/// overwritten by the new value. When the context object is destroyed, | ||
/// the previous value will be restored. | ||
/// | ||
/// - `ContextTypeImmutable`: The new value will be adopted only when the | ||
/// key does not exist in the existing context or when the value of the key | ||
/// is empty. If the key already exists, the new value will be ignored. | ||
/// When the context object is destroyed, if the key did not exist in the | ||
/// existing context, the key will be removed. | ||
/// Otherwise, nothing will be changed. | ||
/// | ||
/// @param key The key of the context item. | ||
/// @param value The value of the context item. The value is assumed to | ||
/// be a JSON string. | ||
/// @param type The context type. Default is `ContextTypeOverwrite`. | ||
/// | ||
void add(const std::string& key, const std::string& value, | ||
ContextType type = ContextType::ContextTypeOverwrite); | ||
|
||
private: | ||
std::shared_ptr<ModelContextManager> context_manager_; | ||
size_t id_; | ||
}; | ||
|
||
} // namespace ark | ||
|
||
#endif // ARK_CONTEXT_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.