Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSIM loss #2165

Closed
nikopj opened this issue Jan 18, 2023 · 10 comments
Closed

SSIM loss #2165

nikopj opened this issue Jan 18, 2023 · 10 comments

Comments

@nikopj
Copy link

nikopj commented Jan 18, 2023

Motivation and description

The current implementation of SSIM in ImageQualityIndexes.jl is not GPU friendly (and probably not AD friendly?). I'm wondering if this is a feature that should go in Flux or NNlib+NNlibCUDA?

In PyTorch, SSIM is not native but implemetented by several external repositories, ex. pytorch-msssim, torchmetrics.

SSIM is a very common image quality metric but also a common loss function (negative ssim) in image reconstruction problems, ex. VarNet's MRI reconstruction.

Possible Implementation

The fastmri implementation of SSIM in pytorch looks like the most straightforward to port over.

@ToucheSir
Copy link
Member

On first glance, I don't see anything with the ImageQualityIndexes implementation that is inherently GPU or AD unfriendly. It may be easier to tweak it than to start again from whole cloth here. Another consideration is that metrics have to be maintained after they're added, and unless one of the other frequent Flux contributors has a good understanding of SSIM I fear we'd just be flying blind there.

@nikopj
Copy link
Author

nikopj commented Jan 18, 2023

Ok, I'll take a closer look at why IQI's ssim is failing for me on GPU and consider opening an issue there. Thanks!

@nikopj
Copy link
Author

nikopj commented Jan 20, 2023

