-
Notifications
You must be signed in to change notification settings - Fork 64
[WIP] Switch to using ProfiledTensorType #68
base: master
Are you sure you want to change the base?
Conversation
{ | ||
|
||
auto csizes = ptt->sizes().concrete_sizes(); | ||
TORCH_INTERNAL_ASSERT(csizes.has_value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we handle the case when this isn't true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shouldn't just create a TVMCompGroup in this case. Even though, we can compile at a runtime if we get a workload where shapes change all the time it will be pretty wasteful. I'm looking into changing Fuser to not fuse things that don't have ProfiledTensorTypes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's alternating between 20 batch size shapes but run 1M times it's probably worth still compiling each shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hopefully, this will be handled via bailouts. We specialize for shape_set1. Then if we see another frequent set, we will specialize for that one as well and so on.
kv.first->inferTypeFrom(kv.second.toTensor()); | ||
// TODO: convince Fuser to NOT create TVMCompilationGroups | ||
// if ANY of subgraph inputs weren't profiled | ||
TORCH_INTERNAL_ASSERT(kv.first->type()->cast<ProfiledTensorType>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this, as it is done on line 111
} | ||
// bail out mechanism: try to convert to Relay, if it fails to convert the | ||
// graph by any reason(i.e. op difference), depend on the user preference, | ||
// either throw or fall back to the JIT interpreter for execution | ||
cache_ = TVMObject {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at this point, if we run a graph twice with different sizes will we still get TVM compiled code each time?
is the logic for that moved up into the profiled executor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, we will bail out, profile again and generate a graph with a TVMCompGroups for different shapes.
<< e.what() << "\n"; | ||
} | ||
|
||
if ((*cache_).invalid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this block be merged back into block on line 260?
cache_[spec].set_input = run_mod.GetFunction("set_input_zero_copy", false); | ||
cache_[spec].kernel = run_mod.GetFunction("run", false); | ||
cache_[spec].get_output = run_mod.GetFunction("get_output", false); | ||
(*cache_).set_input = run_mod.GetFunction("set_input_zero_copy", false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
->
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arrgh, thanks! I forgot to switch it
CompleteArgumentSpec spec{false, ArrayRef<IValue>(inputs)}; | ||
|
||
if (cache_.find(spec) == cache_.end()) { | ||
if (!cache_ || (cache_ && (*cache_).invalid)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cache_
is an optional, can we get rid of the invalid
attribute somehow? it's pretty confusing
@@ -35,6 +36,7 @@ PYBIND11_MODULE(_torch_tvm, m) { | |||
RegisterOperators op({Operator( | |||
tvm_sym, | |||
[](const Node* node) { | |||
GRAPH_DUMP("A graph passed to TVMCompiler\n", node->g(attr::Subgraph)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move this to a different diff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
{ | ||
sizes.push_back(HalideIR::Expr(static_cast<int32_t>(size))); | ||
} | ||
} else if (optional_ivalue.has_value()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious in which case we will not having a value type that is not a ProfiledTensorType
since we all switched to profiled graph executor?
} catch (const std::exception& e) { | ||
(*cache_).invalid = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we failing back to JIT, this means some operators are not converted successfully due to operator semantic mismatch and other behaviors, this invalid flag will let the compiler to re-run the conversion everytime for the same inputs and it will always fail, so I am not sure if that flag would be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly! we don't want to re-run if we already know it's going to fail! Re-running compilations might be pretty expensive
No description provided.