DiffVox is a self-supervised framework for Cone-Beam Computed Tomography (CBCT) reconstruction by directly optimizing a voxelgrid representation using physics-based differentiable X-ray rendering.
To install the latest release use Conda:
git clone https://github.com/hossein-momeni/DiffVox.git
cd DiffVox
conda env create -f environment.yml
conda activate diffvox
pip install .
To download the dataset, run the data.sh
script located in the data/
directory:
pip install zenodo_get
cd data
./data.sh
Note: This dataset is sourced from the study "A cone-beam X-ray computed tomography data collection designed for machine learning". It comprises 48 walnuts, each with approximately 3,600 high-resolution X-ray projections. The download requires around 300 GB of storage and may take approximately 10 hours, depending on your internet speed.
After downloading the dataset you can reconstruct the ground truth volumes using slurm
:
srun python utils/construct_ground_truth.py -d data
This runs in about ~4 min / walnut on an NVIDIA TITAN Xp.
To reconstruct the walnuts using diffvox
, you can use the script walnut_recon.py
. For example, to reconstruct walnut ID 3
with 15
views using trilinear interpolation, run the following command:
python walnut_recon.py --walnut_id 3 --n_views 15
You can customize the reconstruction experiments further by using the following flags:
Click to reveal parameters
-
--walnut_id (int)
: ID of the walnut dataset to use for reconstruction. Default:3
. -
--n_views (int)
: Number of X-ray views to use for reconstruction. Increasing this can improve reconstruction quality but increases computation time. Default:15
. -
--downsample (int)
: Factor by which to downsample the supplied X-ray images. Use this to reduce computational load. Default:1
(no downsampling). -
--batch_size (int)
: Number of rays loaded into memory for each gradient step. Adjust based on your GPU memory capacity:- Example: For an NVIDIA RTX A6000 with 48GB memory:
- Trilinear method: Up to
1,800,000
rays. - Siddon's method: Up to
500,000
rays. Default:1_800_000
.
- Trilinear method: Up to
- Example: For an NVIDIA RTX A6000 with 48GB memory:
-
--n_itr (int)
: Number of optimization iterations to perform. Default:50
. -
--lr (float)
: Learning rate for the optimizer. Default:0.01
. -
--tv_coeff (float)
: Weight coefficient for the total variation (TV) norm. Used to regularize the density map. Higher values encourage smoother reconstructions. Default:15
. -
--shift (float)
: Shift parameter applied to the input before regularization using the density regulator.- This modifies the input to
softplus
asx - shift
, allowing fine-tuning of the density's baseline value. - Useful for controlling where the density values start in the optimization process.
Default:
0
.9.--beta (float)
: Smoothing parameter for density regularization. Default:20
.
- This modifies the input to
-
--beta (float)
: Smoothing parameter for thesoftplus
function in the density regulator.- A higher
beta
makes thesoftplus
function sharper, approaching the behavior of a ReLU. - Lower values smooth the transition, which can help with optimization stability.
Default:
20
. Usage:
The density regularizer is defined as:
torch.nn.functional.softplus(x - shift, beta=self.beta, threshold=20)
- A higher
-
--loss_fn (str)
: Loss function to use for optimization. Options include:"l1"
: L1 loss"l2"
: L2 loss"pcc"
: Pearson Correlation Coefficient loss (work in progress)"ncc"
: Normalized Cross-Correlation Loss- Default:
"l1"
.
-
--renderer (str)
: Rendering method to use for generating the DRRs (Digitally Reconstructed Radiographs). Options include:"trilinear"
: Faster but less accurate."siddon"
: Physics-based rendering method, slower but more accurate.- Default:
"trilinear"
.
-
--n_points (int)
: Number of sampling points per ray in the volume.- Relevance: This parameter is used only with the
trilinear
renderer to determine the number of points sampled along each ray. - Ignored: This parameter is ignored when using the
siddon
renderer since Siddon's method inherently calculates ray intersections based on the voxel grid structure. - A higher number of points may improve reconstruction quality for
trilinear
but increases memory and computational costs. Default:500
.
- Relevance: This parameter is used only with the
-
--drr_params (dict)
: Dictionary of parameters for the DRR generator(DiffDrr
). Keys include:sdd
(float): Source-to-detector distance. Default:199.006188
.height
(int): Height of the DRR image. Default:768
.width
(int): Width of the DRR image. Default:972
.delx
(float): Detector pixel spacing. Default:0.074800
.- Note: These default values are calibrated specifically for walnut dataset reconstruction.
-
--density_regulator (str)
: Regularization method for the density function. Options include:"softplus"
: Applies a softplus transformation."sigmoid"
: Applies a sigmoid transformation. Default:"softplus"
.
-
--tv_type (str)
: Type of total variation regularization to apply. Options include:"vl1"
: Variation L1 norm."vl2"
: Variation L2 norm. Default:"vl1"
.
-
--half_orbit (bool)
: Whether to use a half-orbit of X-ray views for reconstruction instead of a full orbit. Reduces the number of views required. Default:False
. -
--drr_scale (float)
: Scale factor to apply to the generated DRRs. Default:1.0
. -
--proj_name (str)
: Project name for organizing experiments, particularly when logging with WandB. Default:"walnut_recon"
. -
--initialize_alg (str)
: Initialization algorithm for the voxel grid. Options include:"None"
: No specific initialization; the grid is initialized to zeros."fdk"
: Use Filtered Back Projection (FDK) for initialization. Commonly used in CT reconstruction for quick, approximate results."cgls"
: Use Conjugate Gradient Least Squares (CGLS) for initialization, an iterative reconstruction method."sirt"
: Use Simultaneous Iterative Reconstruction Technique (SIRT) for initialization, known for its robust iterative refinement."nesterov"
: Use Nesterov-accelerated gradient descent for initialization, providing faster convergence in optimization. Default:"None"
.
-
--log_wandb (bool)
: Whether to log experiment results to WandB. Default:False
.
To use your own dataset with DiffVox, you can create a subclass of Dataset_DiffVox
and make a constructor (__init__()
) that would handle your data. The Dataset should have the following parameters defined:
gt_projs
: Ground truth projections.sources
: Source positions for the projections (in world coordinates).targets
: Target positions for the projections (in world coordinates).subject
: An instance oftorchio.Subject
representing the dataset subject.
By defining these attributes, you ensure that your dataset is compatible with DiffVox's processing pipeline.
If you find DiffVox useful in your work, please cite our paper:
@article{momeni2024voxel,
title={Differentiable Voxel-based X-ray Rendering Improves Sparse-View 3D CBCT Reconstruction},
author={Momeni, Mohammadhossein and Gopalakrishnan, Vivek and Dey, Neel and Golland, Polina and Frisken, Sarah},
booktitle={Machine Learning and the Physical Sciences, NeurIPS 2024},
year={2024}
}