-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT Error encountered due to jax, hashing, and constants
#1288
Comments
Here are the steps that should be taken to debug
Unjitting the compute function tends to help for debugging |
It seems to be unique to the bounce integral objectives, if I comment out GammaC it works fine, and if i change to EffectiveRipple it still happens. Other things:
These probably aren't related to the error above, but might be another source of concern |
with |
Ok I ran optimizations leading up to ISHW, so commit 5cd7ebd should not have this issue. State of branch at that commit https://github.com/PlasmaControl/DESC/tree/5cd7ebde563258f754a0401d9da6aa143bc3376f |
There is also a jit call wrapping the compute function in
Aren't these compiled once? The
Can memory usage effect this? Is this forward or reverse mode? I ran forward optimizations before ISHW and did not see the optimizer exit |
5cd7ebd...Gamma_c |
I won't have time to debug tonight/tmrw, but will look more this weekend. thanks for starting to look into this so quickly though. on |
I think this is some jax issue; and the caching suggest this is problem dependent. In any case, I suggest trying on #1290, and if the issue disappears then can mark this resolved. The objectives there use an optimization step independent transforms grid, so that might solve that caching issue you came across. |
Same error occurs in #1290 , once I find the specific cause I can commit a fix |
I accidentally ran the the tutorial's optimization cell block 6 another time after the optimization completes successfully, and I get the same error. JIT caching is not done there, and the block is self-contained so the second run is a completely new optimization, not a second step, so I am uncertain if it is related to this issue. The error message suggests jax is getting an array with different dimension than it expects, so flattening all inputs from tuples and higher dim arrays to 1D arrays before they reach the objective function, in particular those in The omnigenity objective also passes in a 2D array in constants, so it might have the same issue, and could be worth looking into how 2D arrays are interpreted in the compute scaled error functions. |
Yep the fix in #1229 is actually pretty simple, Greta basically changed the way |
Hm what is the error message actually? that seems different than what we get (ours is an np logical array, not quite related to shape mismatches which is what yours sounds like? this could be a separate issue?) |
I think even if the block is self-contained, running it again will still find that there is a cached jitted version of the |
Check if @rahulgaur104 's ballooning objective is also affected by this |
Fixed by storing rho, alpha, zeta in a LinearGrid instead of separate arrays. Where rho, alpha, zeta, etc are accessed, now passed by indexing array |
@gretahibbard Can you wait to close until the necessary changes are made to #1003 , #1042 and #1290 ? |
Some other things I've noticed (running on This is a cell I run in jupyter, when I run it from a clean start it actually runs fine (this is without the fix Greta had used): import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
eq = desc.examples.get("HELIOTRON")
constraints = ()
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95]),batch=False
)
objective = ObjectiveFunction((obj,))
optimizer = Optimizer("lsq-exact")
(eq, ), _ = optimizer.optimize(
(eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
)
(eq, ), _ = optimizer.optimize(
(eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
) However if I run the cell a second time, I get the error we expect (and which I had expected to see on the second Building objective: Effective ripple
Precomputing transforms
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Number of parameters: 1795
Number of objectives: 2
Starting optimization
Using method: lsq-exact
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 28
25 objective = ObjectiveFunction((obj,))
27 optimizer = Optimizer("lsq-exact")
---> 28 (eq, ), _ = optimizer.optimize(
29 (eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
30 )
31 (eq, ), _ = optimizer.optimize(
32 (eq, ), objective, constraints, verbose=1, maxiter=2, ftol=0, xtol=1e-9
33 )
File [~/Research/DESC/desc/optimize/optimizer.py:308](http://localhost:8888/~/Research/DESC/desc/optimize/optimizer.py#line=307), in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
304 print("Using method: " + str(self.method))
306 timer.start("Solution time")
--> 308 result = optimizers[method]["fun"](
309 objective,
310 nonlinear_constraint,
311 x0,
312 method,
313 x_scale,
314 verbose,
315 stoptol,
316 options,
317 )
319 if isinstance(objective, LinearConstraintProjection):
320 # remove wrapper to get at underlying objective
321 result["allx"] = [objective.recover(x) for x in result["allx"]]
File [~/Research/DESC/desc/optimize/_desc_wrappers.py:270](http://localhost:8888/~/Research/DESC/desc/optimize/_desc_wrappers.py#line=269), in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
267 options.setdefault("initial_trust_ratio", 0.1)
268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
271 objective.compute_scaled_error,
272 x0=x0,
273 jac=objective.jac_scaled_error,
274 args=(objective.constants,),
275 x_scale=x_scale,
276 ftol=stoptol["ftol"],
277 xtol=stoptol["xtol"],
278 gtol=stoptol["gtol"],
279 maxiter=stoptol["maxiter"],
280 verbose=verbose,
281 callback=None,
282 options=options,
283 )
284 return result
File [~/Research/DESC/desc/optimize/least_squares.py:176](http://localhost:8888/~/Research/DESC/desc/optimize/least_squares.py#line=175), in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
173 assert in_bounds(x, lb, ub), "x0 is infeasible"
174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
177 nfev += 1
178 cost = 0.5 * jnp.dot(f, f)
File [~/Research/DESC/desc/optimize/_constraint_wrappers.py:224](http://localhost:8888/~/Research/DESC/desc/optimize/_constraint_wrappers.py#line=223), in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
208 """Compute the objective function and apply weighting [/](http://localhost:8888/) bounds.
209
210 Parameters
(...)
221
222 """
223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
225 return f
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() more weirdly, I notice that I get the original (more informative) cache error if before I run the cell a second time, I run |
Some more testing (all still on #1003 , without the fix Greta has, i.e. #1003 as of commit hash 533793e), running the script below gets the error # running this has error
import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
import jax
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)
constraints = ()
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95,0.99])
)
objective = ObjectiveFunction((obj,))
optimizer = Optimizer("lsq-exact")
print("start first optimize call")
(eq, ), _ = optimizer.optimize(
(eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
)
# jax.clear_caches()
####################
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)
constraints = ()
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
# grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95,0.99])
)
objective = ObjectiveFunction((obj,))
optimizer = Optimizer("lsq-exact")
print("start second optimize call")
(eq, ), _ = optimizer.optimize(
(eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
) BUT, running the same script (on a clean python process) after commenting out the second eq definition ( # running this cell has NO error
import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
import jax
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)
constraints = ()
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95,0.99])
)
objective = ObjectiveFunction((obj,))
optimizer = Optimizer("lsq-exact")
print("start first optimize call")
(eq, ), _ = optimizer.optimize(
(eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
)
# jax.clear_caches()
####################
# eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)
constraints = ()
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
# grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95,0.99])
)
objective = ObjectiveFunction((obj,))
optimizer = Optimizer("lsq-exact")
print("start second optimize call")
(eq, ), _ = optimizer.optimize(
(eq, ), objective, constraints, verbose=0, maxiter=2, ftol=0, xtol=1e-9
) Since the error is something about comparing numpy arrays to eachother, and I assume this comparison happens somewhere relating to hashing, maybe it is actually some part of the eq that is the issue? Here, when I re-load the eq, the objective should be able to re-use the jitted function, assuming it checks the current things for the objective against the cached jit function and sees that they are compatible shape-wise. BUT, the error does not happen if I replace |
And like @unalmis had already mentioned, the same error occurs even if I replace the entirety of |
Just in case I should point out that the "fix" from #1412 is not on #1003 ; as the objective there still stores a tuple in |
Some further sleuthing: I remember now why I pointed my finger at # this is with original ripple BUT replaces all np with jnp calls. I expect an error and the arr.size to be 3
# running this cell has error
import desc.examples
from desc.objectives import ObjectiveFunction, EffectiveRipple
from desc.grid import LinearGrid
from desc.optimize import Optimizer
import numpy as np
import jax
eq = desc.examples.get("HELIOTRON")
eq.change_resolution(3,3,3,4,4,4)
constraints = ()
# circular surface
a = 0.5
R0 = 10
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
jax.clear_caches()
grid = LinearGrid(M=eq.M , N=eq.N, NFP=eq.NFP)
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95,0.99])
)
obj = ObjectiveFunction(obj)
obj.build()
print(obj.compute_scaled_error(obj.x(eq)))
jax.clear_caches()
####################
eq = eq.copy()
print("start second optimize call")
obj = EffectiveRipple(
eq=eq,
target=0.25,
num_transit=1,
num_pitch=2,
num_quad=2,
knots_per_transit=4,rho=np.array([0.9,0.95,0.99])
)
obj = ObjectiveFunction(obj)
obj.build()
obj.compute_scaled_error(obj.x(eq)) has error (and I called %pdb before running the cell so I could use the debugger to see DESC version 0.12.3+1138.g533793e27.dirty,using JAX backend, jax version=0.4.31, jaxlib version=0.4.30, dtype=float64
Using device: CPU, with 2.95 GB available memory
[/Users/dpanici/Research/DESC/desc/utils.py:554](http://localhost:8888/Users/dpanici/Research/DESC/desc/utils.py#line=553): UserWarning: Reducing radial (L) resolution can make plasma boundary inconsistent. Recommend calling `eq.surface = eq.get_surface_at(rho=1.0)`
warnings.warn(colored(msg, "yellow"), err)
Building objective: Effective ripple
Precomputing transforms
[-0.25 -0.25 -0.25]
start second optimize call
Building objective: Effective ripple
Precomputing transforms
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 49
47 obj = ObjectiveFunction(obj)
48 obj.build()
---> 49 obj.compute_scaled_error(obj.x(eq))
[... skipping hidden 7 frame]
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py:278](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py#line=277), in ArrayImpl.__bool__(self)
277 def __bool__(self):
--> 278 core.check_bool_conversion(self)
279 return bool(self._value)
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py:667](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py#line=666), in check_bool_conversion(arr)
664 raise ValueError("The truth value of an empty array is ambiguous. Use"
665 " `array.size > 0` to check that an array is not empty.")
666 if arr.size > 1:
--> 667 raise ValueError("The truth value of an array with more than one element"
668 " is ambiguous. Use a.any() or a.all()")
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
> /Users/dpanici/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py(667)check_bool_conversion()
665 " `array.size > 0` to check that an array is not empty.")
666 if arr.size > 1:
--> 667 raise ValueError("The truth value of an array with more than one element"
668 " is ambiguous. Use a.any() or a.all()")
669
ipdb> arr.size
3
ipdb> exit Still not clear to me really why the error only happens if I build with a new equilibrium object though, unless that is what triggers the cache checking and what I assume is an equality check on what is a jnp.array object which results in this error |
with no hidden frames: env: JAX_TRACEBACK_FILTERING=off
DESC version 0.12.3+1138.g533793e27.dirty,using JAX backend, jax version=0.4.31, jaxlib version=0.4.30, dtype=float64
Using device: CPU, with 3.04 GB available memory
[/Users/dpanici/Research/DESC/desc/utils.py:554](http://localhost:8888/Users/dpanici/Research/DESC/desc/utils.py#line=553): UserWarning: Reducing radial (L) resolution can make plasma boundary inconsistent. Recommend calling `eq.surface = eq.get_surface_at(rho=1.0)`
warnings.warn(colored(msg, "yellow"), err)
Building objective: Effective ripple
Precomputing transforms
[-0.25 -0.25 -0.25]
start second optimize call
Building objective: Effective ripple
Precomputing transforms
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[1], line 51
49 obj = ObjectiveFunction(obj)
50 obj.build()
---> 51 obj.compute_scaled_error(obj.x(eq))
[... skipping hidden 1 frame]
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:332](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=331), in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
330 @api_boundary
331 def cache_miss(*args, **kwargs):
--> 332 outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
333 fun, jit_info, *args, **kwargs)
334 executable = _read_most_recent_pjit_call_executable(jaxpr)
335 pgle_profiler = _read_pgle_profiler(jaxpr)
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:180](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=179), in _python_pjit_helper(fun, jit_info, *args, **kwargs)
179 def _python_pjit_helper(fun, jit_info, *args, **kwargs):
--> 180 p, args_flat = _infer_params(fun, jit_info, args, kwargs)
182 for arg in args_flat:
183 dispatch.check_arg(arg)
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:729](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=728), in _infer_params(fun, ji, args, kwargs)
726 skip_cache = True
728 if skip_cache:
--> 729 p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
730 kwargs, in_avals=None)
731 return p, p.consts + args_flat
733 entry = _infer_params_cached(
734 fun, ji, signature, avals, pjit_mesh, resource_env)
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:632](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=631), in _infer_params_impl(***failed resolving arguments***)
625 in_type = in_avals
627 in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
628 in_shardings_treedef, in_shardings_leaves,
629 ji.in_layouts_treedef, ji.in_layouts_leaves,
630 in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
--> 632 attr_token = _attr_token(flat_fun, in_type)
633 jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
634 flat_fun, in_type, attr_token, dbg,
635 HashableFunction(res_paths, closure=()),
636 IgnoreKey(ji.inline))
637 _attr_update(flat_fun, in_type, attr_token, attrs_tracked)
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:1346](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=1345), in _attr_token(fun, in_type)
1341 def _attr_token(
1342 fun: lu.WrappedFun,
1343 in_type: core.InputType | tuple[core.AbstractValue, ...]
1344 ) -> int:
1345 from jax.experimental.attrs import jax_getattr
-> 1346 cases = seen_attrs_get(fun, in_type)
1347 for i, records in enumerate(cases):
1348 for obj, attr, treedef, avals in records:
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py:1339](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/pjit.py#line=1338), in seen_attrs_get(fun, in_type)
1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py:278](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/array.py#line=277), in ArrayImpl.__bool__(self)
277 def __bool__(self):
--> 278 core.check_bool_conversion(self)
279 return bool(self._value)
File [~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py:667](http://localhost:8888/lab/tree/~/miniconda3/envs/desc-env-latest/lib/python3.12/site-packages/jax/_src/core.py#line=666), in check_bool_conversion(arr)
664 raise ValueError("The truth value of an empty array is ambiguous. Use"
665 " `array.size > 0` to check that an array is not empty.")
666 if arr.size > 1:
--> 667 raise ValueError("The truth value of an array with more than one element"
668 " is ambiguous. Use a.any() or a.all()")
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() |
Perhaps more confusingly, if you replace ONLY @gretahibbard So the minimal fix to avoid this error seems to be doing the above, perhaps this is then what we should do as it avoids wastefully creating a tensor product LinearGrid in rho, alpha, zeta that we only use to ferry around the unique rho array Still am not satisfied with not understanding this bug exactly, and why |
Last comment for tonight: Even if I replace every instance of I give up on understanding for tonight but at least there is an even more minimal fix identified and a finger to point at hashing, and a way to get more info when debugging this sort of thing documented now |
This is a seemingly similar error from JAX: jax-ml/jax#20466 |
GammaC
constants
Try this again without cache clearing |
I think the issue is that When it's an integer it gets default marked as static, and because its in a dict the entire Simple fix is just to make sure its a float, I'm also looking at simplifying some of the pytree stuff to avoid this in the future. |
@f0uriest after making the quad weights change, the issue is now |
Error seems to occur when optimizing
GammaC
objective ongh/Gamma_c
branch, happens on the second optimization step and seems related to the JIT cache? The error also only occurs if attempting an optimization at a resolution that you have previously optimized at, changing the eq resolution between steps seems to avoid this issue, so I assume it is related to the cachingMWE:
Error:
The text was updated successfully, but these errors were encountered: