Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Korbinian Eckstein committed Mar 15, 2024
1 parent b7e6703 commit adf437a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ julia --threads=auto <folder>/tgv_qsm.jl <arguments>

```julia
# Change regularization strength (1-4)
chi = qsm_tgv(phase, mask, res; TE, fieldstrength, regularization=1);
chi = qsm_tgv(phase, mask, voxel_size; TE, fieldstrength, regularization=1);
```

```julia
Expand Down
16 changes: 8 additions & 8 deletions src/oblique_stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ function dipole(x, y, z, r_treshold, direction=(0, 0, 1))
return kappa
end

function stencil(; st=27, res=(1.0, 1.0, 1.0), res0=1.0, direction=(0, 0, 1), gridsize=(64, 64, 64))
function stencil(; st=27, res=(1.0, 1.0, 1.0), res0=1.0, singularity_cutout=4, direction=(0, 0, 1), gridsize=(64, 64, 64))
middle = floor.(Int, gridsize ./ 2) .+ 1

coord = [((1:gridsize[i]) .- middle[i]) * res0 for i in 1:3]

d = [dipole(x, y, z, 4res0, direction) for x in coord[1], y in coord[2], z in coord[3]]
d = [dipole(x, y, z, singularity_cutout * res0, direction) for x in coord[1], y in coord[2], z in coord[3]]
d_mask = isfinite.(d)

# stencil mask
Expand All @@ -28,7 +28,7 @@ function stencil(; st=27, res=(1.0, 1.0, 1.0), res0=1.0, direction=(0, 0, 1), gr
mask = centered(falses((3, 3, 3)))
mask[0, 0, :] .= mask[0, :, 0] .= mask[:, 0, 0] .= true
end
mask[0,0,0] = false
mask[0, 0, 0] = false

midInd = CartesianIndex(middle)

Expand Down Expand Up @@ -65,8 +65,8 @@ function stencil(; st=27, res=(1.0, 1.0, 1.0), res0=1.0, direction=(0, 0, 1), gr
x = F.U * (y .* s_inv)
x_corr = x * res0^3

stencil = zeros(Float32, 3, 3, 3)
stencil = zeros(Float32, 3, 3, 3)

ind = 1
for i in eachindex(stencil)
if mask[i]
Expand All @@ -76,9 +76,9 @@ function stencil(; st=27, res=(1.0, 1.0, 1.0), res0=1.0, direction=(0, 0, 1), gr
end
weights = [(i^2 / res[1]^2 + j^2 / res[2]^2 + k^2 / res[3]^2) / (i^2 + j^2 + k^2) for i in -1:1, j in -1:1, k in -1:1]
stencil .*= weights
stencil[2,2,2] = 0
stencil[2,2,2] = -sum(stencil)

stencil[2, 2, 2] = 0
stencil[2, 2, 2] = -sum(stencil)

return stencil
end
15 changes: 12 additions & 3 deletions src/tgv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,20 @@ function dipole_kernel_orig(res)
end

function set_parameters(alpha, res, B0_dir, cu; orig_kernel=false)
if orig_kernel
dipole_kernel = dipole_kernel_orig(res)
if orig_kernel isa AbstractArray
dipole_kernel = orig_kernel

grad_norm_sqr = 4 * sum(res .^ -2)
grad_norm = sqrt(grad_norm_sqr)
wave_norm = sum(abs.(dipole_kernel))
norm_matrix = [0 grad_norm 1; 0 0 grad_norm; grad_norm_sqr wave_norm 0]
F = svd(norm_matrix)
norm_sqr = first(F.S)^2
elseif orig_kernel
dipole_kernel = dipole_kernel_orig(res)
grad_norm_sqr = 4 * sum(res .^ -2)
norm_sqr = 2 * grad_norm_sqr^2 + 1
else
else # default
dipole_kernel = stencil(; st=27, direction=B0_dir, res=res)

grad_norm_sqr = 4 * sum(res .^ -2)
Expand Down
10 changes: 4 additions & 6 deletions src/tgv_helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
R = @ndrange

laplace = filter_local((I, R), phi, laplace_kernel)
# wave = filter_local((I, R), chi, dipole_kernel)
wave = wave_local((I, R), chi, dipole_kernel)

@inbounds eta[I] += sigma * mask[I] * (-laplace + wave - laplace_phi0[I])
Expand Down Expand Up @@ -80,7 +79,6 @@ end
R = @ndrange

div = div_local((I, R), p, resinv, mask0)
# wave = filter_local((I, R), v, dipole_kernel, mask0)
wave = wave_local((I, R), v, dipole_kernel, mask0)

@inbounds chi_dest[I] = chi[I] + tau * (div - wave)
Expand Down Expand Up @@ -135,17 +133,17 @@ end

@inline function wave_local((I, (x, y, z)), A::AbstractArray{T}, kernel, mask=nothing) where {T}
i, j, k = Tuple(I)
result = zero(T)
wave = zero(T)
if i > 1 && j > 1 && k > 1 && i < x && j < y && k < z
for di in -1:1, dj in -1:1, dk in -1:1
if isnothing(mask)
result += A[i+di, j+dj, k+dk] * kernel[di+2, dj+2, dk+2]
wave += A[i+di, j+dj, k+dk] * kernel[di+2, dj+2, dk+2]
else
result += mask[i+di, j+dj, k+dk] * A[i+di, j+dj, k+dk] * kernel[di+2, dj+2, dk+2]
wave += mask[i+di, j+dj, k+dk] * A[i+di, j+dj, k+dk] * kernel[di+2, dj+2, dk+2]
end
end
end
return result
return wave
end

@inline function filter_local(I, A, w, mask=nothing)
Expand Down

0 comments on commit adf437a

Please sign in to comment.