Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Higher order reverse mode automatic differentiation that work with control flow #2496

Merged
merged 1 commit into from
Mar 4, 2019

Conversation

MarisaKirisame
Copy link
Contributor

as promised, it is simpler then the first order case, as using reference and closure in the object language(Relay) instead of the metalanguage(C++) simplify our code.
reference code is also here, but is on a seprate pr (#2489 ). we can merge this after merging #2489 .
@ZihengJiang @junrushao1994 @masahi @reminisce can you guys review?

@junrushao
Copy link
Member

So excited to see we reached the point of having higher-order AD!! Thanks Marisa!

Will review the code on Friday night.

@masahi
Copy link
Member

masahi commented Jan 23, 2019

Is there an usage example?

@MarisaKirisame
Copy link
Contributor Author

MarisaKirisame commented Jan 23, 2019

test_ad is the usage example. the mode does not change the interface, only what is generated (the type and semantic is still the same!).
@masahi

for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use the real original node instead of a reconstruction? Reconstructing a node may lead to losing some information, e.g. the inferred type checked_type_.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, but it will require big change in code structure. if such a case come up i will do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need checked_type_ in the integration with the tensor expression ad, mostly for finding out the number of the outputs of the original operation. However, I think I can get this information from other sources. Would passing and reassigning just checked_type_ be dangerous in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrechanik-h can i just rerun type infer? right now every pass will destroy checked_type_ and rebuild from type infer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarisaKirisame Not sure what you mean, but rerunning type inference sounds like a bit of an overkill, and I'm not sure it can be done before calling the FPrimalGradient attribute. If the checked_type_ must be reset after running the differentiation pass, then one of the solutions could be setting it before calling FPrimalGradient to the original value and then resetting it to nullptr after FPrimalGradient has finished, but this feels kinda hacky.

(Also currently I think that in my particular case the proper solution would be to fix the signature of FTVMCompute so that it accept input types, not only the out_type. And this is not connected to the automatic differentiation pass.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrechanik-h all pass (FuseOps, AD, ANF, GNF, DeadCodeElimination, FoldScaleAxis) remove the type annotation and rerun it AFAIK. I am not sure why it is an AD-specific issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarisaKirisame I think some passes may benefit from using type information, and, of course, they should use it before erasing it (or recreating the node, I don't think checked_type_ gets literally erased anywhere). In the case of the code we are currently discussing the node is recreated (and thus type information is erased) before calling to FPrimalGradient function which could use type information if it was still there. I don't insist on fixing it if it's difficult or unnatural, because I have only one case where this might be useful, moreover in this single case it would be better to fix a completely different part of Relay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My other passes use type info too. But we just rerun type infer, and we are encoding (rerunning type infer) into pass manager too.

Copy link
Contributor

@reminisce reminisce left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments.

src/relay/ir/alpha_equal.cc Outdated Show resolved Hide resolved
src/relay/ir/alpha_equal.cc Outdated Show resolved Hide resolved
src/relay/ir/alpha_equal.cc Outdated Show resolved Hide resolved
src/relay/pass/gradient.cc Show resolved Hide resolved
src/relay/pass/gradient.cc Outdated Show resolved Hide resolved
Copy link
Member

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the testcase for control flow?
Besides, neither this version nor the old first order AD supports Tuple/TupleGetItem.

@tqchen
Copy link
Member

tqchen commented Feb 20, 2019

@merrymercy can you open an issue to track the first order AD tuple support?

@MarisaKirisame
Copy link
Contributor Author

@merrymercy this version support tuplegetitem using exprmutator. there is just no need for any code for it.
i will add a test for tuple, and control flow.

@ZihengJiang ZihengJiang self-assigned this Feb 22, 2019
@merrymercy
Copy link
Member

merrymercy commented Feb 23, 2019

@MarisaKirisame How about using tuple as arguments and return value? Some ops will use tuple as arguments, e.g. concatenate.

This example will crash

fn (%tup: Tuple[Tensor[(10, 10), float32], Tensor[(10, 10), float32]]) {
    %tup.0
}

@MarisaKirisame
Copy link
Contributor Author

@merrymercy I had wrote a test case using tuple, and a test case using adt, higher order function, closure, pattern matching(control flow) and recursion. do i address your issue?

@MarisaKirisame
Copy link
Contributor Author

@merrymercy can you review?

@merrymercy
Copy link
Member

  1. Could you add my testcase? This example still crash.
def test_tuple_arg():                                             
    shape = (10, 10)                                              
    dtype = 'float32'                                             
    t = relay.TensorType(shape, dtype)                            
    x = relay.var("x", t)                                         
    y = relay.var("y", t)                                         
    tup = relay.var('tup', relay.TupleType([t, t]))               
    func = relay.Function([tup], relay.TupleGetItem(tup, 0))      
    print(func)                                                   
    back_func = relay.ir_pass.infer_type(gradient(func))          
    back_func = relay.ir_pass.dead_code_elimination(back_func)    
    print(back_func)                                             

We should suppot tuple as arguments and return value in both first order and high order AD.

  1. I found the generated back_func is very complicated and sometimes redundant (both in first order case and higher order case) . How do we execute them efficiently? Do we need more optimization passes or do we need a powerful runtime?

@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;

/*! \brief AD over a program which generates a tensor output. */
Copy link
Member

@merrymercy merrymercy Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the program generates a tuple of tensor as output?
ADFunction and ADTensor cannot cover this case.

@MarisaKirisame
Copy link
Contributor Author

@merrymercy

  1. the test case will not work. the interface can only use Tensor as of now, but you can use whatever you want inside them. is there any need for it? you can always flatten it before passing in. I would prefer to do it on a seprate issue as i am really busy working on the partial evaluator.
  2. i am working on a partial evaluator pass which will take care of this right now.

@merrymercy
Copy link
Member

Some operators use tuple as arguments (e.g. concatenate) and return value (e.g. split).
We have to use tuple because we don't know the number of arguments.

I am happy to leave it to the next PR, but this feature is necessary.

@MarisaKirisame
Copy link
Contributor Author

@merrymercy can we leave it in next pr then? i am working on the Partial Evaluator, and it need this branch as test case. fixing this branch mean less rebasing.

@MarisaKirisame
Copy link
Contributor Author

@merrymercy can you approve if you give thumb up?

add test

remove dead code

stash

do it

add more test
@ZihengJiang ZihengJiang merged commit eae76b3 into apache:master Mar 4, 2019
bwasti pushed a commit to facebookexperimental/tvm that referenced this pull request Mar 6, 2019
… with control flow (apache#2496)

add test

remove dead code

stash

do it

add more test
wweic pushed a commit to neo-ai/tvm that referenced this pull request Mar 9, 2019
… with control flow (apache#2496)

add test

remove dead code

stash

do it

add more test
wweic pushed a commit to neo-ai/tvm that referenced this pull request Mar 12, 2019
… with control flow (apache#2496)

add test

remove dead code

stash

do it

add more test
wweic pushed a commit to neo-ai/tvm that referenced this pull request Mar 12, 2019
… with control flow (apache#2496)

add test

remove dead code

stash

do it

add more test
@MarisaKirisame MarisaKirisame deleted the ad branch March 28, 2019 03:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants