Skip to content

Commit

Permalink
internal interface
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Aug 5, 2024
1 parent 6f7b184 commit d9bdcb6
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 49 deletions.
38 changes: 5 additions & 33 deletions ark/api/context.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "ark/context.hpp"

#include "context_impl.hpp"
#include "logging.hpp"
#include "model/model_context_manager.hpp"
#include "model/model_graph_impl.hpp"

namespace ark {

Context::Context(Model& model)
: context_manager_(std::make_shared<ModelContextManager>(model)) {
static size_t next_id = 0;
id_ = next_id++;
}
Context::Context(Model& model) : impl_(std::make_shared<Impl>(model)) {}

size_t Context::id() const { return this->impl_->id_; }

Check warning on line 11 in ark/api/context.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/context.cpp#L11

Added line #L11 was not covered by tests

void Context::set(const std::string& key, const std::string& value,
ContextType type) {
Expand All @@ -24,30 +19,7 @@ void Context::set(const std::string& key, const std::string& value,
ERR(InvalidUsageError, "Failed to parse context value as JSON: `",

Check warning on line 19 in ark/api/context.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/context.cpp#L18-L19

Added lines #L18 - L19 were not covered by tests
value, "`");
}
if (type == ContextType::ContextTypeOverwrite) {
context_manager_->set(key, value_json);
} else if (type == ContextType::ContextTypeExtend) {
auto ctx = context_manager_->get(key);
if (ctx.empty()) {
context_manager_->set(key, value_json);
} else if (!ctx.is_object() || !value_json.is_object()) {
ERR(InvalidUsageError,
"Context value must be a JSON object when type is "
"ContextTypeExtend. Key: ",
key, ", old value: ", ctx.dump(), ", new value: ", value);
} else {
for (const auto& [k, v] : value_json.items()) {
ctx[k] = v;
}
context_manager_->set(key, ctx);
}
} else if (type == ContextType::ContextTypeImmutable) {
if (!context_manager_->has(key)) {
context_manager_->set(key, value);
}
} else {
ERR(InvalidUsageError, "Unknown context type");
}
this->impl_->set(key, value_json, type);
}

} // namespace ark
5 changes: 5 additions & 0 deletions ark/api/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ark/planner.hpp"

#include "ark/model.hpp"
#include "context_impl.hpp"
#include "env.h"
#include "file_io.h"
#include "gpu/gpu_manager.h"
Expand All @@ -13,6 +14,10 @@

namespace ark {

void PlannerContext::set_sync(bool sync) {
this->impl_->set("Sync", sync, ContextType::Immutable);

Check warning on line 18 in ark/api/planner.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/planner.cpp#L17-L18

Added lines #L17 - L18 were not covered by tests
}

class DefaultPlanner::Impl {
public:
Impl(const Model &model, int gpu_id);
Expand Down
47 changes: 47 additions & 0 deletions ark/context_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "context_impl.hpp"

#include "logging.hpp"
#include "model/model_context_manager.hpp"
#include "model/model_graph_impl.hpp"

namespace ark {

Context::Impl::Impl(Model& model)
: context_manager_(std::make_shared<ModelContextManager>(model)) {
static size_t next_id = 0;
id_ = next_id++;
}

void Context::Impl::set(const std::string& key, const Json& value_json,
ContextType type) {
if (type == ContextType::Overwrite) {
context_manager_->set(key, value_json);
} else if (type == ContextType::Extend) {
auto ctx = context_manager_->get(key);
if (ctx.empty()) {
context_manager_->set(key, value_json);
} else if (!ctx.is_object() || !value_json.is_object()) {
ERR(InvalidUsageError,

Check warning on line 27 in ark/context_impl.cpp

View check run for this annotation

Codecov / codecov/patch

ark/context_impl.cpp#L22-L27

Added lines #L22 - L27 were not covered by tests
"Context value must be a JSON object when type is "
"ContextTypeExtend. Key: ",
key, ", old value: ", ctx.dump(),
", new value: ", value_json.dump());
} else {
for (const auto& [k, v] : value_json.items()) {
ctx[k] = v;

Check warning on line 34 in ark/context_impl.cpp

View check run for this annotation

Codecov / codecov/patch

ark/context_impl.cpp#L33-L34

Added lines #L33 - L34 were not covered by tests
}
context_manager_->set(key, ctx);

Check warning on line 36 in ark/context_impl.cpp

View check run for this annotation

Codecov / codecov/patch

ark/context_impl.cpp#L36

Added line #L36 was not covered by tests
}
} else if (type == ContextType::Immutable) {
if (!context_manager_->has(key)) {
context_manager_->set(key, value_json);

Check warning on line 40 in ark/context_impl.cpp

View check run for this annotation

Codecov / codecov/patch

ark/context_impl.cpp#L38-L40

Added lines #L38 - L40 were not covered by tests
}
} else {
ERR(InvalidUsageError, "Unknown context type");

Check warning on line 43 in ark/context_impl.cpp

View check run for this annotation

Codecov / codecov/patch

ark/context_impl.cpp#L43

Added line #L43 was not covered by tests
}
}

} // namespace ark
29 changes: 29 additions & 0 deletions ark/context_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef ARK_CONTEXT_IMPL_HPP_
#define ARK_CONTEXT_IMPL_HPP_

#include "ark/context.hpp"
#include "model/model_json.hpp"

namespace ark {

class ModelContextManager;

class Context::Impl {
public:
Impl(Model& model);

void set(const std::string& key, const Json& value_json, ContextType type);

protected:
friend class Context;

std::shared_ptr<ModelContextManager> context_manager_;
size_t id_;
};

} // namespace ark

#endif // ARK_CONTEXT_IMPL_HPP_
30 changes: 15 additions & 15 deletions ark/include/ark/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

namespace ark {

class ModelContextManager;

enum class ContextType {
ContextTypeOverwrite,
ContextTypeExtend,
ContextTypeImmutable,
Overwrite,
Extend,
Immutable,
};

class Context {
Expand All @@ -26,7 +24,7 @@ class Context {
Context(Model& model);

/// Get the ID of this context.
size_t id() const { return id_; }
size_t id() const;

///
/// Add an item to the context.
Expand All @@ -37,20 +35,20 @@ class Context {
/// of either the same or different context object for the same model,
/// the behavior is determined by the context type @p `type` as follows.
///
/// - `ContextTypeOverwrite` (default): The existing value will be
/// - `ContextType::Overwrite` (default): The existing value will be
/// replaced with the new one while the context object is alive.
/// When the context object is destroyed, the previous value will be
/// restored.
///
/// - `ContextTypeExtend`: The new value will extend the existing
/// - `ContextType::Extend`: The new value will extend the existing
/// value while the context object is alive. This type is feasible only
/// when the value represents a JSON object, which is convertible to a
/// map. If the new JSON object has a key that already exists in the
/// existing JSON object, the value of the existing key will be
/// overwritten by the new value. When the context object is destroyed,
/// the previous value will be restored.
///
/// - `ContextTypeImmutable`: The new value will be adopted only when the
/// - `ContextType::Immutable`: The new value will be adopted only when the
/// key does not exist in the existing context or when the value of the key
/// is empty. If the key already exists, the new value will be ignored.
/// When the context object is destroyed, if the key did not exist in the
Expand All @@ -60,23 +58,25 @@ class Context {
/// @param key The key of the context item.
/// @param value The value of the context item. The value is assumed to
/// be a JSON string. An empty JSON string is also allowed.
/// @param type The context type. Default is `ContextTypeOverwrite`.
/// @param type The context type. Default is `ContextType::Overwrite`.
///
/// @throw `InvalidUsageError` In the following cases:
///
/// - The value cannot be parsed as JSON.
///
/// - The value is not a JSON object when the context type is
/// `ContextTypeExtend`.
/// `ContextType::Extend`.
///
/// - The context type is unknown.
///
void set(const std::string& key, const std::string& value,
ContextType type = ContextType::ContextTypeOverwrite);
ContextType type = ContextType::Overwrite);

protected:
friend class PlannerContext;

private:
std::shared_ptr<ModelContextManager> context_manager_;
size_t id_;
class Impl;
std::shared_ptr<Impl> impl_;
};

} // namespace ark
Expand Down
8 changes: 7 additions & 1 deletion ark/include/ark/planner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
#ifndef ARK_PLANNER_HPP
#define ARK_PLANNER_HPP

#include <ark/context.hpp>
#include <functional>
#include <memory>
#include <string>

namespace ark {

class Model;
class PlannerContext : public Context {
public:
PlannerContext(Model &model) : Context(model) {}

void set_sync(bool sync);
};

class DefaultPlanner {
public:
Expand Down

0 comments on commit d9bdcb6

Please sign in to comment.