Skip to content

Commit

Permalink
EnzymeRules (#26)
Browse files Browse the repository at this point in the history
Adding EnzymeRules support
  • Loading branch information
michel2323 authored May 8, 2023
1 parent 6e8e95b commit 15c9e28
Show file tree
Hide file tree
Showing 29 changed files with 513 additions and 606 deletions.
1 change: 0 additions & 1 deletion .github/workflows/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ jobs:
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- run: julia --project deps/deps.jl
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- run: julia --project=docs/ docs/make.jl
9 changes: 3 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.7.1"
version = "0.8.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -16,14 +16,11 @@ ChainRulesCore = "1"
DataStructures = "0.18"
Enzyme = "0.11"
HDF5 = "0.16"
julia = "1.7"
julia = "1.8"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Enzyme", "ForwardDiff", "ReverseDiff", "Test", "Zygote"]
test = ["Test", "Zygote"]
3 changes: 0 additions & 3 deletions deps/deps.jl

This file was deleted.

2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[deps]
Checkpointing = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2 changes: 0 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Pkg

diffractorspec = PackageSpec(url="https://github.com/JuliaDiff/Diffractor.jl", rev="main")
Pkg.add([diffractorspec])
checkpointingspec = PackageSpec(path=joinpath(dirname(@__FILE__), ".."))
Pkg.develop(checkpointingspec)

Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

The schemes are agnostic to the AD tool being used and can be easily interfaced with any Julia AD tool. Currently, the package provides the following checkpointing schemes:

* Online checkpointing schemes for adaptive timestepping
* Revolve/Binomial checkpointing [1]
* Periodic checkpointing

## Future
The following features are planned for development:

* Online checkpointing schemes for adaptive timestepping
* Composition of checkpointing schemes
* Multi-level checkpointing schemes

Expand Down
11 changes: 7 additions & 4 deletions docs/src/lib/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ CurrentModule = Checkpointing
```@docs
@checkpoint_struct
```
## Function API
```@docs
checkpoint_struct
```

## Supported Schemes
Expand All @@ -20,6 +16,13 @@ Periodic
```

## Supported Storages
```@docs
ArrayStorage
HDF5Storage
```

## Developer variables for implementing new schemes
```@docs
Scheme
Expand Down
4 changes: 2 additions & 2 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function advance(heat)
end
function sumheat(heat::Heat, chkpt::Scheme)
@checkpoint_struct revolve heat for i in 1:tsteps
function sumheat(heat::Heat, scheme::Scheme)
@checkpoint_struct scheme heat for i in 1:tsteps
heat.Tlast .= heat.Tnext
advance(heat)
end
Expand Down
50 changes: 0 additions & 50 deletions examples/adtools.jl

This file was deleted.

188 changes: 188 additions & 0 deletions examples/box_model_enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
const blength = [5000.0e5; 1000.0e5; 5000.0e5] ## north-south size of boxes, centimeters

const bdepth = [1.0e5; 5.0e5; 4.0e5] ## depth of boxes, centimeters

const delta = bdepth[1]/(bdepth[1] + bdepth[3]) ## constant ratio of two depths

const bwidth = 4000.0*1e5 ## box width, centimeters

# box areas
const barea = [blength[1]*bwidth;
blength[2]*bwidth;
blength[3]*bwidth]

# box volumes
const bvol = [barea[1]*bdepth[1];
barea[2]*bdepth[2];
barea[3]*bdepth[3]]

# parameters that are used to ensure units are in CGS (cent-gram-sec)

const hundred = 100.0
const thousand = 1000.0
const day = 3600.0*24.0
const year = day*365.0
const Sv = 1e12 ## one Sverdrup (a unit of ocean transport), 1e6 meters^3/second

# parameters that appear in box model equations
const u0 = 16.0*Sv/0.0004
const alpha = 1668e-7
const beta = 0.7811e-3

const gamma = 1/(300*day)

# robert filter coefficient for the smoother part of the timestep
const robert_filter_coeff = 0.25

# freshwater forcing
const FW = [(hundred/year) * 35.0 * barea[1]; -(hundred/year) * 35.0 * barea[1]]

# restoring atmospheric temperatures
const Tstar = [22.0; 0.0]
const Sstar = [36.0; 34.0];

# function to compute transport
# Input: rho - the density vector
# Output: U - transport value

function U_func(dens)

U = u0*(dens[2] - (delta * dens[1] + (1 - delta)*dens[3]))
return U

end

# function to compute density
# Input: state = [T1; T2; T3; S1; S2; S3]
# Output: rho

function rho_func(state)

rho = zeros(3)

rho[1] = -alpha * state[1] + beta * state[4]
rho[2] = -alpha * state[2] + beta * state[5]
rho[3] = -alpha * state[3] + beta * state[6]

return rho

end

# lastly our timestep function
# Input: fld_now = [T1(t), T2(t), ..., S3(t)]
# fld_old = [T1(t-dt), ..., S3(t-dt)]
# u = transport(t)
# dt = time step
# Output: fld_new = [T1(t+dt), ..., S3(t+dt)]

function timestep_func(fld_now, fld_old, u, dt)

temp = zeros(6)
fld_new = zeros(6)

# first computing the time derivatives of the various temperatures and salinities
if u > 0

temp[1] = u * (fld_now[3] - fld_now[1]) / bvol[1] + gamma * (Tstar[1] - fld_now[1])
temp[2] = u * (fld_now[1] - fld_now[2]) / bvol[2] + gamma * (Tstar[2] - fld_now[2])
temp[3] = u * (fld_now[2] - fld_now[3]) / bvol[3]

temp[4] = u * (fld_now[6] - fld_now[4]) / bvol[1] + FW[1] / bvol[1]
temp[5] = u * (fld_now[4] - fld_now[5]) / bvol[2] + FW[2] / bvol[2]
temp[6] = u * (fld_now[5] - fld_now[6]) / bvol[3]

elseif u <= 0

temp[1] = u * (fld_now[2] - fld_now[1]) / bvol[1] + gamma * (Tstar[1] - fld_now[1])
temp[2] = u * (fld_now[3] - fld_now[2]) / bvol[2] + gamma * (Tstar[2] - fld_now[2])
temp[3] = u * (fld_now[1] - fld_now[3]) / bvol[3]

temp[4] = u * (fld_now[5] - fld_now[4]) / bvol[1] + FW[1] / bvol[1]
temp[5] = u * (fld_now[6] - fld_now[5]) / bvol[2] + FW[2] / bvol[2]
temp[6] = u * (fld_now[4] - fld_now[6]) / bvol[3]

end

# update fldnew using a version of Euler's method

for j = 1:6
fld_new[j] = fld_old[j] + 2.0 * dt * temp[j]
end

return fld_new
end

mutable struct Box
in_now::Vector{Float64}
in_old::Vector{Float64}
out_now::Vector{Float64}
out_old::Vector{Float64}
i::Int
end

function forward_func_4_AD(in_now, in_old, out_old, out_now)
rho_now = rho_func(in_now) ## compute density
u_now = U_func(rho_now) ## compute transport
in_new = timestep_func(in_now, in_old, u_now, 10*day) ## compute new state values
for j = 1:6
in_now[j] = in_now[j] + robert_filter_coeff * (in_new[j] - 2.0 * in_now[j] + in_old[j])
end
out_old[:] = in_now
out_now[:] = in_new
return nothing
end


function advance(box::Box)
forward_func_4_AD(box.in_now, box.in_old, box.out_now, box.out_old)
end

function timestepper_for(box::Box, scheme::Scheme, tsteps::Int)
@checkpoint_struct scheme box for i in 1:tsteps
advance(box)
box.in_now[:] = box.out_old
box.in_old[:] = box.out_now
nothing
end
return box.out_now[1]
end


function box_for(scheme::Scheme, tsteps::Int)
Tbar = [20.0; 1.0; 1.0]
Sbar = [35.5; 34.5; 34.5]

# Create object from struct. tsteps is not needed for a for-loop
box = Box(copy([Tbar; Sbar]), copy([Tbar; Sbar]), zeros(6), zeros(6), 0)
dbox = Box(zeros(6), zeros(6), zeros(6), zeros(6), 0)

# Compute gradient
autodiff(Enzyme.ReverseWithPrimal, timestepper_for, Duplicated(box, dbox), scheme, tsteps)
return box.out_now[1], dbox.in_old
end

function timestepper_while(box::Box, scheme::Scheme, tsteps::Int)
box.i=1
@checkpoint_struct scheme box while box.i <= tsteps
advance(box)
box.in_now[:] = box.out_old
box.in_old[:] = box.out_now
box.i = box.i+1
nothing
end
return box.out_now[1]
end


function box_while(scheme::Scheme, tsteps::Int)
Tbar = [20.0; 1.0; 1.0]
Sbar = [35.5; 34.5; 34.5]

# Create object from struct. tsteps is not needed for a for-loop
box = Box(copy([Tbar; Sbar]), copy([Tbar; Sbar]), zeros(6), zeros(6), 0)
dbox = Box(zeros(6), zeros(6), zeros(6), zeros(6), 0)

# Compute gradient
autodiff(Enzyme.ReverseWithPrimal, timestepper_while, Duplicated(box, dbox), scheme, tsteps)
return box.out_now[1], dbox.in_old
end
4 changes: 2 additions & 2 deletions examples/box_model.jl → examples/box_model_zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function box_for(scheme::Scheme, tsteps::Int)

# Compute gradient
g = Zygote.gradient(timestepper_for, box, scheme, tsteps)
return box.out_now[1], g[1]
return box.out_now[1], g[1][2]
end

function timestepper_while(box::Box, scheme::Scheme, tsteps::Int)
Expand All @@ -182,5 +182,5 @@ function box_while(scheme::Scheme, tsteps::Int)

# Compute gradient
g = Zygote.gradient(timestepper_while, box, scheme, tsteps)
return box.out_now[1], g[1]
return box.out_now[1], g[1][2]
end
Loading

0 comments on commit 15c9e28

Please sign in to comment.