Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix halo implementation and tiling artefact #113

Merged
merged 6 commits into from
Apr 15, 2024
Merged

Conversation

qin-yu
Copy link
Collaborator

@qin-yu qin-yu commented Apr 10, 2024

Fix halo implementation and tiling artefact

I've revisited the issue of the tiling artifact and identified that the current implementation of the halo is incorrect. For each patch, the halo should constitute the surrounding margins of that patch. The use of mirror padding is responsible for our tiling artifact.

Before my fix, with a patch size of 96x96x96, a stride of 96x96x96, and a halo of 32x64x64, the prediction exhibited a clear tiling artifact. With the correct implementation of the halo, using the same configuration, the prediction shows significant improvement. Note that with such settings, there is no overlap between neighbour patches, yet we still see almost no artefact with the new implementation.

Comparison

Left: Mirror pad; Right: Halo pad

image

Design Choices

  • Padding with halo happens in Dataset's get method, while removal of padding happens in Predictor. Since Predictors takes data Datasets wrapped in loaders, I move halo config to SliceBuilder under loader config. Predictors can access the halo directly from the input Datasets.
  • The functions for padding and unpadding have been refactored and moved into dataset-utils.

Old and New Config Files

Old config file for inference (Predictor taking halo, patch mirror-padded)
# path to the checkpoint file containing the model
model_path: /.../plantseg_original_1135_rest_rotate2d_fmaps16_max/best_checkpoint.pytorch
# model configuration
model:
  # model class
  name: UNet3D
  # number of input channels to the model
  in_channels: 1
  # number of output channels
  out_channels: 2
  # determines the order of operators in a single layer (gcr - GroupNorm+Conv3d+ReLU)
  layer_order: gcr
  # feature maps scale factor
  f_maps: 16
  # number of groups in the groupnorm
  num_groups: 8
  # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax
  final_sigmoid: true
  # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
  is_segmentation: true
# predictor configuration
predictor:
  # standard in memory predictor
  name: 'StandardPredictor'
  # halo around each input patch, created with mirror reflectiion
  patch_halo: [32, 64, 64]
# specify the test datasets
loaders:
  # batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel
  batch_size: 1
  # mirror pad the raw data in each axis for sharper prediction near the boundaries of the volume
  mirror_padding: [32, 32, 32]
  # path to the raw data within the H5
  raw_internal_path: raw/noisy
  # how many subprocesses to use for data loading
  num_workers: 8
  # test loaders configuration
  test:
    # paths to the test datasets; if a given path is a directory all H5 files ('*.h5', '*.hdf', '*.hdf5', '*.hd5')
    # inside this this directory will be included as well (non-recursively)
    file_paths:
      - /.../1135.h5

    # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
    slice_builder:
      # SliceBuilder class
      name: SliceBuilder
      # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
      patch_shape: [96, 96, 96]
      # train stride between patches
      stride_shape: [96, 96, 96]

    transformer:
      raw:
        - name: Standardize
        - name: ToTensor
          expand_dims: true
New config file for inference (SliceBuilder taking halo, patch halo-padded)
# path to the checkpoint file containing the model
model_path: /.../plantseg_original_1135_rest_rotate2d_fmaps16_max/best_checkpoint.pytorch
# model configuration
model:
  # model class
  name: UNet3D
  # number of input channels to the model
  in_channels: 1
  # number of output channels
  out_channels: 2
  # determines the order of operators in a single layer (gcr - GroupNorm+Conv3d+ReLU)
  layer_order: gcr
  # feature maps scale factor
  f_maps: 16
  # number of groups in the groupnorm
  num_groups: 8
  # apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax
  final_sigmoid: true
  # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
  is_segmentation: true
# predictor configuration
predictor:
  # standard in memory predictor
  name: 'StandardPredictor'
# specify the test datasets
loaders:
  # batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel
  batch_size: 1
  # mirror pad the raw data in each axis for sharper prediction near the boundaries of the volume
  mirror_padding: [0, 0, 0]
  # path to the raw data within the H5
  raw_internal_path: raw/noisy
  # how many subprocesses to use for data loading
  num_workers: 8
  # test loaders configuration
  test:
    # paths to the test datasets; if a given path is a directory all H5 files ('*.h5', '*.hdf', '*.hdf5', '*.hd5')
    # inside this this directory will be included as well (non-recursively)
    file_paths:
      - /.../1135.h5

    # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
    slice_builder:
      # SliceBuilder class
      name: SliceBuilder
      # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
      patch_shape: [96, 96, 96]
      # train stride between patches
      stride_shape: [96, 96, 96]
      # raw image halo around each patch
      halo_shape: [32, 64, 64]

    transformer:
      raw:
        - name: Standardize
        - name: ToTensor
          expand_dims: true

Required changes in config files:

image

@qin-yu qin-yu requested a review from wolny April 10, 2024 21:19
@wolny
Copy link
Owner

wolny commented Apr 11, 2024

This is great @qin-yu! Thanks a lot for the PR. It looks good to me from the first glance. I'll do a proper review tomorrow, merge it and release the new version.

Copy link
Owner

@wolny wolny left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, I've left couple of comments/questions

pytorch3dunet/datasets/hdf5.py Show resolved Hide resolved
pytorch3dunet/datasets/hdf5.py Outdated Show resolved Hide resolved
pytorch3dunet/datasets/hdf5.py Outdated Show resolved Hide resolved
pytorch3dunet/datasets/utils.py Show resolved Hide resolved
pytorch3dunet/datasets/hdf5.py Show resolved Hide resolved
@qin-yu
Copy link
Collaborator Author

qin-yu commented Apr 12, 2024

Here I illustrate why raw is padded on all sides but slicing is simply extending the end by 2 x halo

self.raw_padded = mirror_pad(self.raw, self.halo_shape)
raw_idx_padded = tuple(slice(this_index.start, this_index.stop + 2 * this_halo, None) for this_index, this_halo in zip(raw_idx, self.halo_shape))

On the left I show how padded input are predicted and put back to the original shape; on the right I show how patches should be sliced correctly:

Drawing

Note that the image in #113 (comment) is produced with patch shape = stride shape, i.e. the tiling can only match if slicing is done properly.

@qin-yu
Copy link
Collaborator Author

qin-yu commented Apr 12, 2024

  • Make sure only testing has halo padded patches in AbstractHDF5Dataset
  • Make sure new methods have Google style docstrings
  • Make a test for slicer

@qin-yu qin-yu requested a review from wolny April 12, 2024 23:46
Copy link
Owner

@wolny wolny left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants