Skip to content

Commit

Permalink
use device keyword, fft solver support GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
islent committed Mar 21, 2024
1 parent 64d77fd commit fd58720
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 22 deletions.
84 changes: 67 additions & 17 deletions src/PM/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function fft_grid_kk(N, eps = 1e-6)
return xx
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Periodic) where T
function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Periodic, Device::CPU) where T
rho_bar = fft(rho)
rho_bar[1] *= 0.0

Expand All @@ -36,15 +36,28 @@ function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Periodic) where
# solve u_bar
u_bar = similar(rho_bar);
for i in 1:Len[1]+1
u_bar[i] = rho_bar[i] / (delta2sum + delta2[1] * cos(xx[i]))
@inbounds u_bar[i] = rho_bar[i] / (delta2sum + delta2[1] * cos(xx[i]))
end

u = real(ifft(u_bar))
end

### Periodic boundary conditions
function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Periodic) where T
function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Periodic, Device::GPU) where T
rho_bar = fft(rho)
CUDA.@allowscalar rho_bar[1] *= 0.0

delta2 = 2 ./ (ustrip.(Δ) .^ 2)
delta2sum = - sum(delta2)

xx = fft_grid_kk(Len[1])

u_bar = rho_bar ./ (delta2sum .+ delta2[1] * cos.(CuArray(xx * ones(Len[2]+1)')))

u = real(ifft(u_bar))
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Periodic, Device::CPU) where T
rho_bar = fft(rho)
rho_bar[1] *= 0.0

delta2 = 2 ./ (ustrip.(Δ) .^ 2)
Expand All @@ -57,14 +70,30 @@ function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Periodic) where
u_bar = similar(rho_bar);
for j in 1:Len[2]+1
for i in 1:Len[1]+1
u_bar[i,j] = rho_bar[i,j] / (delta2sum + delta2[1] * cos(xx[i]) + delta2[2] * cos(yy[j]))
@inbounds u_bar[i,j] = rho_bar[i,j] / (delta2sum + delta2[1] * cos(xx[i]) + delta2[2] * cos(yy[j]))
end
end

u = real(ifft(u_bar))
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Periodic) where T
function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Periodic, Device::GPU) where T
rho_bar = fft(rho)
CUDA.@allowscalar rho_bar[1] *= 0.0

delta2 = 2 ./ (ustrip.(Δ) .^ 2)
delta2sum = - sum(delta2)

xx = fft_grid_kk(Len[1])
yy = fft_grid_kk(Len[2])

# solve u_bar
u_bar = rho_bar ./ (delta2sum .+ delta2[1] * cos.(CuArray(xx * ones(Len[2]+1)')) + delta2[2] * cos.(CuArray(ones(Len[1]+1) * yy')))

u = real(ifft(u_bar))
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Periodic, Device::CPU) where T
rho_bar = fft(rho)
rho_bar[1] *= 0.0

Expand All @@ -80,16 +109,37 @@ function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Periodic) where
for k in 1:Len[3]+1
for j in 1:Len[2]+1
for i in 1:Len[1]+1
u_bar[i,j,k] = rho_bar[i,j,k] / (delta2sum + delta2[1] * cos(xx[i]) + delta2[2] * cos(yy[j]) + delta2[3] * cos(zz[k]))
@inbounds u_bar[i,j,k] = rho_bar[i,j,k] / (delta2sum + delta2[1] * cos(xx[i]) + delta2[2] * cos(yy[j]) + delta2[3] * cos(zz[k]))
end
end
end

u = real(ifft(u_bar))
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Periodic, Device::GPU) where T
rho_bar = fft(rho)
CUDA.@allowscalar rho_bar[1] *= 0.0

delta2 = 2 ./ (ustrip.(Δ) .^ 2)
delta2sum = - sum(delta2)

xx = fft_grid_kk(Len[1])
yy = fft_grid_kk(Len[2])
zz = fft_grid_kk(Len[3])

oneMatrix = cu(ones((Len.+1)...))
dcx = delta2[1] .* cos.(oneMatrix .* cu(xx))
dcy = delta2[2] .* cos.(oneMatrix .* cu(yy'))
dcz = delta2[3] .* cos.(oneMatrix .* cu(reshape(zz, 1, 1, Len[3]+1)))
u_bar = rho_bar ./ (delta2sum .+ dcx .+ dcy .+ dcz)
CUDA.@allowscalar u_bar[1] = 0.0f0+0.0f0*im

u = real(ifft(u_bar))
end

### Homogeneous Dirichlet boundary conditions - fast sine transform
function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Dirichlet) where T
function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Dirichlet, Device::CPU) where T
#rho_bar = fft(mesh.rho)
#rho_bar[1] *= 0.0
rho_bar = FFTW.r2r(complex(rho[2:end]), FFTW.RODFT00)
Expand All @@ -102,14 +152,14 @@ function fft_poisson(Δ, Len, rho::AbstractArray{T,1}, boundary::Dirichlet) wher
# solve u_bar
u_bar = similar(rho_bar);
for i in 1:Len[1]
u_bar[i] = rho_bar[i] / (delta2sum + delta2[1] * cos(hx * i))
@inbounds u_bar[i] = rho_bar[i] / (delta2sum + delta2[1] * cos(hx * i))
end

u = real(FFTW.r2r(u_bar, FFTW.RODFT00)/((2*(Len[1] + 1))))
#mesh.phi .= real(ifft(u_bar))
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Dirichlet) where T
function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Dirichlet, Device::CPU) where T
#rho_bar = fft(mesh.rho)
#rho_bar[1] *= 0.0
rho_bar = FFTW.r2r(complex(rho[2:end, 2:end]), FFTW.RODFT00)
Expand All @@ -124,15 +174,15 @@ function fft_poisson(Δ, Len, rho::AbstractArray{T,2}, boundary::Dirichlet) wher
u_bar = similar(rho_bar);
for j in 1:Len[2]
for i in 1:Len[1]
u_bar[i,j] = rho_bar[i,j] / (delta2sum + delta2[1] * cos(hx * i) + delta2[2] * cos(hy * j))
@inbounds u_bar[i,j] = rho_bar[i,j] / (delta2sum + delta2[1] * cos(hx * i) + delta2[2] * cos(hy * j))
end
end

u = real(FFTW.r2r(u_bar, FFTW.RODFT00)/((2*(Len[1] + 1)) * (2*(Len[2] + 1))))
#mesh.phi .= real(ifft(u_bar))
end

function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Dirichlet) where T
function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Dirichlet, Device::CPU) where T
#rho_bar = fft(mesh.rho)
#rho_bar[1] *= 0.0
rho_bar = FFTW.r2r(complex(rho[2:end, 2:end, 2:end]), FFTW.RODFT00)
Expand All @@ -149,7 +199,7 @@ function fft_poisson(Δ, Len, rho::AbstractArray{T,3}, boundary::Dirichlet) wher
for k in 1:Len[3]
for j in 1:Len[2]
for i in 1:Len[1]
u_bar[i,j,k] = rho_bar[i,j,k] / (delta2sum + delta2[1] * cos(hx * i) + delta2[2] * cos(hy * j) + delta2[3] * cos(hz * k))
@inbounds u_bar[i,j,k] = rho_bar[i,j,k] / (delta2sum + delta2[1] * cos(hx * i) + delta2[2] * cos(hy * j) + delta2[3] * cos(hz * k))
end
end
end
Expand Down Expand Up @@ -183,19 +233,19 @@ end

# Dirichlet BC returns a smaller array
function fft_poisson!(m::MeshCartesianStatic, rho::AbstractArray, boundary::Periodic)
m.phi .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary) .* unit(eltype(m.phi))
m.phi .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary, m.config.device) .* unit(eltype(m.phi))
end

function fft_poisson!(m::MeshCartesianStatic, rho::AbstractArray{T,1}, boundary::Dirichlet) where T
m.phi[2:end] .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary) .* unit(eltype(m.phi))
m.phi[2:end] .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary, m.config.device) .* unit(eltype(m.phi))
end

function fft_poisson!(m::MeshCartesianStatic, rho::AbstractArray{T,2}, boundary::Dirichlet) where T
m.phi[2:end,2:end] .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary) .* unit(eltype(m.phi))
m.phi[2:end,2:end] .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary, m.config.device) .* unit(eltype(m.phi))
end

function fft_poisson!(m::MeshCartesianStatic, rho::AbstractArray{T,3}, boundary::Dirichlet) where T
m.phi[2:end,2:end,2:end] .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary) .* unit(eltype(m.phi))
m.phi[2:end,2:end,2:end] .= fft_poisson(m.config.Δ, m.config.Len, rho, boundary, m.config.device) .* unit(eltype(m.phi))
end

function fft_poisson(m::AbstractMesh, G::Number)
Expand Down
2 changes: 1 addition & 1 deletion src/PM/timestep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function outbound_rule(sim::Simulation, m::MeshCartesianStatic, ::CoarseMesh)
Ny = 10,
Nz = 10,
assign = true,
gpu = device isa GPU ? true : false,
device,
enlarge = 1.2,
)

Expand Down
2 changes: 1 addition & 1 deletion src/base/Config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ function Simulation(d;
Nx, Ny, Nz, NG,
xMin, xMax, yMin, yMax, zMin, zMax,
mode = meshmode,
gpu = device isa GPU ? true : false,
device,
enlarge = EnlargeMesh,
)
registry[id] = Simulation(
Expand Down
6 changes: 3 additions & 3 deletions test/PM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ end
NG = 0,
dim = 1,
boundary,
gpu = device isa GPU ? true : false,
device,
)
m.rho .= init_rho.(m.pos)
fdm_poisson(m, Val(1), 1/(4*pi), m.config.mode, device, boundary, sparse)
Expand Down Expand Up @@ -552,7 +552,7 @@ end
NG = 0,
dim = 2,
boundary,
gpu = device isa GPU ? true : false,
device,
)
m.rho .= init_rho.(m.pos)
fdm_poisson(m, Val(2), 1/(4*pi), m.config.mode, device, boundary, sparse)
Expand Down Expand Up @@ -621,7 +621,7 @@ end
NG = 0,
dim = 3,
boundary,
gpu = device isa GPU ? true : false,
device,
)
m.rho .= init_rho.(m.pos)
fdm_poisson(m, Val(3), 1/(4*pi), m.config.mode, device, boundary, sparse)
Expand Down

0 comments on commit fd58720

Please sign in to comment.