-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Higher order reverse mode automatic differentiation that work with control flow #2496
Conversation
So excited to see we reached the point of having higher-order AD!! Thanks Marisa! Will review the code on Friday night. |
Is there an usage example? |
test_ad is the usage example. the mode does not change the interface, only what is generated (the type and semantic is still the same!). |
for (const ADValue& adval : args) { | ||
call_args.push_back(adval->get<ADTensor>().forward); | ||
} | ||
auto orig = CallNode::make(op_ref, call_args, attrs, type_args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to use the real original node instead of a reconstruction? Reconstructing a node may lead to losing some information, e.g. the inferred type checked_type_
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe, but it will require big change in code structure. if such a case come up i will do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need checked_type_
in the integration with the tensor expression ad, mostly for finding out the number of the outputs of the original operation. However, I think I can get this information from other sources. Would passing and reassigning just checked_type_
be dangerous in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgrechanik-h can i just rerun type infer? right now every pass will destroy checked_type_ and rebuild from type infer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarisaKirisame Not sure what you mean, but rerunning type inference sounds like a bit of an overkill, and I'm not sure it can be done before calling the FPrimalGradient
attribute. If the checked_type_
must be reset after running the differentiation pass, then one of the solutions could be setting it before calling FPrimalGradient
to the original value and then resetting it to nullptr after FPrimalGradient
has finished, but this feels kinda hacky.
(Also currently I think that in my particular case the proper solution would be to fix the signature of FTVMCompute
so that it accept input types, not only the out_type. And this is not connected to the automatic differentiation pass.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgrechanik-h all pass (FuseOps, AD, ANF, GNF, DeadCodeElimination, FoldScaleAxis) remove the type annotation and rerun it AFAIK. I am not sure why it is an AD-specific issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarisaKirisame I think some passes may benefit from using type information, and, of course, they should use it before erasing it (or recreating the node, I don't think checked_type_
gets literally erased anywhere). In the case of the code we are currently discussing the node is recreated (and thus type information is erased) before calling to FPrimalGradient
function which could use type information if it was still there. I don't insist on fixing it if it's difficult or unnatural, because I have only one case where this might be useful, moreover in this single case it would be better to fix a completely different part of Relay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My other passes use type info too. But we just rerun type infer, and we are encoding (rerunning type infer) into pass manager too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the testcase for control flow?
Besides, neither this version nor the old first order AD supports Tuple/TupleGetItem.
@merrymercy can you open an issue to track the first order AD tuple support? |
@merrymercy this version support tuplegetitem using exprmutator. there is just no need for any code for it. |
@MarisaKirisame How about using tuple as arguments and return value? Some ops will use tuple as arguments, e.g. concatenate. This example will crash fn (%tup: Tuple[Tensor[(10, 10), float32], Tensor[(10, 10), float32]]) {
%tup.0
} |
@merrymercy I had wrote a test case using tuple, and a test case using adt, higher order function, closure, pattern matching(control flow) and recursion. do i address your issue? |
@merrymercy can you review? |
We should suppot tuple as arguments and return value in both first order and high order AD.
|
@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>; | |||
|
|||
/*! \brief AD over a program which generates a tensor output. */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the program generates a tuple of tensor as output?
ADFunction
and ADTensor
cannot cover this case.
|
Some operators use tuple as arguments (e.g. concatenate) and return value (e.g. split). I am happy to leave it to the next PR, but this feature is necessary. |
@merrymercy can we leave it in next pr then? i am working on the Partial Evaluator, and it need this branch as test case. fixing this branch mean less rebasing. |
@merrymercy can you approve if you give thumb up? |
… with control flow (apache#2496) add test remove dead code stash do it add more test
… with control flow (apache#2496) add test remove dead code stash do it add more test
… with control flow (apache#2496) add test remove dead code stash do it add more test
… with control flow (apache#2496) add test remove dead code stash do it add more test
as promised, it is simpler then the first order case, as using reference and closure in the object language(Relay) instead of the metalanguage(C++) simplify our code.
reference code is also here, but is on a seprate pr (#2489 ). we can merge this after merging #2489 .
@ZihengJiang @junrushao1994 @masahi @reminisce can you guys review?