Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusing result #25

Open
alecjacobson opened this issue Mar 27, 2024 · 8 comments
Open

Confusing result #25

alecjacobson opened this issue Mar 27, 2024 · 8 comments

Comments

@alecjacobson
Copy link

Love this paper :-) first time trying the code.

Here's my problem set up.
image

E = ½ ∫₀¹ ([x<a] - [x<t])² dx
∂E/∂t = t>a?+1:-1 = [0<t<1] - 2 [t<a]

I claim that ∂E/∂t = [0<t<1] - 2 [t < α] → if t is smaller (bigger) than α then I grow my energy by making it more smaller (bigger) than α. This is confirmed with F.D.

When I try this in Teg, I wrote:

from teg import TegVar, Var, Teg, IfElse
from teg.derivs import FwdDeriv
from teg.eval.numpy_eval import evaluate
x, a, t = TegVar('x'), Var('a', 0.5), Var('t', 0.25)
expr = 0.5*Teg(0, 
               1,
               (IfElse(x<t,1,0) - IfElse(x<a,1,0))**2,
               x)
deriv_expr = FwdDeriv(expr, [(t, 1)])
print(evaluate(deriv_expr))

x, a, t = TegVar('x'), Var('a', 0.5), Var('t', 0.75)
expr = 0.5*Teg(0, 
               1,
               (IfElse(x<t,1,0) - IfElse(x<a,1,0))**2,
               x)
deriv_expr = FwdDeriv(expr, [(t, 1)])
print(evaluate(deriv_expr))

Which prints

-1.0
0

I'm confused by the 0 which I expected to be 1.0

cc @squidrice21

@gilbo
Copy link

gilbo commented Mar 27, 2024 via email

@gilbo
Copy link

gilbo commented Mar 27, 2024 via email

@alecjacobson
Copy link
Author

Thanks, Gilbert. What's a good rule I can follow to be sure that I provide valid input to Teg? No nested IfElse?

@squidrice21
Copy link

Just to verify the issue with products of degenerate conditionals. By manually simplifying such conditionals into:

from teg import TegVar, Var, Teg, IfElse
from teg.derivs import FwdDeriv
from teg.eval.numpy_eval import evaluate

x, a, t = TegVar('x'), Var('a', 0.5), Var('t', 0.25)
expr = 0.5*Teg(0, 
               1,
               IfElse(x<t,1,0) - 2 * IfElse(x<t,1,0) * IfElse(x<a,1,0) + IfElse(x<a,1,0),
               x)
deriv_expr = FwdDeriv(expr, [(t, 1)])
print(evaluate(deriv_expr))

x, a, t = TegVar('x'), Var('a', 0.5), Var('t', 0.75)
expr = 0.5*Teg(0, 
               1,
               IfElse(x<t,1,0) - 2 * IfElse(x<t,1,0) * IfElse(x<a,1,0) + IfElse(x<a,1,0),
               x)
deriv_expr = FwdDeriv(expr, [(t, 1)])
print(evaluate(deriv_expr))

The script now gives the correct results:

-0.5
0.5

So a solution within the current Teg scope is to manually simplify such conditionals.

@martinjm97
Copy link
Collaborator

martinjm97 commented Mar 29, 2024

Hi @alecjacobson and @squidrice21,

Thank you for the wonderful question.

  1. Why is the answer wrong and what's a way to check this?

The derivative violates the transversality condition and therefore there is no guarantee of correct results.

In particular, when there's a multiplication of the same condition in the derivative, the answer may be wrong.
So the problem is that when you differentiate f([t > x]) you get f'([t > x])delta(t - x), which will evaluate the condition at exactly the location of the jump. If you're theoretically inclined, the reason is that Leibniz's product rule only holds for distributions that satisfy the transversality condition.

In general, that's how I think about the problem. If Teg evaluated at a jump, then the derivative might be wrong.

  1. How can a user/compiler systematically resolve this problem?

The key idea is to pull the conditional outside of the composition prior to differentiation:
f([t > x]) = [t > x]f(1) + [t <= x]f(0).
The derivative is then delta(t - x)f(1) - delta(x - t)f(0), which I believe is correct.

I'd call this hoisting conditionals. Currently, it's on the programmer to do this, but this process could be automated. Some care would need to be taken to avoid exponential explosion when hoisting terms like f([t > x] + [t + 1 > x]).

I'm happy to think about it more/help out. Feel free to lmk if you have more questions

@gilbo
Copy link

gilbo commented Mar 29, 2024 via email

@martinjm97
Copy link
Collaborator

martinjm97 commented Mar 29, 2024

@gilbo I'm talking explicitly about the case of f([x < t]), where the degeneracy arises from the chain rule. I'm not talking about taking a pair of arbitrary conditionals and checking if they're transverse.

I believe the former case can be automated while the latter case is not computable.

This is not a complete solution to all possible degeneracies, but it is a resolution for a class of degeneracies that seem to show up.

@martinjm97
Copy link
Collaborator

@alecjacobson,

To answer your question directly:

What's a good rule I can follow to be sure that I provide valid input to Teg? No nested IfElse?

For the current implementation, don't put parametric discontinuities as an input to a function (e.g. the squaring function in your example). Luckily, there's a manual rewrite for this case that can be automated (hoisting conditionals).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants