Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT Error encountered due to jax, hashing, and constants #1288

Closed
dpanici opened this issue Oct 2, 2024 · 29 comments · Fixed by #1441 · May be fixed by #1229
Closed

JIT Error encountered due to jax, hashing, and constants #1288

dpanici opened this issue Oct 2, 2024 · 29 comments · Fixed by #1441 · May be fixed by #1229
Assignees
Labels
bug Something isn't working optimization Adding or improving optimization methods P3 Highest Priority, someone is/should be actively working on this

Comments

@dpanici
Copy link
Collaborator

dpanici commented Oct 2, 2024

Error seems to occur when optimizing GammaC objective on gh/Gamma_c branch, happens on the second optimization step and seems related to the JIT cache? The error also only occurs if attempting an optimization at a resolution that you have previously optimized at, changing the eq resolution between steps seems to avoid this issue, so I assume it is related to the caching

MWE:

from desc import set_device

set_device("gpu")

import jax
import numpy as np

import desc.examples
from desc.continuation import solve_continuation_automatic
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import ConcentricGrid, LinearGrid
from desc.io import load
from desc.objectives import (  # FixIota,
    AspectRatio,
    Elongation,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    GammaC,
    GenericObjective,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer
from desc.plotting import plot_boozer_surface
import pdb
from desc.backend import jnp
from desc.examples import get
def run_opt_step(k, eq):
    """Run a step of the optimization example."""
    # this step will only optimize boundary modes with |m|,|n| <= k
    # we create an ObjectiveFunction, in this case made up of multiple objectives
    # which will be combined in a least squares sense

    shape_grid = LinearGrid(
        M=int(eq.M), N=int(eq.N), rho=np.array([1.0]), NFP=eq.NFP, sym=True, axis=False
    )

    ntransits = 8

    zeta_field_line = np.linspace(0, 2 * np.pi * ntransits, 64 * ntransits)
    alpha = jnp.array([0.0])
    rho = jnp.linspace(0.85, 1.0, 2)
    # rho = np.linspace(0.85, 1.0, 2)
    flux_surface_grid = LinearGrid(
        rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP
    )

    objective = ObjectiveFunction(
        (
            GammaC(
                eq=eq,
                rho=rho,
                alpha=alpha,
                deriv_mode="fwd",
                batch=False,
                weight=1e3,
                Nemov = False,
            ),
            Elongation(eq=eq, grid=shape_grid,target=1),#0 bounds=(0.5, 2.0), weight=1e3),
            GenericObjective(
                f="curvature_k2_rho",
                thing=eq,
                grid=shape_grid,
                bounds=(-75, 15),
                weight=2e3,
            ),
        ),
    )
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(
            eq,
            grid=ConcentricGrid(
                L=round(2 * eq.L),
                M=round(1.5 * eq.M),
                N=round(1.5 * eq.N),
                NFP=eq.NFP,
                sym=eq.sym,
            ),
        ),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixCurrent(eq=eq),
        FixPsi(eq=eq),
    )
    # this is the default optimizer, which re-solves the equilibrium at each step
    optimizer = Optimizer("proximal-lsq-exact")          
    eq_new, result = optimizer.optimize(
        things = eq,
        objective=objective,
        constraints=constraints,
        maxiter=3,  # we don't need to solve to optimality at each multigrid step
        verbose=3,
        copy=True,  # don't modify original, return a new optimized copy
        options={
            # Sometimes the default initial trust radius is too big, allowing the
            # optimizer to take too large a step in a bad direction. If this happens,
            # we can manually specify a smaller starting radius. Each optimizer has a
            # number of different options that can be used to tune the performance.
            # See the documentation for more info.
            "initial_trust_ratio": 1e-2,
            "maxiter": 125,
            "ftol": 1e-3,
            "xtol": 1e-8,
        },
    )
    eq_new = eq_new[0]
   
    return eq_new 

eq = get("ESTELL")
for k in np.arange(1, eq.M + 1, 1):
    if not eq.is_nested():
        print("NOT NESTED")
        assert eq.is_nested()
        break
    jax.clear_caches()
    eq = run_opt_step(k, eq)

Error:

ValueError                                Traceback (most recent call last)
Cell In[1], line 137
    135     break
    136 jax.clear_caches()
--> 137 eq = run_opt_step(k, eq)

Cell In[1], line 107, in run_opt_step(k, eq)
    103 optimizer = Optimizer("proximal-lsq-exact")
    105 print("spot 1:", type(eq))
--> 107 eq_new, result = optimizer.optimize(
    108     things = eq,
    109     objective=objective,
    110     constraints=constraints,
    111     maxiter=3,  # we don't need to solve to optimality at each multigrid step
    112     verbose=3,
    113     copy=True,  # don't modify original, return a new optimized copy
    114     options={
    115         # Sometimes the default initial trust radius is too big, allowing the
    116         # optimizer to take too large a step in a bad direction. If this happens,
    117         # we can manually specify a smaller starting radius. Each optimizer has a
    118         # number of different options that can be used to tune the performance.
    119         # See the documentation for more info.
    120         "initial_trust_ratio": 1e-2,
    121         "maxiter": 125,
    122         "ftol": 1e-3,
    123         "xtol": 1e-8,
    124     },
    125 )
    126 eq_new = eq_new[0]
    128 return eq_new

File ~/DESC/desc/optimize/optimizer.py:311, in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    307     print("Using method: " + str(self.method))
    309 timer.start("Solution time")
--> 311 result = optimizers[method]["fun"](
    312     objective,
    313     nonlinear_constraint,
    314     x0,
    315     method,
    316     x_scale,
    317     verbose,
    318     stoptol,
    319     options,
    320 )
    322 if isinstance(objective, LinearConstraintProjection):
    323     # remove wrapper to get at underlying objective
    324     result["allx"] = [objective.recover(x) for x in result["allx"]]

File ~/DESC/desc/optimize/_desc_wrappers.py:270, in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File ~/DESC/desc/optimize/least_squares.py:176, in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    173 assert in_bounds(x, lb, ub), "x0 is infeasible"
    174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
    177 nfev += 1
    178 cost = 0.5 * jnp.dot(f, f)

File ~/DESC/desc/optimize/_constraint_wrappers.py:224, in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
    208 """Compute the objective function and apply weighting / bounds.
    209 
    210 Parameters
   (...)
    221 
    222 """
    223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
    225 return f

File ~/DESC/desc/optimize/_constraint_wrappers.py:843, in ProximalProjection.compute_scaled_error(self, x, constants)
    841 constants = setdefault(constants, self.constants)
    842 xopt, _ = self._update_equilibrium(x, store=False)
--> 843 return self._objective.compute_scaled_error(xopt, constants[0])

    [... skipping hidden 6 frame]

File ~/.conda/envs/desc-env-latest/lib/python3.11/site-packages/jax/_src/pjit.py:1339, in seen_attrs_get(fun, in_type)
   1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
   1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
@dpanici dpanici added the bug Something isn't working label Oct 2, 2024
@unalmis
Copy link
Collaborator

unalmis commented Oct 2, 2024

Here are the steps that should be taken to debug

  1. Has this error ever occurred on the branch Gamma_c where this quantity is developed on? JAX version?
  2. If yes to 1, then check which recent changes to master, perhaps the Jacobian changes, have caused this

Unjitting the compute function tends to help for debugging

@f0uriest
Copy link
Member

f0uriest commented Oct 3, 2024

It seems to be unique to the bounce integral objectives, if I comment out GammaC it works fine, and if i change to EffectiveRipple it still happens.

Other things:

  • the bounce integral objectives take a loooong time to compile, even at very low resolution. Like a few minutes, compared to a few seconds for the other objectives
  • when either bounce integral objective is included it seems to make the optimizer stall out (rejects a lot of steps and exits)

These probably aren't related to the error above, but might be another source of concern

@f0uriest
Copy link
Member

f0uriest commented Oct 3, 2024

with use_jit=False (and commenting out the jit in constraint wrappers) I'm unable to reproduce

