Skip to content

Commit

Permalink
ML solver first run
Browse files Browse the repository at this point in the history
  • Loading branch information
islent committed Mar 30, 2024
1 parent 64e9365 commit a11e7e3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
17 changes: 8 additions & 9 deletions src/AstroNbodySim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module AstroNbodySim

__precompile__(true)

using Reexport
# using Reexport
using PrecompileTools

# basic
Expand Down Expand Up @@ -55,13 +55,13 @@ using Zygote

# JuliaAstroSim
using ParallelOperations
@reexport using AstroSimBase
@reexport using PhysicalParticles
@reexport using PhysicalMeshes
@reexport using PhysicalTrees
@reexport using AstroIO
@reexport using PhysicalFDM
@reexport using PhysicalFFT
using AstroSimBase
using PhysicalParticles
using PhysicalMeshes
using PhysicalTrees
using AstroIO
using PhysicalFDM
using PhysicalFFT
using PhysicalParticles.NumericalIntegration

# GPU
Expand Down Expand Up @@ -255,7 +255,6 @@ include("directsumgpu/gravity.jl")
include("directsumgpu/timestep.jl")

include("PM/gravity.jl")
include("PM/fft.jl")
include("PM/output.jl")
include("PM/timestep.jl")

Expand Down
9 changes: 5 additions & 4 deletions src/ML/gravity.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
function cnn_poisson(m, u, dataML)
unit_rho = m.rho / u.M * u.L^3
unit_phi = dataML.dev_cpu(Lux.apply(
result = Lux.apply(
dataML.tstate.model,
dataML.dev_gpu(reshape(unit_rho, size(unit_rho)..., 1, 1)),
dataML.tstate.parameters,
dataML.tstate.states
)[1][:,:,:,1,1])
m.phi .= unit_phi * u.unitL^2 / unitT^2
dataML.tstate.states,
)[1]
unit_phi = reshape(unit_rho, size(unit_rho)...)
m.phi .= unit_phi * u.L^2 / u.T^2
end
17 changes: 9 additions & 8 deletions src/base/Config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,16 @@ function UnitProjection(units)
return UnitProjection(
1.0*getuLength(units),
1.0*getuTime(units),
1.0,
1.0*getuLength(units)^3 / getuMass(units) / getuTime(units)^2,
1.0*getuMass(units),
)
end

function UnitProjection(m::MeshCartesianStatic)
L = m.config.Max[1] - m.config.Min[1]
T = 10.0*getuTime(m.config.units)
G = 1.0
M = G * L^3 / G / T^2
M = mean(m.rho) * L^m.config.dim
G = Constant(m.config.units).G
T = sqrt(L^3 / (4π*G*M))
return UnitProjection(L, T, G, M)
end

Expand Down Expand Up @@ -587,7 +587,7 @@ function SimConfig( ;
tstate = Lux.Experimental.TrainState(MersenneTwister(), model, opt)

if !isnothing(cnn_parameters)
tstate = setproperties!!(tstate, parameters = cnn_parameters)
tstate = setproperties!!(tstate, parameters = gpu_device()(cnn_parameters))
end
solverdata = DataML(UnitProjection(units), tstate, cpu_device(), cnn_parameters, gpu_device())
end
Expand Down Expand Up @@ -1018,6 +1018,7 @@ function Simulation(d;
zMin = nothing,
zMax = nothing,
device = CPU(),
data_on_cpu = false,
EnlargeMesh = 2.01,
BoundaryCondition = Vacuum(),

Expand Down Expand Up @@ -1143,7 +1144,7 @@ function Simulation(d;
Nx, Ny, Nz, NG,
xMin, xMax, yMin, yMax, zMin, zMax,
mode = meshmode,
device,
device, data_on_cpu,
enlarge = EnlargeMesh,
)
if config.solver.grav isa ML
Expand Down Expand Up @@ -1216,7 +1217,7 @@ function get_all_data(sim::Simulation, ::Union{DirectSum, Tree}, ::CPU)
return d
end

function get_all_data(sim::Simulation, ::Union{FDM, FFT}, ::CPU)
function get_all_data(sim::Simulation, ::Union{FDM, FFT, ML}, ::CPU)
return sim.simdata.data
end

Expand All @@ -1226,7 +1227,7 @@ function get_all_data(sim::Simulation, ::DirectSum, ::GPU)
end

"Copy data to CPU"
function get_all_data(sim::Simulation, ::Union{FDM, FFT}, ::GPU)
function get_all_data(sim::Simulation, ::Union{FDM, FFT, ML}, ::GPU)
CUDA.@allowscalar return StructArray(Array(sim.simdata.data))
end

Expand Down

0 comments on commit a11e7e3

Please sign in to comment.