Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed May 8, 2024
1 parent f650ac6 commit aa46fcf
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ n_in = nvars + naugs # with augmentation
n = 1024

# Model
using ContinuousNormalizingFlows, Lux #, CUDA, ComputationalResources
using ContinuousNormalizingFlows, Lux #, CUDA, ComputationalResources, Zygote
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
# icnf = construct(RNODE, nn, nvars) # use defaults
icnf = construct(
Expand All @@ -50,6 +50,8 @@ icnf = construct(
naugs; # number of augmented dimensions
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
# compute_mode = DIJacVecVectorMode, # process data one by one
# autodiff_backend = AutoZygote(), # use Zygote
# resource = CUDALibs(), # process data by GPU
)

Expand All @@ -60,15 +62,15 @@ r = rand(data_dist, nvars, n)
r = convert.(Float32, r)

# Fit It
using DataFrames, MLJBase #, ForwardDiff, ADTypes, OptimizationOptimisers
using DataFrames, MLJBase #, Zygote, ADTypes, OptimizationOptimisers
df = DataFrame(transpose(r), :auto)
# model = ICNFModel(icnf) # use defaults
model = ICNFModel(
icnf;
batch_size = 256, # have bigger batchs
# n_epochs = 100, # have less epochs
# optimizers = (Adam(),), # use a different optimizer
# adtype = AutoForwardDiff(), # use ForwardDiff
# adtype = AutoZygote(), # use Zygote
)
mach = machine(model, df)
fit!(mach)
Expand Down

0 comments on commit aa46fcf

Please sign in to comment.