It looks like CUDA compatibility comes down to Distances.jl, where the problem seems to have been deemed a project for another package (see #JuliaStats/Distances.jl#223).

Considering that discussion, and that IQI is written to handle many of the Images.jl datatypes like RGB, it may be easier to implement from scratch using NNlib.conv. We could write something that operates on WHCN tensors and takes in a keyword for dimensions to reduce on. Correctness can be checked against IQI.

For reference, even IQI's assess_psnr fails due to dependence on Distances.jl (see below)

julia> using ImageQualityIndexes, CUDA

julia> CUDA.allowscalar(false)

julia> x = randn(3,3);

julia> y = randn(3,3);

julia> assess_psnr(x, y)
-3.9922612777358273

julia> assess_psnr(cu(x), cu(y))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore /scratch/npj226/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
  [3] getindex(xs::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays /scratch/npj226/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
  [4] macro expansion
    @ /scratch/npj226/.julia/packages/Distances/6E33b/src/metrics.jl:253 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] _evaluate
    @ /scratch/npj226/.julia/packages/Distances/6E33b/src/metrics.jl:252 [inlined]
  [7] SqEuclidean
    @ /scratch/npj226/.julia/packages/Distances/6E33b/src/metrics.jl:377 [inlined]
  [8] (::SumSquaredDifference)(a::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, b::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ImageDistances /scratch/npj226/.julia/packages/ImageDistances/i0iBL/src/metrics.jl:79
  [9] MeanSquaredError
    @ /scratch/npj226/.julia/packages/ImageDistances/i0iBL/src/metrics.jl:97 [inlined]
 [10] mse(a::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, b::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ImageDistances /scratch/npj226/.julia/packages/ImageDistances/i0iBL/src/metrics.jl:101
 [11] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Base ./essentials.jl:729
 [12] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base ./essentials.jl:726
 [13] (::LazyModules.LazyFunction)(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ LazyModules /scratch/npj226/.julia/packages/LazyModules/d9Be6/src/LazyModules.jl:29
 [14] (::LazyModules.LazyFunction)(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ LazyModules /scratch/npj226/.julia/packages/LazyModules/d9Be6/src/LazyModules.jl:27
 [15] _assess_psnr(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ref::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, peakval::Float32)
    @ ImageQualityIndexes /scratch/npj226/.julia/packages/ImageQualityIndexes/wIQc0/src/psnr.jl:39
 [16] assess_psnr
    @ /scratch/npj226/.julia/packages/ImageQualityIndexes/wIQc0/src/psnr.jl:29 [inlined]
 [17] assess_psnr(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ref::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ImageQualityIndexes /scratch/npj226/.julia/packages/ImageQualityIndexes/wIQc0/src/psnr.jl:30
 [18] top-level scope
    @ REPL[25]:1
 [19] top-level scope
    @ /scratch/npj226/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

@nikopj
Copy link
Author

nikopj commented Jan 20, 2023

An example implementation that matches IQI's assess_ssim with base parameters is given below.

function gaussian_kernel(σ, n=ceil(Int, (6σ-1)/2))
    kernel = @. exp(-(-n:n)^2 / (2σ^2))
    return kernel ./ sum(kernel)
end

const SSIM_K = (0.01, 0.03)

const SSIM_KERNEL = let
    k = gaussian_kernel(1.5, 5)
    (k*k')[:,:,:,:]
end

function ssim(x::T, y::T; peakval=1.0, K=SSIM_K, crop=false) where {T}
    kernel = T(SSIM_KERNEL)
    C₁, C₂ = @. (peakval * K)^2

    x, y = crop ? (x, y) : pad_reflect.((x, y), size(kernel,1) ÷ 2)
    μx  = conv(x, kernel)
    μy  = conv(y, kernel)
    μx² = μx.^2
    μy² = μy.^2
    μxy = μx.*μy
    σx² = conv(x.^2, kernel) .- μx²
    σy² = conv(y.^2, kernel) .- μy²
    σxy = conv(x.*y, kernel) .- μxy

    ssim_map = @. (2μxy + C₁)*(2σxy + C₂)/((μx² + μy² + C₁)*(σx² + σy² + C₂))
    return mean(ssim_map)
end

Testing it below, it matches IQI when crop=true, I believe because I'm using reflection padding while IQI is using symmetric (NNlib is missing symmetric padding).

julia> Using ImageQualityIndexes

julia> Using CUDA, NNlib, NNlibCUDA

julia> x, y = rand(32,32,1,1), rand(32,32,1,1) .|> cu

julia> ssim(x, y)
-0.004551592920863883

julia> ssim(x, y; crop=true)
-0.015330878832750007

julia> assess_ssim(Array(x), Array(y))
-0.004238436041182532

julia> assess_ssim(Array(x), Array(y); crop=true)
-0.015330878832750073

I think this could be a useful addition as ImageFiltering.jl and related JuliaImages packages seem to not be very compatible with NNlib, Flux, and GPUs... though perhaps I'm misinterpreting the issues on ImageFiltering (ex. #JuliaImages/ImageFiltering.jl#142), let me know.

Of course, a more complete/general version of ssim or ssim_loss would require a bit more work, but I'm willing to give it go if there's a home for it.

@ToucheSir
Copy link
Member

It looks like CUDA compatibility comes down to Distances.jl, where the problem seems to have been deemed a project for another package (see #JuliaStats/Distances.jl#223).

Can you link some breadcrumbs back from that discussion to ImageQualityIndexes.jl? I see some JuliaImages packages do test for CUDA support and ImageDistances could theoretically add GPU support without changes in Distances.jl, so it doesn't appear to be a hard blocker.

I think this could be a useful addition as ImageFiltering.jl and related JuliaImages packages seem to not be very compatible with NNlib, Flux, and GPUs...

The only criteria to be compatible with Flux is to be differentiable. NNlib is the place for functionality which is really core for making NNs work and thus doesn't make sense to maintain in other libraries. Loss functions kind of blur the boundaries there, but I think the couple we do have are pretty domain-specific (e.g. CTC loss).

Given all this, I do think it'd be a good idea to discuss your issue and ideas with the JuliaImages maintainers. In the meantime, we can leave this issue open in case others want to express their interest/arguments for adding a SSIM loss on the FluxML side.

@johnnychen94
Copy link
Contributor

johnnychen94 commented Jan 24, 2023

I deliberately don't maintain the AD-able SSIM implementation because the one I made in ImageQualityIndexes are for benchmark purpose, and efforts are made to get consistent results across different toolboxes.

But for deep learning training purposes, performance is more important than "how edges are processed", "what kernels are used", "is it calculated per channel, or assumed as a 3D volume".

Thus, I would suggest to keep and maintain a copy of tailored AD-able SSIM implementation.

@ToucheSir
Copy link
Member

Thanks Johnny. That brings us back to where this should live and who can maintain it.

On the former, Metalhead is usually the place for all things computer vision in FluxML but has thus far not included any losses. Almost all loss functions live in Flux, but we've been trying to re-home them for a whlie. NNlib was the candidate if we couldn't find/create a better repo, but are we breaking some separation of concerns by adding domain-specific functionality here? Definitely worth some discussion.

On the latter, I can already see more points that need to be addressed to make the implementation in #2165 (comment) merge-worthy (though it seems close) but I wouldn't know how to address them. With examples like Flux.AlphaDropout in mind which were added and subsequently abandoned (as evidenced by it being broken without anyone realizing for at least a year), how do we avoid the same here? Can we increase the utility of this by making it work on batched/1D/3D/GPU/non-WHCN inputs so that more users are motivated to use it + contribute fixes and it doesn't meet the same fate?

@nikopj
Copy link
Author

nikopj commented Jan 25, 2023

Thanks @johnnychen94 !

@ToucheSir, I'm willing to work to make an ssim_loss implementation match your specifications! One way to be aware of the code breaking is by testing aginst Johnnys IQI ssim function on some testimages with known ssim values. As SSIM is largely statistics on sliding windows, a NNlib.conv based implementation (as started above) may allow natural inclusion of 1D/2D/3D + GPU versions.

SSIM is a fairly widespread loss function in the image reconstruction community, probably only beat in popularity by L1 and L2 loss. The paper "Loss Functions for Neural Networks for Image Processing" (2015) demonstrates this by chosing to include only L1, L2, SSIM, and MS-SSIM in their emperical study. You can also see the list of recent papers which cite this, showing that SSIM loss is still on people's minds. SSIM loss is also widely used in the image-compression DNN literature (ex. foundational paper (2018)), as it is show to maintain perceptual quality better than L2.

I think ssim_loss would belong in Flux as much as (if not more than) the classification oriented losses (hinge, dice, focal, etc.) do, though SSIM is a little more involved than those losses.

@ToucheSir
Copy link
Member

Thanks for the background. This must be quite common in the CV literature, because anecdotally I've never encountered it in general ML/DL lit (cf. the other losses mentioned, which appear far more frequently there). Those loss functions also have the benefit of being much simpler: most are only 1-5 lines long!

That said, it seems like there's no reason not to incorporate this function if the above requirements are met. My big caveat is that I lack the expertise to provide guidance on fulfilling said requirements, and currently nobody else has stepped up to help.

@nikopj
Copy link
Author

nikopj commented Jul 12, 2023

@nikopj nikopj closed this as completed Jul 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants