You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As noted in #402 there is issue with jitting objective. This functionality should be removed.
Option (a)
Remove these following lines of code on AbstractObjective:
...
def__hash__(self):
returnhash(tuple(jtu.tree_leaves(self))) # Probably put this on the Module!def__call__(self, *args, **kwargs) ->ScalarFloat:
returnself.step(*args, **kwargs)
So that jax.jit(gpx.ConjugateMLL(negative=True)) errors. This code is dodgy.
Objectives could still be passed as objective=gpx.ConjugateMLL(negative=True) without the jit which is not really needed in the first place, as code is traced with the lax.scan.
Option (b)
Revert back to the previous objective design in GPJax that comprised an outer and inner function:
Bug Report
GPJax version: 0.7.2
Tagging @henrymoss.
As noted in #402 there is issue with jitting objective. This functionality should be removed.
Option (a)
Remove these following lines of code on AbstractObjective:
So that jax.jit(gpx.ConjugateMLL(negative=True)) errors. This code is dodgy.
Objectives could still be passed as
objective=gpx.ConjugateMLL(negative=True)
without thejit
which is not really needed in the first place, as code is traced with thelax.scan
.Option (b)
Revert back to the previous objective design in GPJax that comprised an outer and inner function:
Or even just have objectives defined from a minimisation perspective.
The text was updated successfully, but these errors were encountered: