Skip to content

Commit

Permalink
fix ML setup
Browse files Browse the repository at this point in the history
  • Loading branch information
islent committed Mar 29, 2024
1 parent 41b161d commit 64e9365
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/PM/gravity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function compute_force(sim::Simulation, GravSolver::Union{FDM, FFT, ML}, Device:

# Solve QUMOND on mesh
if sim.config.grav.model isa QUMOND
QUMOND_acc!(m, ACC0, G, Device, sim.config.grav.sparse)
QUMOND_acc!(m, ACC0, G)
end

# Assign acc to inbound particles
Expand Down
72 changes: 69 additions & 3 deletions src/base/Config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,41 @@ end

##### SolverConfig #####

struct UnitProjection
L
T
G
M
end

"""
Construct without initialization
"""
function UnitProjection(units)
return UnitProjection(
1.0*getuLength(units),
1.0*getuTime(units),
1.0,
1.0*getuMass(units),
)
end

function UnitProjection(m::MeshCartesianStatic)
L = m.config.Max[1] - m.config.Min[1]
T = 10.0*getuTime(m.config.units)
G = 1.0
M = G * L^3 / G / T^2
return UnitProjection(L, T, G, M)
end

mutable struct DataML
u::UnitProjection
tstate
best_parameters
dev_cpu
dev_gpu
end

"""
$(TYPEDEF)
Expand All @@ -204,6 +239,9 @@ struct SolverConfig{SolverG #=, SolverH=#}
grav::SolverG
#"Hydrodynamical force solver. Supported: `SPH`, `MHD`, `FEM`, `FVM`"
#hydro::SolverH

"Data specialized for the solver"
data
end

"""
Expand All @@ -215,9 +253,11 @@ function SolverConfig(;

# Keywords to override
grav::Gravity = GravitySolver,
data = nothing,
)
return SolverConfig(
grav,
data,
)
end

Expand Down Expand Up @@ -528,8 +568,30 @@ function SimConfig( ;
EnlargeMesh::Float64 = 2.01,
BoundaryCondition::BoundaryCondition = Vacuum(),
sparse::Bool = true,
)

# ML
cnn_model = nothing,
cnn_parameters = nothing,
learning_rate = 0.001f0,
optimiser = Optimisers.Adam,
solverdata = nothing,
)
# Construct here and initialize in Simulation
if GravitySolver isa ML
if !isnothing(cnn_model)
model = cnn_model
else
model = Chain()
end
opt = optimiser(learning_rate)
tstate = Lux.Experimental.TrainState(MersenneTwister(), model, opt)

if !isnothing(cnn_parameters)
tstate = setproperties!!(tstate, parameters = cnn_parameters)
end
solverdata = DataML(UnitProjection(units), tstate, cpu_device(), cnn_parameters, gpu_device())
end

return SimConfig(
name, author, daytime,
floattype,
Expand All @@ -546,6 +608,7 @@ function SimConfig( ;
),
SolverConfig(;
GravitySolver,
data = solverdata,
),
GravityConfig(;
ForceSofteningTable = MVector{length(ForceSofteningTable)}(ForceSofteningTable),
Expand Down Expand Up @@ -1071,7 +1134,7 @@ function Simulation(d;
)
end
@info "Data cuts: " * string(gather(registry[id], numlocal))
elseif config.solver.grav isa FDM || config.solver.grav isa FFT
elseif config.solver.grav isa FDM || config.solver.grav isa FFT || config.solver.grav isa ML
@info "Setting up $(traitstring(config.solver.grav)) simulation..."
dStruct = StructArray(d)
mesh = MeshCartesianStatic(dStruct, units;
Expand All @@ -1083,6 +1146,9 @@ function Simulation(d;
device,
enlarge = EnlargeMesh,
)
if config.solver.grav isa ML
config.solver.data.u = UnitProjection(mesh)
end
registry[id] = Simulation(
config, id, pids,
mesh,
Expand Down Expand Up @@ -1121,7 +1187,7 @@ function get_local_data(sim::Simulation, ::Tree, ::CPU)
return sim.simdata.tree.data
end

function get_local_data(sim::Simulation, ::Union{FDM, FFT}, ::DeviceType)
function get_local_data(sim::Simulation, ::Union{FDM, FFT, ML}, ::DeviceType)
return sim.simdata.data
end

Expand Down
2 changes: 1 addition & 1 deletion src/base/Plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function pack_pos(sim::Simulation, ::DirectSum, ::GPU)
return pack_xyz(d)
end

function pack_pos(sim::Simulation, ::Union{DirectSum, Tree, FFT, FDM}, ::CPU)
function pack_pos(sim::Simulation, ::Union{DirectSum, Tree, FFT, FDM, ML}, ::CPU)
uLength = getuLength(sim.config.units)
d = StructArray(ustrip.(uLength, get_all_data(sim).Pos))
return pack_xyz(d)
Expand Down
6 changes: 5 additions & 1 deletion src/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function preprocessdata(sim::Simulation, ::DirectSum, ::GPU)
preferunits(sim.config.units)
end

function preprocessdata(sim::Simulation, ::ML, ::DeviceType)
function preprocessdata(sim::Simulation, ::ML, ::GPU)
sim.config.solver.data.u = UnitProjection(sim.simdata)

if isempty(sim.config.solver.data.tstate.model.layers)
Expand All @@ -32,6 +32,10 @@ function preprocessdata(sim::Simulation, ::ML, ::DeviceType)
end
end

function preprocessdata(sim::Simulation, ::ML, ::CPU)
@warn "ML model on CPU is not encouraged! Try use keyword `device = GPU()`"
end

"""
$TYPEDSIGNATURES
This function does all the work for you:
Expand Down

0 comments on commit 64e9365

Please sign in to comment.