Skip to content

Commit

Permalink
fix pabs bug
Browse files Browse the repository at this point in the history
  • Loading branch information
korbinian90 committed Jul 1, 2023
1 parent 5331b84 commit 0dfe717
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
42 changes: 22 additions & 20 deletions src/tgv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end

function qsm_tgv(laplace_phi0, mask, res; TE, fieldstrength=3, omega=[0, 0, 1], alpha=(0.0015, 0.0005), iterations=1000, erosions=3, type=Float32, gpu=false)
device, cu = if gpu
CUDADevice(), CUDA.cu
CUDA.CUDAKernels.CUDABackend(), CUDA.cu
else
CPU(), identity
end
Expand Down Expand Up @@ -66,46 +66,48 @@ function qsm_tgv(laplace_phi0, mask, res; TE, fieldstrength=3, omega=[0, 0, 1],
res_inv_dim4 = cu(reshape(res .^ -1, 1, 1, 1, 3))

if gpu
synchronize()
KernelAbstractions.synchronize(device)
end
for k in 1:iterations
@time for k in 1:iterations

tau = 1 / sqrt(norm_sqr)
sigma = (1 / norm_sqr) / tau # TODO they are always identical

#############
# dual update

thread_eta = tgv_update_eta!(eta, phi_, chi_, laplace_phi0, mask0, sigma, res, omega; cu, device)
thread_p = tgv_update_p!(p, chi_, w_, mask, mask0, sigma, alpha1, res; cu, device)
thread_q = tgv_update_q!(q, w_, mask0, sigma, alpha0, res; cu, device)
wait(thread_eta)
wait(thread_p)
wait(thread_q)
# @sync begin
# Update eta <- eta + sigma*mask*(-laplace(phi) + wave(chi) - laplace_phi0).
tgv_update_eta!(eta, phi_, chi_, laplace_phi0, mask0, sigma, res, omega; cu, device)
# Update p <- P_{||.||_\infty <= alpha}(p + sigma*(mask0*grad(phi_f) - mask*w).
tgv_update_p!(p, chi_, w_, mask, mask0, sigma, alpha1, res; cu, device)
# Update q <- P_{||.||_\infty <= alpha}(q + sigma*weight*symgrad(u)).
tgv_update_q!(q, w_, mask0, sigma, alpha0, res; cu, device)
# end

#######################
# swap primal variables

(phi_, phi) = (phi, phi_)
(chi_, chi) = (chi, chi_)
(w_, w) = (w, w_)

###############
# primal update
thread_phi = tgv_update_phi!(phi, phi_, eta, mask, mask0, tau, res; cu, device)
thread_chi = tgv_update_chi!(chi, chi_, eta, p, mask0, tau, res, omega; cu, device)
thread_w = tgv_update_w!(w, w_, p, q, mask, mask0, tau, res, res_inv_dim4, qx_alloc, qy_alloc, qz_alloc; cu, device)
wait(thread_phi)
wait(thread_chi)
wait(thread_w)
# @sync begin
# Update phi_dest <- (phi + tau*laplace(mask0*eta))/(1+mask*tau).
tgv_update_phi!(phi, phi_, eta, mask, mask0, tau, res; cu, device)
# Update chi_dest <- chi + tau*(div(p) - wave(mask*v)).
tgv_update_chi!(chi, chi_, eta, p, mask0, tau, res, omega; cu, device)
# Update w_dest <- w + tau*(mask*p + div(mask0*q)).
tgv_update_w!(w, w_, p, q, mask, mask0, tau, res, res_inv_dim4, qx_alloc, qy_alloc, qz_alloc; cu, device)
# end

######################
# extragradient update

@sync begin
@async extragradient_update(phi_, phi)
@async extragradient_update(chi_, chi)
@async extragradient_update(w_, w)
@async extragradient_update!(phi_, phi)
@async extragradient_update!(chi_, chi)
@async extragradient_update!(w_, w)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/tgv_helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
px = p[i, j, k, 1] + sigmaw0 * dxp - sigmaw * w[i, j, k, 1]
py = p[i, j, k, 2] + sigmaw0 * dyp - sigmaw * w[i, j, k, 2]
pz = p[i, j, k, 3] + sigmaw0 * dzp - sigmaw * w[i, j, k, 3]
pabs = sqrt(px * px + py * py * pz * pz) * alphainv
pabs = sqrt(px * px + py * py + pz * pz) * alphainv
pabs = (pabs > 1) ? 1 / pabs : one(type)

p[i, j, k, 1] = px * pabs
Expand Down Expand Up @@ -250,6 +250,6 @@ end
w_dest[i, j, k, 3] = w[i, j, k, 3] + tau * (m0 * p[i, j, k, 3] + q2x + q4y + q5z)
end

function extragradient_update(u_, u)
function extragradient_update!(u_, u)
u_ .= 2 .* u .- u_
end

1 comment on commit 0dfe717

@korbinian90
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interesting change here is line 68 in src/tgv_helper.jl:
pabs = sqrt(px * px + py * py * pz * pz) * alphainv
->
pabs = sqrt(px * px + py * py + pz * pz) * alphainv

Please sign in to comment.