@unalmis
Copy link
Collaborator

unalmis commented Oct 3, 2024

Ok I ran optimizations leading up to ISHW, so commit 5cd7ebd should not have this issue. State of branch at that commit https://github.com/PlasmaControl/DESC/tree/5cd7ebde563258f754a0401d9da6aa143bc3376f

@unalmis
Copy link
Collaborator

unalmis commented Oct 3, 2024

with use_jit=False (and commenting out the jit in constraint wrappers) I'm unable to reproduce

There is also a jit call wrapping the compute function in _compute. When you could no longer reproduce, was this JIT call still online?

the bounce integral objectives take a loooong time to compile, even at very low resolution. Like a few minutes, compared to a few seconds for the other objectives

Aren't these compiled once? The BallooningStability objective requires less resolution than bounce integrals along a field line, but it still does a coordinate mapping inside the objective and builds transforms on the resulting grid. How does compilation time / optimization stalling compare when "low resolution" is typical resolution for BallooningStability?

the optimizer stall out (rejects a lot of steps and exits)

Can memory usage effect this? Is this forward or reverse mode? I ran forward optimizations before ISHW and did not see the optimizer exit

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 3, 2024

5cd7ebd...Gamma_c
the diff page btwn the commit Kaya mentioned and the current Gamma_c branch

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 3, 2024

I won't have time to debug tonight/tmrw, but will look more this weekend. thanks for starting to look into this so quickly though. on Gamma_c I see the same bug for both GammaC objective and EffectiveRipple

@unalmis
Copy link
Collaborator

unalmis commented Oct 27, 2024

I think this is some jax issue; and the caching suggest this is problem dependent. In any case, I suggest trying on #1290, and if the issue disappears then can mark this resolved.

The objectives there use an optimization step independent transforms grid, so that might solve that caching issue you came across.

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 28, 2024

Same error occurs in #1290 , once I find the specific cause I can commit a fix

@dpanici dpanici added the P3 Highest Priority, someone is/should be actively working on this label Nov 11, 2024
dpanici added a commit that referenced this issue Nov 23, 2024
@unalmis
Copy link
Collaborator

unalmis commented Nov 23, 2024

I accidentally ran the the tutorial's optimization cell block 6 another time after the optimization completes successfully, and I get the same error. JIT caching is not done there, and the block is self-contained so the second run is a completely new optimization, not a second step, so I am uncertain if it is related to this issue.

The error message suggests jax is getting an array with different dimension than it expects, so flattening all inputs from tuples and higher dim arrays to 1D arrays before they reach the objective function, in particular those in constants, avoided the issue for some reason.

The omnigenity objective also passes in a 2D array in constants, so it might have the same issue, and could be worth looking into how 2D arrays are interpreted in the compute scaled error functions.

@unalmis unalmis linked a pull request Nov 24, 2024 that will close this issue
@dpanici
Copy link
Collaborator Author

dpanici commented Nov 24, 2024

Yep the fix in #1229 is actually pretty simple, Greta basically changed the way rho is passed from being an array which is in constants to instead being through the nodes attribute of a LinearGrid, I asked her to make the same change in ripple branch and the branch where you implemented the objectives using the 2D interpolated version of the bounce functions as well once she narrows down that this specific change in the code was the one which fixes the bug.

@dpanici
Copy link
Collaborator Author

dpanici commented Nov 24, 2024

Hm what is the error message actually? that seems different than what we get (ours is an np logical array, not quite related to shape mismatches which is what yours sounds like? this could be a separate issue?)

@dpanici
Copy link
Collaborator Author

dpanici commented Nov 24, 2024

I accidentally ran the the tutorial's optimization cell block 6 another time after the optimization completes successfully, and I get the same error. JIT caching is not done there, and the block is self-contained so the second run is a completely new optimization, not a second step, so I am uncertain if it is related to this issue. The error message suggests jax is getting an array with different dimension than it expects, so flattening all inputs from tuples and higher dim arrays to 1D arrays before they reach the objective function, in particular those in constants, avoided the issue for some reason.

The omnigenity objective also passes in a 2D array in constants, so it might have the same issue, and could be worth looking into how 2D arrays are interpreted in the compute scaled error functions.

I think even if the block is self-contained, running it again will still find that there is a cached jitted version of the ObjectiveFunction.compute method and attempt to see if it can re-use it, and in the check of the cache is where we found we would get the error. So not that it is a second step, just anytime the same resolution eq (with same grids etc) is used to build and then compile an ObjectiveFunction. The test I have here shows the kind of thing we would find fails before fixing the issue (though I guess the test is not re-instantiating the objective, but I know that even re-instantiating it would cause the bug as if you were to previously run pytest on the two tests in that file when they used the same res eq, the second would fail with this bug, because the first test's cached jitted obective compute was attempted to be used by JAX, but in checking if it is compatible, it would throw the numpy logical error above)

@unalmis unalmis added the optimization Adding or improving optimization methods label Nov 24, 2024
@dpanici
Copy link
Collaborator Author

dpanici commented Nov 25, 2024

Check if @rahulgaur104 's ballooning objective is also affected by this

@gretahibbard
Copy link
Collaborator

Fixed by storing rho, alpha, zeta in a LinearGrid instead of separate arrays. Where rho, alpha, zeta, etc are accessed, now passed by indexing array

@dpanici dpanici reopened this Dec 2, 2024
@dpanici
Copy link
Collaborator Author

dpanici commented Dec 2, 2024

@gretahibbard Can you wait to close until the necessary changes are made to #1003 , #1042 and #1290 ?

dpanici added a commit that referenced this issue Dec 3, 2024
gretahibbard pushed a commit that referenced this issue Dec 3, 2024
@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

Some other things I've noticed (running on ripple branch #1003 ):

This is a cell I run in jupyter, when I run it from a clean start it actually runs fine (this is without the fix Greta had used):

import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
eq = desc.examples.get("HELIOTRON")

constraints = ()  
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)

grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95]),batch=False

)
objective = ObjectiveFunction((obj,))

optimizer = Optimizer("lsq-exact")
(eq, ), _ = optimizer.optimize(
    (eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
)
(eq, ), _ = optimizer.optimize(
    (eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
)

However if I run the cell a second time, I get the error we expect (and which I had expected to see on the second optimizer.optimize call...)

Building objective: Effective ripple
Precomputing transforms
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Number of parameters: 1795
Number of objectives: 2

Starting optimization
Using method: lsq-exact
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 28
     25 objective = ObjectiveFunction((obj,))
     27 optimizer = Optimizer("lsq-exact")
---> 28 (eq, ), _ = optimizer.optimize(
     29     (eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
     30 )
     31 (eq, ), _ = optimizer.optimize(
     32     (eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
     33 )

File [~/Research/DESC/desc/optimize/optimizer.py:308](http://localhost:8888/~/Research/DESC/desc/optimize/optimizer.py#line=307), in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    304     print("Using method: " + str(self.method))
    306 timer.start("Solution time")
--> 308 result = optimizers[method]["fun"](
    309     objective,
    310     nonlinear_constraint,
    311     x0,
    312     method,
    313     x_scale,
    314     verbose,
    315     stoptol,
    316     options,
    317 )
    319 if isinstance(objective, LinearConstraintProjection):
    320     # remove wrapper to get at underlying objective
    321     result["allx"] = [objective.recover(x) for x in result["allx"]]

File [~/Research/DESC/desc/optimize/_desc_wrappers.py:270](http://localhost:8888/~/Research/DESC/desc/optimize/_desc_wrappers.py#line=269), in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File [~/Research/DESC/desc/optimize/least_squares.py:176](http://localhost:8888/~/Research/DESC/desc/optimize/least_squares.py#line=175), in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    173 assert in_bounds(x, lb, ub), "x0 is infeasible"
    174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
    177 nfev += 1
    178 cost = 0.5 * jnp.dot(f, f)

File [~/Research/DESC/desc/optimize/_constraint_wrappers.py:224](http://localhost:8888/~/Research/DESC/desc/optimize/_constraint_wrappers.py#line=223), in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
    208 """Compute the objective function and apply weighting [/](http://localhost:8888/) bounds.
    209 
    210 Parameters
   (...)
    221 
    222 """
    223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
    225 return f

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

more weirdly, I notice that I get the original (more informative) cache error if before I run the cell a second time, I run jax.clear_caches() (the error happens during the .optimize call still but it reveals the jax cache calls)

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

Some more testing (all still on #1003 , without the fix Greta has, i.e. #1003 as of commit hash 533793e), running the script below gets the error

# running this has error
import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
import jax
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)

constraints = ()  
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)

grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95,0.99])

)
objective = ObjectiveFunction((obj,))

optimizer = Optimizer("lsq-exact")
print("start first optimize call")
(eq, ), _ = optimizer.optimize(
    (eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
)


# jax.clear_caches()
####################
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)

constraints = ()  
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)

# grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95,0.99])

)
objective = ObjectiveFunction((obj,))

optimizer = Optimizer("lsq-exact")
print("start second optimize call")
(eq, ), _ = optimizer.optimize(
    (eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
)

BUT, running the same script (on a clean python process) after commenting out the second eq definition (eq = desc.examples.get("HELIOTRON")) does NOT yield an error

# running this cell has NO error
import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
import jax
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)

constraints = ()  
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)

grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95,0.99])

)
objective = ObjectiveFunction((obj,))

optimizer = Optimizer("lsq-exact")
print("start first optimize call")
(eq, ), _ = optimizer.optimize(
    (eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
)


# jax.clear_caches()
####################
# eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)

constraints = ()  
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)

# grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95,0.99])

)
objective = ObjectiveFunction((obj,))

optimizer = Optimizer("lsq-exact")
print("start second optimize call")
(eq, ), _ = optimizer.optimize(
    (eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
)

Since the error is something about comparing numpy arrays to eachother, and I assume this comparison happens somewhere relating to hashing, maybe it is actually some part of the eq that is the issue? Here, when I re-load the eq, the objective should be able to re-use the jitted function, assuming it checks the current things for the objective against the cached jit function and sees that they are compatible shape-wise.

BUT, the error does not happen if I replace EffectiveRipple with ForceBalance,or with PlasmaVesselDistance (two that I tried, the latter because of the issue #1412 as it contains a 2D array in constants). So it must be something unique to the way EffectiveRipple is

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

And like @unalmis had already mentioned, the same error occurs even if I replace the entirety of EffectiveRipple.compute with return np.ones_like(self._dim_f)

@unalmis
Copy link
Collaborator

unalmis commented Dec 4, 2024

BUT, the error does not happen if I replace EffectiveRipple with ForceBalance

Just in case I should point out that the "fix" from #1412 is not on #1003 ; as the objective there still stores a tuple in constants for the quadrature. In terms of actually figuring out the source of the error, I think the equilibrium hashing is makes more sense, because I checked how we compute derivatives and objectives recently and I didn't see anything that would indicate constants was being used incorrectly.

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

Some further sleuthing: I remember now why I pointed my finger at rho: if you just change the np calls in EffectiveRipple in its current state to jnp calls, the error you get is more readable and you can actually see the offending object size, which is the exact same size as the rho array (this is robust if I change the rho size, and does not match any other array sizes at least in self._constants).

# this is with original ripple BUT replaces all np with jnp calls. I expect an error and the arr.size to be 3
# running this cell has error
import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
import jax
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)

constraints = ()  
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
jax.clear_caches()
grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95,0.99])

)
obj = ObjectiveFunction(obj)
obj.build()
print(obj.compute_scaled_error(obj.x(eq)))


jax.clear_caches()
####################
eq = eq.copy()
print("start second optimize call")
obj = EffectiveRipple(
    eq=eq,
    target=0.25,
    num_transit=1,
    num_pitch=2,
    num_quad=2,
    knots_per_transit=4,rho=np.array([0.9,0.95,0.99])

)
obj = ObjectiveFunction(obj)
obj.build()
obj.compute_scaled_error(obj.x(eq))

has error (and I called %pdb before running the cell so I could use the debugger to see arr.size

DESC version 0.12.3+1138.g533793e27.dirty,using JAX backend, jax version=0.4.31, jaxlib version=0.4.30, dtype=float64
Using device: CPU, with 2.95 GB available memory
[/Users/dpanici/Research/DESC/desc/utils.py:554](http://localhost:8888/Users/dpanici/Research/DESC/desc/utils.py#line=553): UserWarning: Reducing radial (L) resolution can make plasma boundary inconsistent. Recommend calling `eq.surface = eq.get_surface_at(rho=1.0)`
  warnings.warn(colored(msg, "yellow"), err)
Building objective: Effective ripple
Precomputing transforms
[-0.25 -0.25 -0.25]
start second optimize call
Building objective: Effective ripple
Precomputing transforms
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 49
     47 obj = ObjectiveFunction(obj)
     48 obj.build()
---> 49 obj.compute_scaled_error(obj.x(eq))

    [... skipping hidden 7 frame]

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py:278](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py#line=277), in ArrayImpl.__bool__(self)
    277 def __bool__(self):
--> 278   core.check_bool_conversion(self)
    279   return bool(self._value)

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py:667](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py#line=666), in check_bool_conversion(arr)
    664   raise ValueError("The truth value of an empty array is ambiguous. Use"
    665                    " `array.size > 0` to check that an array is not empty.")
    666 if arr.size > 1:
--> 667   raise ValueError("The truth value of an array with more than one element"
    668                    " is ambiguous. Use a.any() or a.all()")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
> /Users/dpanici/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py(667)check_bool_conversion()
    665                      " `array.size > 0` to check that an array is not empty.")
    666   if arr.size > 1:
--> 667     raise ValueError("The truth value of an array with more than one element"
    668                      " is ambiguous. Use a.any() or a.all()")
    669 

ipdb>  arr.size
3
ipdb>  exit

Still not clear to me really why the error only happens if I build with a new equilibrium object though, unless that is what triggers the cache checking and what I assume is an equality check on what is a jnp.array object which results in this error

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

with no hidden frames:

env: JAX_TRACEBACK_FILTERING=off
DESC version 0.12.3+1138.g533793e27.dirty,using JAX backend, jax version=0.4.31, jaxlib version=0.4.30, dtype=float64
Using device: CPU, with 3.04 GB available memory
[/Users/dpanici/Research/DESC/desc/utils.py:554](http://localhost:8888/Users/dpanici/Research/DESC/desc/utils.py#line=553): UserWarning: Reducing radial (L) resolution can make plasma boundary inconsistent. Recommend calling `eq.surface = eq.get_surface_at(rho=1.0)`
  warnings.warn(colored(msg, "yellow"), err)
Building objective: Effective ripple
Precomputing transforms
[-0.25 -0.25 -0.25]
start second optimize call
Building objective: Effective ripple
Precomputing transforms
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 51
     49 obj = ObjectiveFunction(obj)
     50 obj.build()
---> 51 obj.compute_scaled_error(obj.x(eq))

    [... skipping hidden 1 frame]

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:332](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=331), in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
    330 @api_boundary
    331 def cache_miss(*args, **kwargs):
--> 332   outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
    333       fun, jit_info, *args, **kwargs)
    334   executable = _read_most_recent_pjit_call_executable(jaxpr)
    335   pgle_profiler = _read_pgle_profiler(jaxpr)

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:180](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=179), in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    179 def _python_pjit_helper(fun, jit_info, *args, **kwargs):
--> 180   p, args_flat = _infer_params(fun, jit_info, args, kwargs)
    182   for arg in args_flat:
    183     dispatch.check_arg(arg)

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:729](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=728), in _infer_params(fun, ji, args, kwargs)
    726     skip_cache = True
    728 if skip_cache:
--> 729   p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
    730                                     kwargs, in_avals=None)
    731   return p, p.consts + args_flat
    733 entry = _infer_params_cached(
    734     fun, ji, signature, avals, pjit_mesh, resource_env)

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:632](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=631), in _infer_params_impl(***failed resolving arguments***)
    625   in_type = in_avals
    627 in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
    628     in_shardings_treedef, in_shardings_leaves,
    629     ji.in_layouts_treedef, ji.in_layouts_leaves,
    630     in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
--> 632 attr_token = _attr_token(flat_fun, in_type)
    633 jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
    634     flat_fun, in_type, attr_token, dbg,
    635     HashableFunction(res_paths, closure=()),
    636     IgnoreKey(ji.inline))
    637 _attr_update(flat_fun, in_type, attr_token, attrs_tracked)

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:1346](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=1345), in _attr_token(fun, in_type)
   1341 def _attr_token(
   1342     fun: lu.WrappedFun,
   1343     in_type: core.InputType | tuple[core.AbstractValue, ...]
   1344 ) -> int:
   1345   from jax.experimental.attrs import jax_getattr
-> 1346   cases = seen_attrs_get(fun, in_type)
   1347   for i, records in enumerate(cases):
   1348     for obj, attr, treedef, avals in records:

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:1339](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=1338), in seen_attrs_get(fun, in_type)
   1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
   1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py:278](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py#line=277), in ArrayImpl.__bool__(self)
    277 def __bool__(self):
--> 278   core.check_bool_conversion(self)
    279   return bool(self._value)

File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py:667](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py#line=666), in check_bool_conversion(arr)
    664   raise ValueError("The truth value of an empty array is ambiguous. Use"
    665                    " `array.size > 0` to check that an array is not empty.")
    666 if arr.size > 1:
--> 667   raise ValueError("The truth value of an array with more than one element"
    668                    " is ambiguous. Use a.any() or a.all()")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

Perhaps more confusingly, if you replace ONLY self._constants["rho"] with a LinearGrid(rho=rho) (and change subsequent accesses to it with self._constants["rho"].nodes[:,0]), everything works fine and the error above is not received (even though alpha and zeta are still arrays)

@gretahibbard So the minimal fix to avoid this error seems to be doing the above, perhaps this is then what we should do as it avoids wastefully creating a tensor product LinearGrid in rho, alpha, zeta that we only use to ferry around the unique rho array

Still am not satisfied with not understanding this bug exactly, and why rho is the only array that matters and not alpha or zeta even though they both are present in the init and constants, and alpha is passed into the init just like rho is

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

Last comment for tonight: Even if I replace every instance of constants["rho"] with a hard-coded jnp.array in the objective, if the passed-in rho array is placed in constants["rho"] in the init, it will have the error... and if I just don't put rho in the constants at all, then the error still occurs, but the offending array is actually the same length as alpha , or if alpha is a length 1 array, then the offending array is the same length as zeta (even though with the rho grid fix, these can be arrays and the error won't occur).

I give up on understanding for tonight but at least there is an even more minimal fix identified and a finger to point at hashing, and a way to get more info when debugging this sort of thing documented now

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 4, 2024

This is a seemingly similar error from JAX: jax-ml/jax#20466

@unalmis unalmis changed the title JIT Error encountered when optimizing GammaC JIT Error encountered due to jax, hashing, and constants Dec 6, 2024
@dpanici
Copy link
Collaborator Author

dpanici commented Dec 9, 2024

Try this again without cache clearing

@f0uriest
Copy link
Member

f0uriest commented Dec 9, 2024

I think the issue is that quad_weights is 1 not 1.0.

When it's an integer it gets default marked as static, and because its in a dict the entire constants dict gets marked as static so gets treated as part of the hashable pytree stucture, but the arrays rho etc aren't hashable, hence the issue.

Simple fix is just to make sure its a float, I'm also looking at simplifying some of the pytree stuff to avoid this in the future.

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 9, 2024

@gretahibbard

@unalmis
Copy link
Collaborator

unalmis commented Dec 20, 2024

@f0uriest after making the quad weights change, the issue is now

@unalmis unalmis closed this as completed Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working optimization Adding or improving optimization methods P3 Highest Priority, someone is/should be actively working on this
Projects
None yet
4 participants