Skip to content

Commit

Permalink
ENH: Add local weights map for conjugate gradient regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Rit authored and SimonRit committed Dec 4, 2023
1 parent f17a8ad commit 8994dfe
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 31 deletions.
28 changes: 26 additions & 2 deletions applications/rtkconjugategradient/rtkconjugategradient.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,18 @@ main(int argc, char * argv[])
WeightsReaderType::Pointer weightsReader = WeightsReaderType::New();
weightsReader->SetFileName(args_info.weights_arg);
inputWeights = weightsReader->GetOutput();
inputWeights->Update();
TRY_AND_EXIT_ON_ITK_EXCEPTION(inputWeights->Update())
}

// Read regularization weights if given
OutputImageType::Pointer localRegWeights;
if (args_info.regweights_given)
{
using WeightsReaderType = itk::ImageFileReader<OutputImageType>;
WeightsReaderType::Pointer localRegWeightsReader = WeightsReaderType::New();
localRegWeightsReader->SetFileName(args_info.regweights_arg);
localRegWeights = localRegWeightsReader->GetOutput();
localRegWeights->Update();
}

// Read Support Mask if given
Expand All @@ -103,6 +114,7 @@ main(int argc, char * argv[])
conjugategradient->SetInputVolume(inputFilter->GetOutput());
conjugategradient->SetInputProjectionStack(reader->GetOutput());
conjugategradient->SetInputWeights(inputWeights);
conjugategradient->SetLocalRegularizationWeights(localRegWeights);
conjugategradient->SetCudaConjugateGradient(!args_info.nocudacg_flag);
if (args_info.mask_given)
{
Expand All @@ -118,10 +130,22 @@ main(int argc, char * argv[])
conjugategradient->SetNumberOfIterations(args_info.niterations_arg);
conjugategradient->SetDisableDisplacedDetectorFilter(args_info.nodisplaced_flag);

REPORT_ITERATIONS(conjugategradient, ConjugateGradientFilterType, OutputImageType)
itk::TimeProbe readerProbe;
if (args_info.time_flag)
{
std::cout << "Recording elapsed time... " << std::flush;
readerProbe.Start();
}

TRY_AND_EXIT_ON_ITK_EXCEPTION(conjugategradient->Update())

if (args_info.time_flag)
{
// conjugategradient->PrintTiming(std::cout);
readerProbe.Stop();
std::cout << "It took... " << readerProbe.GetMean() << ' ' << readerProbe.GetUnit() << std::endl;
}

// Write
TRY_AND_EXIT_ON_ITK_EXCEPTION(itk::WriteImage(conjugategradient->GetOutput(), args_info.output_arg))

Expand Down
2 changes: 2 additions & 0 deletions applications/rtkconjugategradient/rtkconjugategradient.ggo
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ option "config" - "Config file"
option "geometry" g "XML geometry file name" string yes
option "output" o "Output file name" string yes
option "niterations" n "Number of iterations" int no default="5"
option "time" t "Records elapsed time during the process" flag off
option "input" i "Input volume" string no
option "weights" w "Weights file for Weighted Least Squares (WLS)" string no
option "regweights" - "Local regularization weights file" string no
option "gamma" - "Laplacian regularization weight" float no default="0"
option "tikhonov" - "Tikhonov regularization weight" float no default="0"
option "nocudacg" - "Do not perform conjugate gradient calculations on GPU" flag off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,26 @@ main(int argc, char * argv[])
input = constantImageSource->GetOutput();
}

// Read weights if given, otherwise default to weights all equal to one
WeightsImageType::Pointer weightsSource;
// Read weights if given
WeightsImageType::Pointer inputWeights;
if (args_info.weights_given)
{
TRY_AND_EXIT_ON_ITK_EXCEPTION(weightsSource = itk::ReadImage<WeightsImageType>(args_info.weights_arg))
using WeightsReaderType = itk::ImageFileReader<WeightsImageType>;
WeightsReaderType::Pointer weightsReader = WeightsReaderType::New();
weightsReader->SetFileName(args_info.weights_arg);
inputWeights = weightsReader->GetOutput();
TRY_AND_EXIT_ON_ITK_EXCEPTION(inputWeights->Update())
}
else

// Read regularization weights if given
SingleComponentImageType::Pointer localRegWeights;
if (args_info.regweights_given)
{
using ConstantWeightsSourceType = rtk::ConstantImageSource<WeightsImageType>;
ConstantWeightsSourceType::Pointer constantWeightsSource = ConstantWeightsSourceType::New();

// Set the weights to the identity matrix
constantWeightsSource->SetInformationFromImage(projections);
WeightsType constantWeight = itk::NumericTraits<WeightsType>::ZeroValue(constantWeight);
for (unsigned int i = 0; i < nMaterials; i++)
constantWeight[i + i * nMaterials] = 1;

constantWeightsSource->SetConstant(constantWeight);
TRY_AND_EXIT_ON_ITK_EXCEPTION(constantWeightsSource->Update())
weightsSource = constantWeightsSource->GetOutput();
using WeightsReaderType = itk::ImageFileReader<SingleComponentImageType>;
WeightsReaderType::Pointer localRegWeightsReader = WeightsReaderType::New();
localRegWeightsReader->SetFileName(args_info.regweights_arg);
localRegWeights = localRegWeightsReader->GetOutput();
localRegWeights->Update();
}

// Read Support Mask if given
Expand All @@ -114,19 +114,20 @@ main(int argc, char * argv[])
using ConjugateGradientFilterType =
rtk::ConjugateGradientConeBeamReconstructionFilter<OutputImageType, SingleComponentImageType, WeightsImageType>;
ConjugateGradientFilterType::Pointer conjugategradient = ConjugateGradientFilterType::New();
// conjugategradient->SetForwardProjectionFilter(ConjugateGradientFilterType::JOSEPH);
// conjugategradient->SetBackProjectionFilter(ConjugateGradientFilterType::JOSEPH);
SetForwardProjectionFromGgo(args_info, conjugategradient.GetPointer());
SetBackProjectionFromGgo(args_info, conjugategradient.GetPointer());
conjugategradient->SetInputVolume(input);
conjugategradient->SetInputProjectionStack(projections);
conjugategradient->SetInputWeights(weightsSource);
conjugategradient->SetInputWeights(inputWeights);
conjugategradient->SetLocalRegularizationWeights(localRegWeights);
conjugategradient->SetCudaConjugateGradient(!args_info.nocudacg_flag);
if (args_info.mask_given)
{
conjugategradient->SetSupportMask(supportmask);
}

if (args_info.gamma_given)
conjugategradient->SetGamma(args_info.gamma_arg);
if (args_info.tikhonov_given)
conjugategradient->SetTikhonov(args_info.tikhonov_arg);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ option "niterations" n "Number of iterations"
option "time" t "Records elapsed time during the process" flag off
option "input" i "Input volume" string no
option "weights" w "Weights file for Weighted Least Squares (WLS)" string no
option "regweights" - "Local regularization weights file" string no
option "gamma" - "Laplacian regularization weight" float no default="0"
option "tikhonov" - "Tikhonov regularization weight" float no default="0"
option "nocudacg" - "Do not perform conjugate gradient calculations on GPU" flag off
option "mask" m "Apply a support binary mask: reconstruction kept null outside the mask)" string no
Expand Down
4 changes: 4 additions & 0 deletions include/rtkConjugateGradientConeBeamReconstructionFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class ITK_TEMPLATE_EXPORT ConjugateGradientConeBeamReconstructionFilter
SetInputProjectionStack(const TOutputImage * projs);
void
SetInputWeights(const TWeightsImage * weights);
void
SetLocalRegularizationWeights(const TSingleComponentImage * weights);

using ForwardProjectionFilterType = ForwardProjectionImageFilter<TOutputImage, TOutputImage>;
using ForwardProjectionFilterPointer = typename ForwardProjectionFilterType::Pointer;
Expand Down Expand Up @@ -249,6 +251,8 @@ class ITK_TEMPLATE_EXPORT ConjugateGradientConeBeamReconstructionFilter
GetInputProjectionStack();
typename TWeightsImage::ConstPointer
GetInputWeights();
typename TSingleComponentImage::ConstPointer
GetLocalRegularizationWeights();

template <typename ImageType,
typename IterativeConeBeamReconstructionFilter<TOutputImage>::template EnableCudaScalarAndVectorType<
Expand Down
27 changes: 27 additions & 0 deletions include/rtkConjugateGradientConeBeamReconstructionFilter.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
this->SetInput("InputWeights", const_cast<TWeightsImage *>(weights));
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
void
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::
SetLocalRegularizationWeights(const TSingleComponentImage * weights)
{
this->SetInput("LocalRegularizationWeights", const_cast<TSingleComponentImage *>(weights));
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
void
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::SetSupportMask(
Expand Down Expand Up @@ -108,6 +116,14 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
return static_cast<const TWeightsImage *>(this->itk::ProcessObject::GetInput("InputWeights"));
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
typename TSingleComponentImage::ConstPointer
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::
GetLocalRegularizationWeights()
{
return static_cast<const TSingleComponentImage *>(this->itk::ProcessObject::GetInput("LocalRegularizationWeights"));
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
typename TSingleComponentImage::ConstPointer
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::GetSupportMask()
Expand Down Expand Up @@ -152,6 +168,16 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
inputWeights->SetRequestedRegion(inputWeights->GetLargestPossibleRegion());
}

// Input LocalRegularizationWeights is the optional weights map on regularization
if (this->GetLocalRegularizationWeights().IsNotNull())
{
typename TSingleComponentImage::Pointer localRegWeights =
const_cast<TSingleComponentImage *>(this->GetLocalRegularizationWeights().GetPointer());
if (!localRegWeights)
return;
localRegWeights->SetRequestedRegion(localRegWeights->GetLargestPossibleRegion());
}

// Input "SupportMask" is the support constraint mask on volume, if any
if (this->GetSupportMask().IsNotNull())
{
Expand Down Expand Up @@ -215,6 +241,7 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
m_MultiplyWithWeightsFilter->SetInput1(this->GetInputProjectionStack());
m_MultiplyWithWeightsFilter->SetInput2(m_DisplacedDetectorFilter->GetOutput());
m_CGOperator->SetInputWeights(m_DisplacedDetectorFilter->GetOutput());
m_CGOperator->SetLocalRegularizationWeights(this->GetLocalRegularizationWeights());
m_BackProjectionFilterForB->SetInput(1, m_MultiplyWithWeightsFilter->GetOutput());

// If a support mask is used, it serves as preconditioning weights
Expand Down
12 changes: 10 additions & 2 deletions include/rtkLaplacianImageFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "rtkForwardDifferenceGradientImageFilter.h"
#include "rtkBackwardDifferenceDivergenceImageFilter.h"
#include "itkMultiplyImageFilter.h"

namespace rtk
{
Expand Down Expand Up @@ -53,13 +54,19 @@ class ITK_TEMPLATE_EXPORT LaplacianImageFilter : public itk::ImageToImageFilter<
typename TOutputImage::ValueType,
TGradientImage>;
using DivergenceFilterType = rtk::BackwardDifferenceDivergenceImageFilter<TGradientImage, TOutputImage>;
using MultiplyImageFilterType = itk::MultiplyImageFilter<TGradientImage, TOutputImage>;

/** Method for creation through the object factory. */
itkNewMacro(Self);

/** Run-time type information (and related methods). */
itkTypeMacro(LaplacianImageFilter, itk::ImageToImageFilter);

void
SetWeights(const TOutputImage * weights);
typename TOutputImage::ConstPointer
GetWeights();

protected:
LaplacianImageFilter();
~LaplacianImageFilter() override = default;
Expand All @@ -72,8 +79,9 @@ class ITK_TEMPLATE_EXPORT LaplacianImageFilter : public itk::ImageToImageFilter<
void
GenerateOutputInformation() override;

typename GradientFilterType::Pointer m_Gradient;
typename DivergenceFilterType::Pointer m_Divergence;
typename GradientFilterType::Pointer m_Gradient;
typename DivergenceFilterType::Pointer m_Divergence;
typename MultiplyImageFilterType::Pointer m_Multiply;
};
} // namespace rtk

Expand Down
28 changes: 25 additions & 3 deletions include/rtkLaplacianImageFilter.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ LaplacianImageFilter<TOutputImage, TGradientImage>::LaplacianImageFilter()
m_Gradient = GradientFilterType::New();
m_Divergence = DivergenceFilterType::New();

// Set permanent connections between filters
m_Divergence->SetInput(m_Gradient->GetOutput());

// Set memory management parameters
m_Gradient->ReleaseDataFlagOn();
}
Expand All @@ -46,6 +43,17 @@ LaplacianImageFilter<TOutputImage, TGradientImage>::GenerateOutputInformation()
// Set runtime connections
m_Gradient->SetInput(this->GetInput());

// Set internal connection between filter
if (this->GetWeights().IsNotNull())
{
m_Multiply = MultiplyImageFilterType::New();
m_Multiply->SetInput1(m_Gradient->GetOutput());
m_Multiply->SetInput2(this->GetWeights());
m_Divergence->SetInput(m_Multiply->GetOutput());
}
else
m_Divergence->SetInput(m_Gradient->GetOutput());

// Update the last filter
m_Divergence->UpdateOutputInformation();

Expand All @@ -64,6 +72,20 @@ LaplacianImageFilter<TOutputImage, TGradientImage>::GenerateData()
this->GraftOutput(m_Divergence->GetOutput());
}

template <typename TOutputImage, typename TGradientImage>
void
LaplacianImageFilter<TOutputImage, TGradientImage>::SetWeights(const TOutputImage * weights)
{
this->SetInput("Weights", const_cast<TOutputImage *>(weights));
}

template <typename TOutputImage, typename TGradientImage>
typename TOutputImage::ConstPointer
LaplacianImageFilter<TOutputImage, TGradientImage>::GetWeights()
{
return static_cast<const TOutputImage *>(this->itk::ProcessObject::GetInput("Weights"));
}

} // namespace rtk


Expand Down
9 changes: 9 additions & 0 deletions include/rtkReconstructionConjugateGradientOperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ class ITK_TEMPLATE_EXPORT ReconstructionConjugateGradientOperator : public Conju
typename TSingleComponentImage::ConstPointer
GetSupportMask();

/** Set local regularization weights. The map should have the same
* information (size, spacing, origin etc.) as the reconstructed volume. The
* same map is used in Laplacian and Tikhonov regularization. */
void
SetLocalRegularizationWeights(const TSingleComponentImage * localRegularizationWeights);
typename TSingleComponentImage::ConstPointer
GetLocalRegularizationWeights();

/** Set the geometry of both m_BackProjectionFilter and m_ForwardProjectionFilter */
itkSetConstObjectMacro(Geometry, ThreeDCircularProjectionGeometry);

Expand Down Expand Up @@ -223,6 +231,7 @@ class ITK_TEMPLATE_EXPORT ReconstructionConjugateGradientOperator : public Conju
typename MultiplyFilterType::Pointer m_MultiplyInputVolumeFilter;
typename MultiplyFilterType::Pointer m_MultiplyLaplacianFilter;
typename MultiplyFilterType::Pointer m_MultiplyTikhonovFilter;
typename MultiplyFilterType::Pointer m_MultiplyTikhonovWeightsFilter;
typename AddFilterType::Pointer m_AddLaplacianFilter;
typename AddFilterType::Pointer m_AddTikhonovFilter;
typename itk::ImageToImageFilter<TOutputImage, TOutputImage>::Pointer m_LaplacianFilter;
Expand Down
Loading

0 comments on commit 8994dfe

Please sign in to comment.