Skip to content

Commit

Permalink
Bump Enzyme version and API to v0.11
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Apr 14, 2023
1 parent 3985a2d commit d081b35
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.6.3"
version = "0.7.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,7 +14,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[compat]
ChainRulesCore = "1"
DataStructures = "0.18"
Enzyme = "0.10"
Enzyme = "0.11"
HDF5 = "0.16"
julia = "1.7"

Expand Down
2 changes: 1 addition & 1 deletion examples/adtools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function Checkpointing.jacobian(tobedifferentiated, F_H, ::EnzymeADTool)
fill!(dx, 0)
fill!(y, 0)
dy[i] = 1.0
autodiff(f, Duplicated(x,dx), Duplicated(y, dy))
autodiff(Reverse, f, Duplicated(x,dx), Duplicated(y, dy))
J[i,:] = dx[:]
end
return J
Expand Down
2 changes: 1 addition & 1 deletion src/Schemes/Online_r2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ function checkpoint_struct_while(
model_final = deepcopy(model)
# Enzyme.autodiff(body, Duplicated(model,shadowmodel))
elseif (next_action.actionflag == Checkpointing.uturn)
Enzyme.autodiff(body, Duplicated(model,shadowmodel))
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
if haskey(storemap,next_action.iteration-1-1)
push!(freeindices, storemap[next_action.iteration-1-1])
delete!(storemap,next_action.iteration-1-1)
Expand Down
2 changes: 1 addition & 1 deletion src/Schemes/Periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function checkpoint_struct_for(
end
for j= alg.period:-1:1
model = deepcopy(model_check_inner[j])
Enzyme.autodiff(body, Duplicated(model,shadowmodel))
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
end
end
return model_final
Expand Down
4 changes: 2 additions & 2 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,9 @@ function checkpoint_struct_for(
elseif (next_action.actionflag == Checkpointing.firstuturn)
body(model)
model_final = deepcopy(model)
Enzyme.autodiff(body, Duplicated(model,shadowmodel))
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
elseif (next_action.actionflag == Checkpointing.uturn)
Enzyme.autodiff(body, Duplicated(model,shadowmodel))
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
if haskey(storemap,next_action.iteration-1-1)
delete!(storemap,next_action.iteration-1-1)
check=check-1
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include("../examples/adtools.jl")
@testset "Testing Jacobian interface..." begin
include("jacobian.jl")
end
@testset "Test optcontrol" begin
@testset "Test optcontrol deprecated" begin
@testset "AD Tool $adtool" for adtool in [EnzymeADTool(), ForwardDiffADTool(), ReverseDiffADTool(), ZygoteADTool()]
@testset "Testing Revolve..." begin
include("../examples/deprecated/optcontrol.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/speelpenning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function main()

dx = zeros(n)
dy = [1.0]
autodiff(speelpenning, Duplicated(y,dy), Duplicated(x,dx))
autodiff(Reverse, speelpenning, Duplicated(y,dy), Duplicated(x,dx))
y = [0.0]
speelpenning(y,x)

Expand Down

0 comments on commit d081b35

Please sign in to comment.