Skip to content

Commit

Permalink
Expose attrs argument of "ir.IRModule" to Rust bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg authored and tqchen committed Apr 6, 2023
1 parent ff5118f commit b228037
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
16 changes: 11 additions & 5 deletions rust/tvm/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::runtime::array::Array;
use crate::runtime::function::Result;
use crate::runtime::map::Map;
use crate::runtime::string::String as TVMString;
use crate::runtime::{external, IsObjectRef, Object};
use crate::runtime::{external, IsObjectRef, Object, ObjectRef};

use super::expr::GlobalVar;
use super::function::BaseFunc;
Expand Down Expand Up @@ -62,7 +62,7 @@ external! {
#[name("relay.parser.ParseExpr")]
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
#[name("ir.IRModule")]
fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>, attrs: Map<TVMString, ObjectRef>) -> IRModule;
// Module methods
#[name("ir.Module_Add")]
fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule;
Expand Down Expand Up @@ -99,18 +99,24 @@ external! {
// Note: we don't expose update here as update is going to be removed.

impl IRModule {
pub fn new<'a, F, T>(funcs: F, types: T) -> Result<IRModule>
pub fn new<'a, F, T, A>(funcs: F, types: T, attrs: A) -> Result<IRModule>
where
F: IntoIterator<Item = (&'a GlobalVar, &'a BaseFunc)>,
T: IntoIterator<Item = (&'a GlobalTypeVar, &'a TypeData)>,
A: IntoIterator<Item = (&'a TVMString, &'a ObjectRef)>,
{
module_new(Map::from_iter(funcs), Map::from_iter(types))
module_new(
Map::from_iter(funcs),
Map::from_iter(types),
Map::from_iter(attrs),
)
}

pub fn empty() -> Result<IRModule> {
let funcs = HashMap::<GlobalVar, BaseFunc>::new();
let types = HashMap::<GlobalTypeVar, TypeData>::new();
IRModule::new(funcs.iter(), types.iter())
let attrs = HashMap::<TVMString, ObjectRef>::new();
IRModule::new(funcs.iter(), types.iter(), attrs.iter())
}

pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
Expand Down
16 changes: 15 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,21 @@ TVM_REGISTER_NODE_TYPE(IRModuleNode);

TVM_REGISTER_GLOBAL("ir.IRModule")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); });
tvm::ObjectRef attrs) {
auto dict_attrs = [&attrs]() {
if (!attrs.defined()) {
return DictAttrs();
} else if (auto* as_dict_attrs = attrs.as<tvm::DictAttrsNode>()) {
return GetRef<tvm::DictAttrs>(as_dict_attrs);
} else if (attrs.as<tvm::MapNode>()) {
return tvm::DictAttrs(Downcast<Map<String, ObjectRef>>(attrs));
} else {
LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map<String,ObjectRef>";
}
}();

return IRModule(funcs, types, {}, {}, dict_attrs);
});

TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
Expand Down

0 comments on commit b228037

Please sign in to comment.