Skip to content

Commit

Permalink
Refactor the compile engine into a cleaner interface.
Browse files Browse the repository at this point in the history
Duplicate the CompileEngine interface.

Refactor the graph_runtime_codegen to invoke the new LowerTE pass

More changes

Things appear to be working

Some tracing to get Relay code to flow through too.

Disable some assertions as exp.

Tweak printing for now

Fix a few bugs: (#13)

1. Don't add relay main function to list of lowered TIR functions
2. Don't skip visiting call to relay function in graph runtime codegen

Remove debug prints.

Start refactoring

Split out shared data structures

Fix implicit duplicate decl of IsDynamic

Clean up handling of name + global prim fn

Clean up the code and debug issue introduced by previous hack

Clean up the debugging

Do C++ lint clean up

Update src/relay/backend/graph_executor_codegen.cc

Co-authored-by: Chris Sullivan <[email protected]>

Clean up handling of external functions

Add more error messages

More clean up

Update src/runtime/graph_executor/graph_executor.cc

Co-authored-by: Chris Sullivan <[email protected]>

Update src/runtime/graph_executor/graph_executor.cc

Co-authored-by: Chris Sullivan <[email protected]>

Update src/relay/backend/te_compiler.h

Co-authored-by: Haichen Shen <[email protected]>

Update src/relay/backend/te_compiler.h

Co-authored-by: Haichen Shen <[email protected]>

Fix

CR

More CR

Format

Fix lowering path for C++

Fix tests

Remove uncessary change

Clean up a few more things

CI fix

Fix the default context

Fix

Fix broken test cases

Update

Fix

WIP

Clean up storage data structures

WIP

WIP

Fix build errors

Remove TVMLower

Fix lint

Lint again

fix black

Move UpdateMainWorkspaceSize into te_compiler.cc

Fix link errors

Formatting

Change UpdateMainWorkspaceSize to return Map<String, FunctionInfo>

Workaround for GCC 5 error caused by enums in maps (GCC 5 is on i386 CI)

Testing how functions should be named

Lint

Change how function metadata is updated

Attempt to update aot_executor_codegen to use new StaticMemoryPlan instead of storage_device_map

Pass memory plan through LowerTE into UpdateMainWorkspaceSize so that we don't need to run GraphPlanMemory an extra time

Fix return in UpdateMainWorkspaceSize

Lint

Try to fix UpdateMainWorkspaceSize

Fix construction of static memory plan

Clean up code while debugging

Adding UpdateWorkspaceSize back

Add closure + call to UpdateFunctionMetadata (WIP)

UpdateFunctionMetadata builds; weird error with device ctx map though. Not sure if it came from this change or something else

Add some debugging of UpdateMainWorkspaceSize

Starting to move UpdateFunctionMetadata call to use process_fn infra

UWhat target should be passed to UpdateFunctionMetadata?

UpdateFunctionMetadata is not workinggg

Added some comments about UpdateFunctionMetadata for Jared

Fix the creation of function metadata

Try another stab at cleaning up the information

Fix

Port StorageInfo and StaticMemoryPlan data structure (apache#8297)

Restoring reshape opt

Fix tests

Caught a nasty typo from Lily, Map::Set does not mutate

Format

Disable stupid Google style warning
  • Loading branch information
jroesch committed Jul 6, 2021
1 parent bbfc52c commit 393ca27
Show file tree
Hide file tree
Showing 27 changed files with 2,226 additions and 1,124 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief A 3rd party compiler for code generation. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
1 change: 1 addition & 0 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def add_workload_input_names(self, workload_key, input_names):
@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def enter_layout_rewrite():
"""Enter layout rewrite tracing environment"""
# import pdb; pdb.set_trace()
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()

Expand Down
1 change: 1 addition & 0 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _build_function_memory_map(function_metadata):
2.) A global memory requirement if all functions are executed sequentially
"""
device_max_workspace = dict()
print("TOTAL FUNCTION METADATA: ", function_metadata)
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
num_targets = len(main_func_metadata.workspace_sizes.items())
func_entries = []
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand All @@ -444,7 +444,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar
from tvm.ir import RelayExpr, GlobalVar, Node

from .base import RelayNode
from . import _ffi_api
Expand Down Expand Up @@ -538,3 +538,18 @@ def bind(expr, binds):
The expression or function after binding.
"""
return _ffi_api.Bind(expr, binds)


@tvm._ffi.register_object("relay.StorageInfo")
class StorageInfo(Node):
@property
def storage_ids(self):
return _ffi_api.StorageInfoStorageIds(self)

@property
def device_types(self):
return _ffi_api.StorageInfoDeviceTypes(self)

@property
def storage_sizes(self):
return _ffi_api.StorageInfoStorageSizes(self)
10 changes: 7 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,18 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
}

if (target->kind->device_type == kDLCPU && target_host == target) {
ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
// TODO(@jroesch): This check is no longer true we need to figure out if we care about this.
// We need to relax this check for just TIR functions.
// ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
// << "and host_target are both llvm target."
// << "\n";
}

return {mhost, mdevice};
}

// Can we make this take one annotated IRModule?
//
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
auto pass_ctx = transform::PassContext::Current();
Expand Down
Loading

0 comments on commit 393ca27

Please sign in to comment.