From f110e3bf22c1bea74e39cbf97b44610cfe7062ee Mon Sep 17 00:00:00 2001 From: Fabian Wagner Date: Thu, 15 Dec 2022 14:45:48 +0100 Subject: [PATCH 1/9] Integrated trainable bilateral filter layer with unit tests. Coding style and unit tests passed. Signed-off-by: Fabian Wagner --- monai/csrc/ext.cpp | 2 + monai/csrc/filtering/filtering.h | 1 + .../bf_layer_cpu_backward.cpp | 260 ++++++++++ .../bf_layer_cpu_forward.cpp | 306 +++++++++++ .../bf_layer_gpu_backward.cu | 272 ++++++++++ .../bf_layer_gpu_forward.cu | 319 ++++++++++++ .../trainable_bilateral.cpp | 103 ++++ .../trainable_bilateral/trainable_bilateral.h | 79 +++ monai/networks/layers/__init__.py | 2 +- monai/networks/layers/filtering.py | 149 +++++- tests/test_trainable_bilateral.py | 476 ++++++++++++++++++ 11 files changed, 1967 insertions(+), 2 deletions(-) create mode 100644 monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp create mode 100644 monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp create mode 100644 monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu create mode 100644 monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu create mode 100644 monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp create mode 100644 monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h create mode 100644 tests/test_trainable_bilateral.py diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index ac43e6fd3e..fc47247473 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -22,6 +22,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // filtering m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter"); + m.def("tbf_forward", &TrainableBilateralFilterForward, "Trainable Bilateral Filter Forward"); + m.def("tbf_backward", &TrainableBilateralFilterBackward, "Trainable Bilateral Filter Backward"); // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index 3e680010ed..2b0c499259 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -15,3 +15,4 @@ limitations under the License. #include "bilateral/bilateral.h" #include "permutohedral/permutohedral.h" +#include "trainable_bilateral/trainable_bilateral.h" diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp new file mode 100644 index 0000000000..b91e24458c --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -0,0 +1,260 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_bilateral.h" + +struct Indexer { +public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer(){ + delete [] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + +private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpuBackward_3d(torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(gradientInputTensor); + + // Raw tensor data pointers. + scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr(); + scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr(); + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t filter_kernel = 0; + scalar_t valueSum = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] - + inputTensorData[homeOffset + i * + desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * + gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y || + kernelIndex[2] != halfWindowSize_z) { + + filter_kernel = + -(1 / + outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * + outputTensorData[neighbourOffset + i * desc.channelStride] * + totalWeight * + colorDistance / (colorSigma * colorSigma) + + (1 / + outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * + totalWeight * + (1 + + inputTensorData[homeOffset + i * desc.channelStride] * colorDistance / + (colorSigma * colorSigma)); // inputTensorData[homeOffset] !! + } else { + + filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride]; + } + + valueSum += + gradientInputTensorData[neighbourOffset + i * desc.channelStride] * + filter_kernel; + + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +torch::Tensor BilateralFilterCpuBackward(torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "BilateralFilterCpuBackward_3d", ([&] { + BilateralFilterCpuBackward_3d( + gradientInputTensor, + gradientOutputTensor, + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return gradientOutputTensor; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp new file mode 100644 index 0000000000..ccfa1fe640 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp @@ -0,0 +1,306 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_bilateral.h" + +struct Indexer { +public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer(){ + delete [] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + +private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpuForward_3d(torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr(); + scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr(); + scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr(); + scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * + gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += + (-1) * totalWeight * + inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistance / (colorSigma * + colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += + totalWeight * colorDistanceSquared / + std::abs(colorSigma * colorSigma * colorSigma); + colorSum_alpha += + totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + + xSum_w += + totalWeight * xDistanceSquared[kernelIndex[0]] / + std::abs(sigma_x * sigma_x * sigma_x); + xSum_alpha += + totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + + ySum_w += + totalWeight * yDistanceSquared[kernelIndex[1]] / + std::abs(sigma_y * sigma_y * sigma_y); + ySum_alpha += + totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + + zSum_w += + totalWeight * zDistanceSquared[kernelIndex[2]] / + std::abs(sigma_z * sigma_z * sigma_z); + zSum_alpha += + totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + + // Filtering: + outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum; + dO_dx_kiData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * + dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_rData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * + colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_xData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * + xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_yData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * + ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_zData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * + zSum_w + (1 / weightSum) * zSum_alpha; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +std::tuple BilateralFilterCpuForward(torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpuForward_3d", ([&] { + BilateralFilterCpuForward_3d( + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + dO_dsig_r, + dO_dsig_x, + dO_dsig_y, + dO_dsig_z, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu new file mode 100644 index 0000000000..d7d2a3960a --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -0,0 +1,272 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" + +__constant__ int cBatchStrideBack; +__constant__ int cColorStrideBack; + +__constant__ int cSizesBack[3]; +__constant__ int cStridesBack[3]; + +__constant__ int cKernelSizesBack[3]; +__constant__ int cHalfWindowSize_arrBack[3]; +__constant__ float cGaussianKernel_xBack[256]; +__constant__ float cGaussianKernel_yBack[256]; +__constant__ float cGaussianKernel_zBack[256]; +__constant__ float cXDistanceSquaredBack[256]; +__constant__ float cYDistanceSquaredBack[256]; +__constant__ float cZDistanceSquaredBack[256]; +__constant__ float cColorExponentConstantBack; +__constant__ float cSigma_xBack; +__constant__ float cSigma_yBack; +__constant__ float cSigma_zBack; +__constant__ float cColorSigmaBack; + + +template +__global__ void BilateralFilterCudaKernel3DBackward(scalar_t* gradientInputTensor, + scalar_t* gradientOutputTensor, + scalar_t* inputTensor, + scalar_t* outputTensor, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStrideBack; + + if (homeOffset >= cColorStrideBack) + return; + + int homeX = homeOffset / cStridesBack[0]; + int homeY = (homeOffset - homeX * cStridesBack[0]) / cStridesBack[1]; + int homeZ = (homeOffset - homeX * cStridesBack[0] - homeY * cStridesBack[1]) / cStridesBack[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); + scalar_t gaussianX = cGaussianKernel_xBack[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizesBack[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arrBack[1]), cSizesBack[1] - 1)); + scalar_t gaussianY = cGaussianKernel_yBack[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizesBack[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arrBack[2]), cSizesBack[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_zBack[kernelZ]; + + int neighbourOffset = neighbourX * cStridesBack[0] + neighbourY * cStridesBack[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!), otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = inputTensor[batchOffset + neighbourOffset + c * cColorStrideBack]; + scalar_t b = inputTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstantBack * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + scalar_t filter_kernel_back; + +#pragma unroll + for (int c = 0; c < C; c++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] || + kernelZ != cHalfWindowSize_arrBack[2]) { + + filter_kernel_back = + -(1 / + outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * + outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * + totalWeight * + colorDistance / (cColorSigmaBack * cColorSigmaBack) + + (1 / + outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * + totalWeight * + (1 + + inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance / + (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !! + } else { + + filter_kernel_back = dO_dx_ki[batchOffset + homeOffset + c * cColorStrideBack]; + } + + valueSum += + gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * + filter_kernel_back; + + } + + weightSum += totalWeight; + + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSum; + + } +} + +template +void BilateralFilterCudaBackwardFunction(torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStrideBack, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStrideBack, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizesBack, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStridesBack, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizesBack, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arrBack, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_xBack, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_yBack, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_zBack, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquaredBack, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquaredBack, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquaredBack, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstantBack, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_xBack, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_yBack, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float)); + +// cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel3DBackward", ([&] { + BilateralFilterCudaKernel3DBackward + <<>>( + gradientInputTensor.data_ptr(), + gradientOutputTensor.data_ptr(), + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr()); + })); + +// cuda_error_check("Cuda check after kernel call."); +// delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterCudaBackward(torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); +// cuda_error_check("beginning"); + +#define CASE(c, d) BilateralFilterCudaBackwardFunction(gradientInputTensor, gradientOutputTensor, inputTensor, outputTensor, outputWeightsTensor, dO_dx_ki, sigma_x, sigma_y, sigma_z, colorSigma); + SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, gradientInputTensor.size(1), gradientInputTensor.dim() - 2); + + return gradientOutputTensor; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu new file mode 100644 index 0000000000..36473ee379 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu @@ -0,0 +1,319 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSizes[3]; +__constant__ int cHalfWindowSize_arr[3]; +__constant__ float cGaussianKernel_x[256]; +__constant__ float cGaussianKernel_y[256]; +__constant__ float cGaussianKernel_z[256]; +__constant__ float cXDistanceSquared[256]; +__constant__ float cYDistanceSquared[256]; +__constant__ float cZDistanceSquared[256]; +__constant__ float cColorExponentConstant; +__constant__ float cSigma_x; +__constant__ float cSigma_y; +__constant__ float cSigma_z; +__constant__ float cColorSigma; + + +template +__global__ void BilateralFilterCudaKernel3DForward(scalar_t* input, + scalar_t* output, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki, + scalar_t* dO_dsig_r, + scalar_t* dO_dsig_x, + scalar_t* dO_dsig_y, + scalar_t* dO_dsig_z) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizes[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arr[0]), cSizes[0] - 1)); + scalar_t gaussianX = cGaussianKernel_x[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizes[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arr[1]), cSizes[1] - 1)); + scalar_t gaussianY = cGaussianKernel_y[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizes[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arr[2]), cSizes[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_z[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!), otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward the other way around !! + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstant * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + +#pragma unroll + for (int c = 0; c < C; c++) { + valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistance / (cColorSigma * + cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += + totalWeight * colorDistanceSquared / + std::abs(cColorSigma * cColorSigma * cColorSigma); + colorSum_alpha += + totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + + xSum_w += + totalWeight * cXDistanceSquared[kernelX] / + std::abs(cSigma_x * cSigma_x * cSigma_x); + xSum_alpha += + totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + + ySum_w += + totalWeight * cYDistanceSquared[kernelY] / + std::abs(cSigma_y * cSigma_y * cSigma_y); + ySum_alpha += + totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + + zSum_w += + totalWeight * cZDistanceSquared[kernelZ] / + std::abs(cSigma_z * cSigma_z * cSigma_z); + zSum_alpha += + totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + + } + + weightSum += totalWeight; + + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { +// output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum; + dO_dx_ki[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * + dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_r[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * + colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_x[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * + xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_y[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * + ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_z[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * + zSum_w + (1 / weightSum) * zSum_alpha; + } +} + +template +void BilateralFilterCudaForwardFunction(torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizes, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arr, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_x, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_y, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_z, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquared, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquared, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquared, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstant, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_x, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_y, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float)); + +// cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel3DForward", ([&] { + BilateralFilterCudaKernel3DForward + <<>>( + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr(), + dO_dsig_r.data_ptr(), + dO_dsig_x.data_ptr(), + dO_dsig_y.data_ptr(), + dO_dsig_z.data_ptr()); + } + )); + +// cuda_error_check("Cuda check after kernel call."); +// delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +std::tuple BilateralFilterCudaForward(torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); +// cuda_error_check("beginning"); + +#define CASE(c, d) BilateralFilterCudaForwardFunction(inputTensor, outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z, sigma_x, sigma_y, sigma_z, colorSigma); + SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp new file mode 100644 index 0000000000..90cd609f5f --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp @@ -0,0 +1,103 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "trainable_bilateral.h" +#include "utils/common_utils.h" + +std::tuple TrainableBilateralFilterForward(torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + std::tuple (*filterFunction)(torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && inputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(inputTensor); + + if (inputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (inputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &BilateralFilterCudaForward; + } else { + filterFunction = &BilateralFilterCpuForward; + } +#else + filterFunction = &BilateralFilterCpuForward; +#endif + + return filterFunction(inputTensor, sigma_x, sigma_y, sigma_z, colorSigma); +} + +torch::Tensor TrainableBilateralFilterBackward(torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor (*filterFunction)(torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && gradientInputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(gradientInputTensor); + + if (gradientInputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (gradientInputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &BilateralFilterCudaBackward; + } else { + filterFunction = &BilateralFilterCpuBackward; + } +#else + filterFunction = &BilateralFilterCpuBackward; +#endif + + return filterFunction(gradientInputTensor, inputTensor, outputTensor, outputWeightsTensor, dO_dx_ki, sigma_x, sigma_y, sigma_z, colorSigma); +} diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h new file mode 100644 index 0000000000..d34cc7849b --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h @@ -0,0 +1,79 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include "utils/common_utils.h" +#include "utils/tensor_description.h" +#include +#include +#include + +#define BF_CUDA_MAX_CHANNELS 16 +#define BF_CUDA_MAX_SPATIAL_DIMENSION 3 + +#ifdef WITH_CUDA +std::tuple BilateralFilterCudaForward(torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); +torch::Tensor BilateralFilterCudaBackward(torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); +#endif + +std::tuple BilateralFilterCpuForward(torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +torch::Tensor BilateralFilterCpuBackward(torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +std::tuple TrainableBilateralFilterForward(torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +torch::Tensor TrainableBilateralFilterBackward(torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 31bd36dd8f..73ab43064e 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -12,7 +12,7 @@ from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args -from .filtering import BilateralFilter, PHLFilter +from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter from .gmm import GaussianMixtureModel from .simplelayers import ( LLTM, diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index bbf925eba9..104458c94a 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -15,7 +15,7 @@ _C, _ = optional_import("monai._C") -__all__ = ["BilateralFilter", "PHLFilter"] +__all__ = ["BilateralFilter", "PHLFilter", "TrainableBilateralFilter"] class BilateralFilter(torch.autograd.Function): @@ -99,3 +99,150 @@ def backward(ctx, grad_output): # scaled_features, = ctx.saved_variables # grad_input = _C.phl_filter(grad_output, scaled_features) # return grad_input + + +class TrainableBilateralFilterFunction(torch.autograd.Function): + """ + torch.autograd.Function for the TrainableBilateralFilter layer. + + See: + F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in + computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718 + + Args: + input: input tensor to be filtered. + + sigma x: trainable standard deviation of the spatial filter kernel in x direction. + + sigma y: trainable standard deviation of the spatial filter kernel in y direction. + + sigma z: trainable standard deviation of the spatial filter kernel in z direction. + + color sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + @staticmethod + def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma): + output_tensor, output_weights_tensor, do_dx_ki, do_dsig_r, do_dsig_x, do_dsig_y, do_dsig_z = _C.tbf_forward( + input_img, sigma_x, sigma_y, sigma_z, color_sigma + ) + + ctx.save_for_backward( + input_img, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + output_tensor, + output_weights_tensor, + do_dx_ki, + do_dsig_r, + do_dsig_x, + do_dsig_y, + do_dsig_z, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + input_img = ctx.saved_tensors[0] # input image + sigma_x = ctx.saved_tensors[1] + sigma_y = ctx.saved_tensors[2] + sigma_z = ctx.saved_tensors[3] + color_sigma = ctx.saved_tensors[4] + output_tensor = ctx.saved_tensors[5] # filtered image + output_weights_tensor = ctx.saved_tensors[6] # weights + do_dx_ki = ctx.saved_tensors[7] # derivative of output with respect to input, while k==i + do_dsig_r = ctx.saved_tensors[8] # derivative of output with respect to range sigma + do_dsig_x = ctx.saved_tensors[9] # derivative of output with respect to sigma x + do_dsig_y = ctx.saved_tensors[10] # derivative of output with respect to sigma y + do_dsig_z = ctx.saved_tensors[11] # derivative of output with respect to sigma z + + # calculate gradient with respect to the sigmas + grad_color_sigma = torch.sum(grad_output * do_dsig_r) + grad_sig_x = torch.sum(grad_output * do_dsig_x) + grad_sig_y = torch.sum(grad_output * do_dsig_y) + grad_sig_z = torch.sum(grad_output * do_dsig_z) + + grad_output_tensor = _C.tbf_backward( + grad_output, + input_img, + output_tensor, + output_weights_tensor, + do_dx_ki, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + ) + + return grad_output_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma + + +class TrainableBilateralFilter(torch.nn.Module): + """ + Implementation of a trainable bilateral filter layer as proposed in the corresponding publication. + All filter parameters can be trained data-driven. The spatial filter kernels x, y, and z determine + image smoothing whereas the color parameter specifies the amount of edge preservation. + Can run on 1D, 2D, or 3D tensors (on top of Batch and Channel dimensions). + + See: + F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in + computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718 + + Args: + input: input tensor to be filtered. + + sigma x: trainable standard deviation of the spatial filter kernel in x direction. + + sigma y: trainable standard deviation of the spatial filter kernel in y direction. + + sigma z: trainable standard deviation of the spatial filter kernel in z direction. + + color sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + def __init__(self, sigma_x, sigma_y, sigma_z, color_sigma): + super(TrainableBilateralFilter, self).__init__() + + # Register sigmas as trainable parameters. + self.sigma_x = torch.nn.Parameter(torch.tensor(sigma_x)) + self.sigma_y = torch.nn.Parameter(torch.tensor(sigma_y)) + self.sigma_z = torch.nn.Parameter(torch.tensor(sigma_z)) + self.color_sigma = torch.nn.Parameter(torch.tensor(color_sigma)) + + def forward(self, input_tensor): + assert input_tensor.shape[1] == 1, ( + "Currently channel dimensions >1 are not supported. " + "Please use multiple parallel filter layers if you want " + "to filter multiple channels." + ) + + len_input = len(input_tensor.shape) + + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + prediction = TrainableBilateralFilterFunction.apply( + input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.color_sigma + ) + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + prediction = prediction.squeeze(4).squeeze(3) + elif len_input == 4: + prediction = prediction.squeeze(4) + + return prediction diff --git a/tests/test_trainable_bilateral.py b/tests/test_trainable_bilateral.py new file mode 100644 index 0000000000..a42794616a --- /dev/null +++ b/tests/test_trainable_bilateral.py @@ -0,0 +1,476 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized +from torch.autograd import gradcheck + +from monai.networks.layers.filtering import TrainableBilateralFilterFunction +from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1.0, 1.0, 1.0, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999997, 0.000001, 0.000000, 0.000001, 0.999997] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.999995, 0.000001, 0.000000] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1, 1, 1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.714200, 0.158126, 0.061890, 0.158126, 0.714200] + ], + # Batch 1 + [ + # Channel 0 + [0.043465, 0.158126, 0.555452, 0.158126, 0.043465] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999994, 0.000002, 0.000002, 0.000002, 0.999994] + ], + # Batch 1 + [ + # Channel 0 + [0.000001, 0.000001, 0.999986, 0.000001, 0.000001] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.533282, 0.245915, 0.244711, 0.245915, 0.533282] + ], + # Batch 1 + [ + # Channel 0 + [0.125052, 0.126608, 0.333592, 0.126608, 0.125052] + ], + ], + ], + [ + # Case Description + "2 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.239789, 0.082990, 0.082630, 0.082990, 0.239789], + [0.082990, 0.081934, 0.081579, 0.081934, 0.082990], + [0.082630, 0.081579, 0.081225, 0.081579, 0.082630], + [0.082990, 0.081934, 0.081579, 0.081934, 0.082990], + [0.239789, 0.082990, 0.082630, 0.082990, 0.239789], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.024155, 0.024432, 0.024525, 0.024432, 0.024155], + [0.024432, 0.024712, 0.024806, 0.024712, 0.024432], + [0.024525, 0.024806, 0.080686, 0.024806, 0.024525], + [0.024432, 0.024712, 0.024806, 0.024712, 0.024432], + [0.024155, 0.024432, 0.024525, 0.024432, 0.024155], + ] + ], + ], + ], + [ + # Case Description + "3 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.098142, 0.030317, 0.030191, 0.030316, 0.098142], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.098142, 0.030317, 0.030191, 0.030317, 0.098142], + ], + # Frame 1 + [ + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + ], + # Frame 2 + [ + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029698, 0.029336, 0.029214, 0.029336, 0.029698], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + ], + # Frame 3 + [ + [0.030316, 0.029947, 0.029822, 0.029947, 0.030317], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + ], + # Frame 4 + [ + [0.098142, 0.030317, 0.030191, 0.030317, 0.098142], + [0.030317, 0.029947, 0.029822, 0.029947, 0.030316], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.030317, 0.029947, 0.029822, 0.029947, 0.030316], + [0.098142, 0.030317, 0.030191, 0.030316, 0.098142], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extension +class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precise(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +@skip_if_no_cuda +@skip_if_no_cpp_extension +class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precise(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cuda") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +if __name__ == "__main__": + unittest.main() From 19cf51e0bfd12d4d372b509fe7315864b559463a Mon Sep 17 00:00:00 2001 From: Fabian Wagner Date: Thu, 15 Dec 2022 15:24:45 +0100 Subject: [PATCH 2/9] Included the documentation for the TBF. Signed-off-by: Fabian Wagner --- docs/source/networks.rst | 5 +++++ monai/networks/layers/filtering.py | 12 ++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index a4c225de29..64ba0313c4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -349,6 +349,11 @@ Layers .. autoclass:: BilateralFilter :members: +`TrainableBilateralFilter` +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TrainableBilateralFilter + :members: + `PHLFilter` ~~~~~~~~~~~ .. autoclass:: PHLFilter diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 104458c94a..311616ba57 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -119,7 +119,7 @@ class TrainableBilateralFilterFunction(torch.autograd.Function): sigma z: trainable standard deviation of the spatial filter kernel in z direction. color sigma: trainable standard deviation of the intensity range kernel. This filter - parameter determines the degree of edge preservation. + parameter determines the degree of edge preservation. Returns: output (torch.Tensor): filtered tensor. @@ -204,21 +204,21 @@ class TrainableBilateralFilter(torch.nn.Module): sigma z: trainable standard deviation of the spatial filter kernel in z direction. - color sigma: trainable standard deviation of the intensity range kernel. This filter - parameter determines the degree of edge preservation. + sigma color: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. Returns: output (torch.Tensor): filtered tensor. """ - def __init__(self, sigma_x, sigma_y, sigma_z, color_sigma): + def __init__(self, sigma_x, sigma_y, sigma_z, sigma_color): super(TrainableBilateralFilter, self).__init__() # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(sigma_x)) self.sigma_y = torch.nn.Parameter(torch.tensor(sigma_y)) self.sigma_z = torch.nn.Parameter(torch.tensor(sigma_z)) - self.color_sigma = torch.nn.Parameter(torch.tensor(color_sigma)) + self.sigma_color = torch.nn.Parameter(torch.tensor(sigma_color)) def forward(self, input_tensor): assert input_tensor.shape[1] == 1, ( @@ -236,7 +236,7 @@ def forward(self, input_tensor): input_tensor = input_tensor.unsqueeze(4) prediction = TrainableBilateralFilterFunction.apply( - input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.color_sigma + input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. From 663b219920528faf916c8c8d37a7d384fef2330f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Dec 2022 14:43:30 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 311616ba57..1b4d717850 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -212,7 +212,7 @@ class TrainableBilateralFilter(torch.nn.Module): """ def __init__(self, sigma_x, sigma_y, sigma_z, sigma_color): - super(TrainableBilateralFilter, self).__init__() + super().__init__() # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(sigma_x)) From af66d72a132d6792768de39308a1e9c5e21c408a Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 15 Dec 2022 14:49:18 +0000 Subject: [PATCH 4/9] [MONAI] code formatting Signed-off-by: monai-bot --- .../bf_layer_cpu_backward.cpp | 453 ++++++++------- .../bf_layer_cpu_forward.cpp | 532 +++++++++--------- .../bf_layer_gpu_backward.cu | 184 +++--- .../bf_layer_gpu_forward.cu | 226 ++++---- .../trainable_bilateral.cpp | 60 +- .../trainable_bilateral/trainable_bilateral.h | 95 ++-- 6 files changed, 761 insertions(+), 789 deletions(-) diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp index b91e24458c..8e34d7a8b0 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -14,247 +14,240 @@ limitations under the License. #include "trainable_bilateral.h" struct Indexer { -public: - Indexer(int dimensions, int* sizes) { - m_dimensions = dimensions; - m_sizes = sizes; - m_index = new int[dimensions]{0}; + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer() { + delete[] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } } - ~Indexer(){ - delete [] m_index; - } - - bool operator++(int) { - for (int i = 0; i < m_dimensions; i++) { - m_index[i] += 1; - - if (m_index[i] < m_sizes[i]) { - return true; - } else { - m_index[i] = 0; - } - } - return false; - } + return false; + } - int& operator[](int dimensionIndex) { - return m_index[dimensionIndex]; - } + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } -private: - int m_dimensions; - int* m_sizes; - int* m_index; + private: + int m_dimensions; + int* m_sizes; + int* m_index; }; template -void BilateralFilterCpuBackward_3d(torch::Tensor gradientInputTensor, - torch::Tensor gradientOutputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { - // Getting tensor description. - TensorDescription desc = TensorDescription(gradientInputTensor); - - // Raw tensor data pointers. - scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr(); - scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr(); - scalar_t* inputTensorData = inputTensor.data_ptr(); - scalar_t* outputTensorData = outputTensor.data_ptr(); - scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); - scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); - - // Pre-calculate common values - int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size - int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size - int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size - int halfWindowSize_x = floor(0.5f * windowSize_x); - int halfWindowSize_y = floor(0.5f * windowSize_y); - int halfWindowSize_z = floor(0.5f * windowSize_z); - int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; - scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); - scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); - scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); - scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); - - // Set kernel sizes with respect to the defined spatial sigmas. - int* kernelSizes = new int[desc.dimensions]; - - kernelSizes[0] = windowSize_x; - kernelSizes[1] = windowSize_y; - kernelSizes[2] = windowSize_z; - - // Pre-calculate gaussian kernel and distance map in 1D. - scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; - scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; - scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; - scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; - scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; - scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; - - for (int i = 0; i < windowSize_x; i++) { - int distance = i - halfWindowSize_x; - gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); - xDistanceSquared[i] = distance * distance; - } - for (int i = 0; i < windowSize_y; i++) { - int distance = i - halfWindowSize_y; - gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); - yDistanceSquared[i] = distance * distance; - } - for (int i = 0; i < windowSize_z; i++) { - int distance = i - halfWindowSize_z; - gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); - zDistanceSquared[i] = distance * distance; - } +void BilateralFilterCpuBackward_3d( + torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(gradientInputTensor); + + // Raw tensor data pointers. + scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr(); + scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr(); + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t filter_kernel = 0; + scalar_t valueSum = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } - // Looping over the batches - for (int b = 0; b < desc.batchCount; b++) { - int batchOffset = b * desc.batchStride; + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; - // Looping over all dimensions for the home element - for (int z = 0; z < desc.sizes[2]; z++) -#pragma omp parallel for - for (int y = 0; y < desc.sizes[1]; y++) { - for (int x = 0; x < desc.sizes[0]; x++) { - // Calculating indexing offset for the home element - int homeOffset = batchOffset; - - int homeIndex[] = {x, y, z}; - homeOffset += x * desc.strides[0]; - homeOffset += y * desc.strides[1]; - homeOffset += z * desc.strides[2]; - - // Zero kernel aggregates. - scalar_t filter_kernel = 0; - scalar_t valueSum = 0; - - scalar_t weightSum = 0.0f; - - // Looping over all dimensions for the neighbour element - Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); - do // while(kernelIndex++) - { - // Calculating buffer offset for the neighbour element - // Index is clamped to the border in each dimension. - int neighbourOffset = batchOffset; - bool flagNotClamped = true; - - for (int i = 0; i < desc.dimensions; i++) { - int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; - int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); - neighbourOffset += neighbourIndexClamped * desc.strides[i]; - if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } - } - - // Euclidean color distance. - scalar_t colorDistance = 0; - scalar_t colorDistanceSquared = 0; - - for (int i = 0; i < desc.channelCount; i++) { - scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] - - inputTensorData[homeOffset + i * - desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) - colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. - colorDistanceSquared += diff * diff; - } - - // Calculating and combining the spatial - // and color weights. - scalar_t spatialWeight = 1; - - spatialWeight = gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * - gaussianKernel_z[kernelIndex[2]]; - - scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); - scalar_t totalWeight = spatialWeight * colorWeight; - - // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. - if (flagNotClamped) { - for (int i = 0; i < desc.channelCount; i++) { - - // Distinguish cases for k!=i (calculation is done here) - // and k==i (partial derivatives are precalculated). - // If statement replaces center element of neighborhood/kernel. - if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y || - kernelIndex[2] != halfWindowSize_z) { - - filter_kernel = - -(1 / - outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * - outputTensorData[neighbourOffset + i * desc.channelStride] * - totalWeight * - colorDistance / (colorSigma * colorSigma) + - (1 / - outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * - totalWeight * - (1 + - inputTensorData[homeOffset + i * desc.channelStride] * colorDistance / - (colorSigma * colorSigma)); // inputTensorData[homeOffset] !! - } else { - - filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride]; - } - - valueSum += - gradientInputTensorData[neighbourOffset + i * desc.channelStride] * - filter_kernel; - - } - - weightSum += totalWeight; - } - } while (kernelIndex++); - - // Do the filtering and calculate the values for the backward pass. - for (int i = 0; i < desc.channelCount; i++) { - // Filtering: - gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum; - } - } + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] - + inputTensorData[homeOffset + + i * desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y || + kernelIndex[2] != halfWindowSize_z) { + filter_kernel = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * + outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance / + (colorSigma * colorSigma) + + (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight * + (1 + + inputTensorData[homeOffset + i * desc.channelStride] * colorDistance / + (colorSigma * colorSigma)); // inputTensorData[homeOffset] !! + } else { + filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride]; } - } - delete[] kernelSizes; - delete[] gaussianKernel_x; - delete[] gaussianKernel_y; - delete[] gaussianKernel_z; - delete[] xDistanceSquared; - delete[] yDistanceSquared; - delete[] zDistanceSquared; + valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; } -torch::Tensor BilateralFilterCpuBackward(torch::Tensor gradientInputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { - // Preparing output tensor. - torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "BilateralFilterCpuBackward_3d", ([&] { - BilateralFilterCpuBackward_3d( - gradientInputTensor, - gradientOutputTensor, - inputTensor, - outputTensor, - outputWeightsTensor, - dO_dx_ki, - sigma_x, - sigma_y, - sigma_z, - colorSigma); - })); - - return gradientOutputTensor; +torch::Tensor BilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "BilateralFilterCpuBackward_3d", ([&] { + BilateralFilterCpuBackward_3d( + gradientInputTensor, + gradientOutputTensor, + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return gradientOutputTensor; } diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp index ccfa1fe640..f6f8979d29 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp @@ -14,293 +14,273 @@ limitations under the License. #include "trainable_bilateral.h" struct Indexer { -public: - Indexer(int dimensions, int* sizes) { - m_dimensions = dimensions; - m_sizes = sizes; - m_index = new int[dimensions]{0}; - } - ~Indexer(){ - delete [] m_index; + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer() { + delete[] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } } - bool operator++(int) { - for (int i = 0; i < m_dimensions; i++) { - m_index[i] += 1; + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; - if (m_index[i] < m_sizes[i]) { - return true; - } else { - m_index[i] = 0; +template +void BilateralFilterCpuForward_3d( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr(); + scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr(); + scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr(); + scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } } - } - return false; - } + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; - int& operator[](int dimensionIndex) { - return m_index[dimensionIndex]; - } + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } -private: - int m_dimensions; - int* m_sizes; - int* m_index; -}; + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; -template -void BilateralFilterCpuForward_3d(torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - torch::Tensor dO_dsig_r, - torch::Tensor dO_dsig_x, - torch::Tensor dO_dsig_y, - torch::Tensor dO_dsig_z, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { - // Getting tensor description. - TensorDescription desc = TensorDescription(inputTensor); - - // Raw tensor data pointers. - scalar_t* inputTensorData = inputTensor.data_ptr(); - scalar_t* outputTensorData = outputTensor.data_ptr(); - scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); - scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); - scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr(); - scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr(); - scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr(); - scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr(); - - // Pre-calculate common values - int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size - int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size - int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size - int halfWindowSize_x = floor(0.5f * windowSize_x); - int halfWindowSize_y = floor(0.5f * windowSize_y); - int halfWindowSize_z = floor(0.5f * windowSize_z); - int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; - scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); - scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); - scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); - scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); - - // Set kernel sizes with respect to the defined spatial sigmas. - int* kernelSizes = new int[desc.dimensions]; - - kernelSizes[0] = windowSize_x; - kernelSizes[1] = windowSize_y; - kernelSizes[2] = windowSize_z; - - // Pre-calculate gaussian kernel and distance map in 1D. - scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; - scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; - scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; - scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; - scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; - scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; - - for (int i = 0; i < windowSize_x; i++) { - int distance = i - halfWindowSize_x; - gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); - xDistanceSquared[i] = distance * distance; - } - for (int i = 0; i < windowSize_y; i++) { - int distance = i - halfWindowSize_y; - gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); - yDistanceSquared[i] = distance * distance; - } - for (int i = 0; i < windowSize_z; i++) { - int distance = i - halfWindowSize_z; - gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); - zDistanceSquared[i] = distance * distance; - } + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; - // Looping over the batches - for (int b = 0; b < desc.batchCount; b++) { - int batchOffset = b * desc.batchStride; + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; - // Looping over all dimensions for the home element - for (int z = 0; z < desc.sizes[2]; z++) -#pragma omp parallel for - for (int y = 0; y < desc.sizes[1]; y++) { - for (int x = 0; x < desc.sizes[0]; x++) { - // Calculating indexing offset for the home element - int homeOffset = batchOffset; - - int homeIndex[] = {x, y, z}; - homeOffset += x * desc.strides[0]; - homeOffset += y * desc.strides[1]; - homeOffset += z * desc.strides[2]; - - // Zero kernel aggregates. - scalar_t valueSum = 0; - scalar_t dw_dx_ki = 0; - scalar_t dfilter_dx_ki = 0; - scalar_t colorSum_w = 0; - scalar_t colorSum_alpha = 0; - scalar_t xSum_w = 0; - scalar_t xSum_alpha = 0; - scalar_t ySum_w = 0; - scalar_t ySum_alpha = 0; - scalar_t zSum_w = 0; - scalar_t zSum_alpha = 0; - - scalar_t weightSum = 0.0f; - - // Looping over all dimensions for the neighbour element - Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); - do // while(kernelIndex++) - { - // Calculating buffer offset for the neighbour element - // Index is clamped to the border in each dimension. - int neighbourOffset = batchOffset; - bool flagNotClamped = true; - - for (int i = 0; i < desc.dimensions; i++) { - int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; - int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); - neighbourOffset += neighbourIndexClamped * desc.strides[i]; - if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } - } - - // Euclidean color distance. - scalar_t colorDistance = 0; - scalar_t colorDistanceSquared = 0; - - for (int i = 0; i < desc.channelCount; i++) { - scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - - inputTensorData[neighbourOffset + i * desc.channelStride]; - colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. - colorDistanceSquared += diff * diff; - } - - // Calculating and combining the spatial - // and color weights. - scalar_t spatialWeight = 1; - - spatialWeight = gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * - gaussianKernel_z[kernelIndex[2]]; - - scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); - scalar_t totalWeight = spatialWeight * colorWeight; - - // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. - if (flagNotClamped) { - for (int i = 0; i < desc.channelCount; i++) { - valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; - - // Derivative of weights with respect to X_i while i=k. - dw_dx_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma); - // Derivative of convolved image with respect to X_i while i=k. - dfilter_dx_ki += - (-1) * totalWeight * - inputTensorData[neighbourOffset + i * desc.channelStride] * - colorDistance / (colorSigma * - colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData - - colorSum_w += - totalWeight * colorDistanceSquared / - std::abs(colorSigma * colorSigma * colorSigma); - colorSum_alpha += - totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * - colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); - - xSum_w += - totalWeight * xDistanceSquared[kernelIndex[0]] / - std::abs(sigma_x * sigma_x * sigma_x); - xSum_alpha += - totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * - xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); - - ySum_w += - totalWeight * yDistanceSquared[kernelIndex[1]] / - std::abs(sigma_y * sigma_y * sigma_y); - ySum_alpha += - totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * - yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); - - zSum_w += - totalWeight * zDistanceSquared[kernelIndex[2]] / - std::abs(sigma_z * sigma_z * sigma_z); - zSum_alpha += - totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * - zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); - } - - weightSum += totalWeight; - } - } while (kernelIndex++); - - // Do the filtering and calculate the values for the backward pass. - for (int i = 0; i < desc.channelCount; i++) { - - // Filtering: - outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum; - - // Pre-computations for the backward pass: - outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum; - dO_dx_kiData[homeOffset + i * desc.channelStride] = - -(1 / weightSum) * (valueSum / weightSum) * - dw_dx_ki + - (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here - dO_dsig_rData[homeOffset + i * desc.channelStride] = - -(1 / weightSum) * (valueSum / weightSum) * - colorSum_w + (1 / weightSum) * colorSum_alpha; - dO_dsig_xData[homeOffset + i * desc.channelStride] = - -(1 / weightSum) * (valueSum / weightSum) * - xSum_w + (1 / weightSum) * xSum_alpha; - dO_dsig_yData[homeOffset + i * desc.channelStride] = - -(1 / weightSum) * (valueSum / weightSum) * - ySum_w + (1 / weightSum) * ySum_alpha; - dO_dsig_zData[homeOffset + i * desc.channelStride] = - -(1 / weightSum) * (valueSum / weightSum) * - zSum_w + (1 / weightSum) * zSum_alpha; - } - } - } - } + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistance / + (colorSigma * + colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData - delete[] kernelSizes; - delete[] gaussianKernel_x; - delete[] gaussianKernel_y; - delete[] gaussianKernel_z; - delete[] xDistanceSquared; - delete[] yDistanceSquared; - delete[] zDistanceSquared; + colorSum_w += totalWeight * colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + colorSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + + xSum_w += totalWeight * xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + xSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + + ySum_w += totalWeight * yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + ySum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + + zSum_w += totalWeight * zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + zSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum; + dO_dx_kiData[homeOffset + i * desc.channelStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_rData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_xData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_yData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_zData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; } -std::tuple BilateralFilterCpuForward(torch::Tensor inputTensor, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { - // Preparing output tensor. - torch::Tensor outputTensor = torch::zeros_like(inputTensor); - torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); - torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); - torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); - torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); - torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); - torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpuForward_3d", ([&] { - BilateralFilterCpuForward_3d( - inputTensor, - outputTensor, - outputWeightsTensor, - dO_dx_ki, - dO_dsig_r, - dO_dsig_x, - dO_dsig_y, - dO_dsig_z, - sigma_x, - sigma_y, - sigma_z, - colorSigma); - })); - - return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +std::tuple +BilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpuForward_3d", ([&] { + BilateralFilterCpuForward_3d( + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + dO_dsig_r, + dO_dsig_x, + dO_dsig_y, + dO_dsig_z, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; } diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu index d7d2a3960a..ffd8eb32a2 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -38,14 +38,14 @@ __constant__ float cSigma_yBack; __constant__ float cSigma_zBack; __constant__ float cColorSigmaBack; - template -__global__ void BilateralFilterCudaKernel3DBackward(scalar_t* gradientInputTensor, - scalar_t* gradientOutputTensor, - scalar_t* inputTensor, - scalar_t* outputTensor, - scalar_t* outputWeightsTensor, - scalar_t* dO_dx_ki) { +__global__ void BilateralFilterCudaKernel3DBackward( + scalar_t* gradientInputTensor, + scalar_t* gradientOutputTensor, + scalar_t* inputTensor, + scalar_t* outputTensor, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStrideBack; @@ -77,13 +77,16 @@ __global__ void BilateralFilterCudaKernel3DBackward(scalar_t* gradientInputTenso bool flagNotClamped = true; int kernelIndex[] = {kernelX, kernelY, kernelZ}; - int dimensions = 3; // Must equal the number of spatial dimensions. + int dimensions = 3; // Must equal the number of spatial dimensions. for (int i = 0; i < dimensions; i++) { - int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!), otherwise: cudaErrorMisalignedAddress - int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; - int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex)); - if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } + int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } } scalar_t colorDistance = 0; @@ -92,7 +95,8 @@ __global__ void BilateralFilterCudaKernel3DBackward(scalar_t* gradientInputTenso #pragma unroll for (int c = 0; c < C; c++) { scalar_t a = inputTensor[batchOffset + neighbourOffset + c * cColorStrideBack]; - scalar_t b = inputTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) + scalar_t b = inputTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (X_k - + // X_i) and not (X_i - X_q) scalar_t diff = a - b; colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. colorDistanceSquared += diff * diff; @@ -104,41 +108,30 @@ __global__ void BilateralFilterCudaKernel3DBackward(scalar_t* gradientInputTenso // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. if (flagNotClamped) { - scalar_t filter_kernel_back; + scalar_t filter_kernel_back; #pragma unroll - for (int c = 0; c < C; c++) { - // Distinguish cases for k!=i (calculation is done here) - // and k==i (partial derivatives are precalculated). - // If statement replaces center element of neighborhood/kernel. - if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] || - kernelZ != cHalfWindowSize_arrBack[2]) { - - filter_kernel_back = - -(1 / - outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * - outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * - totalWeight * - colorDistance / (cColorSigmaBack * cColorSigmaBack) + - (1 / - outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * - totalWeight * - (1 + - inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance / - (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !! - } else { - - filter_kernel_back = dO_dx_ki[batchOffset + homeOffset + c * cColorStrideBack]; - } - - valueSum += - gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * - filter_kernel_back; - + for (int c = 0; c < C; c++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] || + kernelZ != cHalfWindowSize_arrBack[2]) { + filter_kernel_back = -(1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * + outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * totalWeight * colorDistance / + (cColorSigmaBack * cColorSigmaBack) + + (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight * + (1 + + inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance / + (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !! + } else { + filter_kernel_back = dO_dx_ki[batchOffset + homeOffset + c * cColorStrideBack]; } - weightSum += totalWeight; + valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; + } + weightSum += totalWeight; } } } @@ -147,21 +140,21 @@ __global__ void BilateralFilterCudaKernel3DBackward(scalar_t* gradientInputTenso #pragma unroll for (int c = 0; c < C; c++) { gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSum; - } } template -void BilateralFilterCudaBackwardFunction(torch::Tensor gradientInputTensor, - torch::Tensor gradientOutputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { +void BilateralFilterCudaBackwardFunction( + torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { // Getting tensor description. TensorDescription desc = TensorDescription(inputTensor); @@ -191,19 +184,19 @@ void BilateralFilterCudaBackwardFunction(torch::Tensor gradientInputTensor, auto* zDistanceSquared = new float[windowSize_z]; for (int i = 0; i < windowSize_x; i++) { - int distance = i - halfWindowSize_x; - gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); - xDistanceSquared[i] = distance * distance; + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; } for (int i = 0; i < windowSize_y; i++) { - int distance = i - halfWindowSize_y; - gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); - yDistanceSquared[i] = distance * distance; + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; } for (int i = 0; i < windowSize_z; i++) { - int distance = i - halfWindowSize_z; - gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); - zDistanceSquared[i] = distance * distance; + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; } // Writing constant memory. @@ -225,24 +218,24 @@ void BilateralFilterCudaBackwardFunction(torch::Tensor gradientInputTensor, cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float)); cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float)); -// cuda_error_check("Cuda check before kernel call."); + // cuda_error_check("Cuda check before kernel call."); #define BLOCK_SIZE 32 AT_DISPATCH_FLOATING_TYPES_AND_HALF( inputTensor.scalar_type(), "BilateralFilterCudaKernel3DBackward", ([&] { - BilateralFilterCudaKernel3DBackward - <<>>( - gradientInputTensor.data_ptr(), - gradientOutputTensor.data_ptr(), - inputTensor.data_ptr(), - outputTensor.data_ptr(), - outputWeightsTensor.data_ptr(), - dO_dx_ki.data_ptr()); + BilateralFilterCudaKernel3DBackward + <<>>( + gradientInputTensor.data_ptr(), + gradientOutputTensor.data_ptr(), + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr()); })); -// cuda_error_check("Cuda check after kernel call."); -// delete[] kernel; + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; delete[] kernelSizes; delete[] gaussianKernel_x; delete[] gaussianKernel_y; @@ -253,20 +246,37 @@ void BilateralFilterCudaBackwardFunction(torch::Tensor gradientInputTensor, } // Function to choose template implementation based on dynamic, channels and dimensions -torch::Tensor BilateralFilterCudaBackward(torch::Tensor gradientInputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { +torch::Tensor BilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); -// cuda_error_check("beginning"); - -#define CASE(c, d) BilateralFilterCudaBackwardFunction(gradientInputTensor, gradientOutputTensor, inputTensor, outputTensor, outputWeightsTensor, dO_dx_ki, sigma_x, sigma_y, sigma_z, colorSigma); - SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, gradientInputTensor.size(1), gradientInputTensor.dim() - 2); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + BilateralFilterCudaBackwardFunction( \ + gradientInputTensor, \ + gradientOutputTensor, \ + inputTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dx_ki, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB( + CASE, + BF_CUDA_MAX_CHANNELS, + BF_CUDA_MAX_SPATIAL_DIMENSION, + gradientInputTensor.size(1), + gradientInputTensor.dim() - 2); return gradientOutputTensor; } diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu index 36473ee379..29278feca1 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu @@ -38,16 +38,16 @@ __constant__ float cSigma_y; __constant__ float cSigma_z; __constant__ float cColorSigma; - template -__global__ void BilateralFilterCudaKernel3DForward(scalar_t* input, - scalar_t* output, - scalar_t* outputWeightsTensor, - scalar_t* dO_dx_ki, - scalar_t* dO_dsig_r, - scalar_t* dO_dsig_x, - scalar_t* dO_dsig_y, - scalar_t* dO_dsig_z) { +__global__ void BilateralFilterCudaKernel3DForward( + scalar_t* input, + scalar_t* output, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki, + scalar_t* dO_dsig_r, + scalar_t* dO_dsig_x, + scalar_t* dO_dsig_y, + scalar_t* dO_dsig_z) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; @@ -89,13 +89,16 @@ __global__ void BilateralFilterCudaKernel3DForward(scalar_t* input, bool flagNotClamped = true; int kernelIndex[] = {kernelX, kernelY, kernelZ}; - int dimensions = 3; // Must equal the number of spatial dimensions. + int dimensions = 3; // Must equal the number of spatial dimensions. for (int i = 0; i < dimensions; i++) { - int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!), otherwise: cudaErrorMisalignedAddress - int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; - int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex)); - if (neighbourIndex != neighbourIndexClamped) { flagNotClamped = false; } + int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } } scalar_t colorDistance = 0; @@ -104,7 +107,8 @@ __global__ void BilateralFilterCudaKernel3DForward(scalar_t* input, #pragma unroll for (int c = 0; c < C; c++) { scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; - scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward the other way around !! + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward the + // other way around !! scalar_t diff = a - b; colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. colorDistanceSquared += diff * diff; @@ -116,50 +120,36 @@ __global__ void BilateralFilterCudaKernel3DForward(scalar_t* input, // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. if (flagNotClamped) { - #pragma unroll - for (int c = 0; c < C; c++) { - valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight; - - // Derivative of weights with respect to X_i while i=k. - dw_dx_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma); - // Derivative of convolved image with respect to X_i while i=k. - dfilter_dx_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * - colorDistance / (cColorSigma * - cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData - - colorSum_w += - totalWeight * colorDistanceSquared / - std::abs(cColorSigma * cColorSigma * cColorSigma); - colorSum_alpha += - totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * - colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); - - xSum_w += - totalWeight * cXDistanceSquared[kernelX] / - std::abs(cSigma_x * cSigma_x * cSigma_x); - xSum_alpha += - totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * - cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); - - ySum_w += - totalWeight * cYDistanceSquared[kernelY] / - std::abs(cSigma_y * cSigma_y * cSigma_y); - ySum_alpha += - totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * - cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); - - zSum_w += - totalWeight * cZDistanceSquared[kernelZ] / - std::abs(cSigma_z * cSigma_z * cSigma_z); - zSum_alpha += - totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * - cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); - - } - - weightSum += totalWeight; - + for (int c = 0; c < C; c++) { + valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistance / + (cColorSigma * + cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + colorSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + + xSum_w += totalWeight * cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + xSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + + ySum_w += totalWeight * cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + ySum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + + zSum_w += totalWeight * cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + zSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + } + + weightSum += totalWeight; } } } @@ -167,43 +157,38 @@ __global__ void BilateralFilterCudaKernel3DForward(scalar_t* input, #pragma unroll for (int c = 0; c < C; c++) { -// output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + // output[batchOffset + homeOffset + c * cColorStride] /= weightSum; output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum; // Pre-computations for the backward pass: outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum; - dO_dx_ki[batchOffset + homeOffset + c * cColorStride] = - -(1 / weightSum) * (valueSum / weightSum) * - dw_dx_ki + - (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dx_ki[batchOffset + homeOffset + c * cColorStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here dO_dsig_r[batchOffset + homeOffset + c * cColorStride] = - -(1 / weightSum) * (valueSum / weightSum) * - colorSum_w + (1 / weightSum) * colorSum_alpha; + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; dO_dsig_x[batchOffset + homeOffset + c * cColorStride] = - -(1 / weightSum) * (valueSum / weightSum) * - xSum_w + (1 / weightSum) * xSum_alpha; + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; dO_dsig_y[batchOffset + homeOffset + c * cColorStride] = - -(1 / weightSum) * (valueSum / weightSum) * - ySum_w + (1 / weightSum) * ySum_alpha; + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; dO_dsig_z[batchOffset + homeOffset + c * cColorStride] = - -(1 / weightSum) * (valueSum / weightSum) * - zSum_w + (1 / weightSum) * zSum_alpha; + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; } } template -void BilateralFilterCudaForwardFunction(torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - torch::Tensor dO_dsig_r, - torch::Tensor dO_dsig_x, - torch::Tensor dO_dsig_y, - torch::Tensor dO_dsig_z, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { +void BilateralFilterCudaForwardFunction( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { // Getting tensor description. TensorDescription desc = TensorDescription(inputTensor); @@ -233,19 +218,19 @@ void BilateralFilterCudaForwardFunction(torch::Tensor inputTensor, auto* zDistanceSquared = new float[windowSize_z]; for (int i = 0; i < windowSize_x; i++) { - int distance = i - halfWindowSize_x; - gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); - xDistanceSquared[i] = distance * distance; + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; } for (int i = 0; i < windowSize_y; i++) { - int distance = i - halfWindowSize_y; - gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); - yDistanceSquared[i] = distance * distance; + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; } for (int i = 0; i < windowSize_z; i++) { - int distance = i - halfWindowSize_z; - gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); - zDistanceSquared[i] = distance * distance; + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; } // Writing constant memory. @@ -267,27 +252,26 @@ void BilateralFilterCudaForwardFunction(torch::Tensor inputTensor, cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float)); cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float)); -// cuda_error_check("Cuda check before kernel call."); + // cuda_error_check("Cuda check before kernel call."); #define BLOCK_SIZE 32 AT_DISPATCH_FLOATING_TYPES_AND_HALF( inputTensor.scalar_type(), "BilateralFilterCudaKernel3DForward", ([&] { - BilateralFilterCudaKernel3DForward - <<>>( - inputTensor.data_ptr(), - outputTensor.data_ptr(), - outputWeightsTensor.data_ptr(), - dO_dx_ki.data_ptr(), - dO_dsig_r.data_ptr(), - dO_dsig_x.data_ptr(), - dO_dsig_y.data_ptr(), - dO_dsig_z.data_ptr()); - } - )); - -// cuda_error_check("Cuda check after kernel call."); -// delete[] kernel; + BilateralFilterCudaKernel3DForward + <<>>( + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr(), + dO_dsig_r.data_ptr(), + dO_dsig_x.data_ptr(), + dO_dsig_y.data_ptr(), + dO_dsig_z.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; delete[] kernelSizes; delete[] gaussianKernel_x; delete[] gaussianKernel_y; @@ -298,11 +282,8 @@ void BilateralFilterCudaForwardFunction(torch::Tensor inputTensor, } // Function to choose template implementation based on dynamic, channels and dimensions -std::tuple BilateralFilterCudaForward(torch::Tensor inputTensor, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { +std::tuple +BilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) { torch::Tensor outputTensor = torch::zeros_like(inputTensor); torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); @@ -310,9 +291,22 @@ std::tuple(inputTensor, outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z, sigma_x, sigma_y, sigma_z, colorSigma); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + BilateralFilterCudaForwardFunction( \ + inputTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dx_ki, \ + dO_dsig_r, \ + dO_dsig_x, \ + dO_dsig_y, \ + dO_dsig_z, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2); return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp index 90cd609f5f..0cef92810f 100644 --- a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp @@ -18,24 +18,15 @@ limitations under the License. #include "trainable_bilateral.h" #include "utils/common_utils.h" -std::tuple TrainableBilateralFilterForward(torch::Tensor inputTensor, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma) { - std::tuple (*filterFunction)(torch::Tensor, float, float, float, float); +std::tuple +TrainableBilateralFilterForward( + torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + std::tuple ( + *filterFunction)(torch::Tensor, float, float, float, float); #ifdef WITH_CUDA @@ -64,16 +55,18 @@ std::tuple +#include +#include +#include #include "utils/common_utils.h" #include "utils/tensor_description.h" -#include -#include -#include #define BF_CUDA_MAX_CHANNELS 16 #define BF_CUDA_MAX_SPATIAL_DIMENSION 3 #ifdef WITH_CUDA -std::tuple BilateralFilterCudaForward(torch::Tensor inputTensor, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma); -torch::Tensor BilateralFilterCudaBackward(torch::Tensor gradientInputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma); +std::tuple +BilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma); +torch::Tensor BilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); #endif -std::tuple BilateralFilterCpuForward(torch::Tensor inputTensor, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma); +std::tuple +BilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma); -torch::Tensor BilateralFilterCpuBackward(torch::Tensor gradientInputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma); +torch::Tensor BilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); -std::tuple TrainableBilateralFilterForward(torch::Tensor inputTensor, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma); +std::tuple +TrainableBilateralFilterForward( + torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); -torch::Tensor TrainableBilateralFilterBackward(torch::Tensor gradientInputTensor, - torch::Tensor inputTensor, - torch::Tensor outputTensor, - torch::Tensor outputWeightsTensor, - torch::Tensor dO_dx_ki, - float sigma_x, - float sigma_y, - float sigma_z, - float colorSigma); +torch::Tensor TrainableBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); From 68ff70b3f0455bae709b787b45c078ed53d6093b Mon Sep 17 00:00:00 2001 From: Fabian Wagner Date: Mon, 16 Jan 2023 16:32:19 +0100 Subject: [PATCH 5/9] Remove unused variable. --- .../filtering/trainable_bilateral/bf_layer_cpu_backward.cpp | 3 --- .../filtering/trainable_bilateral/bf_layer_gpu_backward.cu | 2 -- 2 files changed, 5 deletions(-) diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp index 8e34d7a8b0..bf18689902 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -136,8 +136,6 @@ void BilateralFilterCpuBackward_3d( scalar_t filter_kernel = 0; scalar_t valueSum = 0; - scalar_t weightSum = 0.0f; - // Looping over all dimensions for the neighbour element Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); do // while(kernelIndex++) @@ -200,7 +198,6 @@ void BilateralFilterCpuBackward_3d( valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; } - weightSum += totalWeight; } } while (kernelIndex++); diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu index ffd8eb32a2..f9235af8b5 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -59,7 +59,6 @@ __global__ void BilateralFilterCudaKernel3DBackward( // Zero kernel aggregates. scalar_t valueSum = 0; - scalar_t weightSum = 0; for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); @@ -131,7 +130,6 @@ __global__ void BilateralFilterCudaKernel3DBackward( valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; } - weightSum += totalWeight; } } } From b3e0fa0e71f9fc446c9db57d4e00927433b54738 Mon Sep 17 00:00:00 2001 From: Fabian Wagner Date: Mon, 16 Jan 2023 17:33:30 +0100 Subject: [PATCH 6/9] Revert "Remove unused variable." This reverts commit 68ff70b3f0455bae709b787b45c078ed53d6093b. --- .../filtering/trainable_bilateral/bf_layer_cpu_backward.cpp | 3 +++ .../filtering/trainable_bilateral/bf_layer_gpu_backward.cu | 2 ++ 2 files changed, 5 insertions(+) diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp index bf18689902..8e34d7a8b0 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -136,6 +136,8 @@ void BilateralFilterCpuBackward_3d( scalar_t filter_kernel = 0; scalar_t valueSum = 0; + scalar_t weightSum = 0.0f; + // Looping over all dimensions for the neighbour element Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); do // while(kernelIndex++) @@ -198,6 +200,7 @@ void BilateralFilterCpuBackward_3d( valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; } + weightSum += totalWeight; } } while (kernelIndex++); diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu index f9235af8b5..ffd8eb32a2 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -59,6 +59,7 @@ __global__ void BilateralFilterCudaKernel3DBackward( // Zero kernel aggregates. scalar_t valueSum = 0; + scalar_t weightSum = 0; for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); @@ -130,6 +131,7 @@ __global__ void BilateralFilterCudaKernel3DBackward( valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; } + weightSum += totalWeight; } } } From 30f5b379494711b3af7856ce5b5f736b3e7ac9eb Mon Sep 17 00:00:00 2001 From: Fabian Wagner Date: Mon, 16 Jan 2023 17:33:30 +0100 Subject: [PATCH 7/9] Revert "Remove unused variable." This reverts commit 68ff70b3f0455bae709b787b45c078ed53d6093b. Signed-off-by: Fabian Wagner --- .../filtering/trainable_bilateral/bf_layer_cpu_backward.cpp | 3 +++ .../filtering/trainable_bilateral/bf_layer_gpu_backward.cu | 2 ++ 2 files changed, 5 insertions(+) diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp index bf18689902..8e34d7a8b0 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -136,6 +136,8 @@ void BilateralFilterCpuBackward_3d( scalar_t filter_kernel = 0; scalar_t valueSum = 0; + scalar_t weightSum = 0.0f; + // Looping over all dimensions for the neighbour element Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); do // while(kernelIndex++) @@ -198,6 +200,7 @@ void BilateralFilterCpuBackward_3d( valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; } + weightSum += totalWeight; } } while (kernelIndex++); diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu index f9235af8b5..ffd8eb32a2 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -59,6 +59,7 @@ __global__ void BilateralFilterCudaKernel3DBackward( // Zero kernel aggregates. scalar_t valueSum = 0; + scalar_t weightSum = 0; for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); @@ -130,6 +131,7 @@ __global__ void BilateralFilterCudaKernel3DBackward( valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; } + weightSum += totalWeight; } } } From 607298ad22cf1b33abe05a47df17da2aafdb0f73 Mon Sep 17 00:00:00 2001 From: Fabian Wagner Date: Mon, 16 Jan 2023 17:52:37 +0100 Subject: [PATCH 8/9] Remove unused variable in backward. Signed-off-by: Fabian Wagner --- .../filtering/trainable_bilateral/bf_layer_cpu_backward.cpp | 3 --- .../filtering/trainable_bilateral/bf_layer_gpu_backward.cu | 2 -- 2 files changed, 5 deletions(-) diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp index 8e34d7a8b0..bf18689902 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -136,8 +136,6 @@ void BilateralFilterCpuBackward_3d( scalar_t filter_kernel = 0; scalar_t valueSum = 0; - scalar_t weightSum = 0.0f; - // Looping over all dimensions for the neighbour element Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); do // while(kernelIndex++) @@ -200,7 +198,6 @@ void BilateralFilterCpuBackward_3d( valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; } - weightSum += totalWeight; } } while (kernelIndex++); diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu index ffd8eb32a2..f9235af8b5 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -59,7 +59,6 @@ __global__ void BilateralFilterCudaKernel3DBackward( // Zero kernel aggregates. scalar_t valueSum = 0; - scalar_t weightSum = 0; for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); @@ -131,7 +130,6 @@ __global__ void BilateralFilterCudaKernel3DBackward( valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; } - weightSum += totalWeight; } } } From 6953305e9dc39089a0173bf8e6b90995601460ce Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 17 Jan 2023 08:00:11 +0000 Subject: [PATCH 9/9] [MONAI] code formatting Signed-off-by: monai-bot --- .../filtering/trainable_bilateral/bf_layer_cpu_backward.cpp | 1 - .../csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu | 1 - tests/test_trainable_bilateral.py | 2 ++ 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp index bf18689902..a9e224d9dc 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -197,7 +197,6 @@ void BilateralFilterCpuBackward_3d( valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; } - } } while (kernelIndex++); diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu index f9235af8b5..23c532af41 100644 --- a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -129,7 +129,6 @@ __global__ void BilateralFilterCudaKernel3DBackward( valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; } - } } } diff --git a/tests/test_trainable_bilateral.py b/tests/test_trainable_bilateral.py index a42794616a..1300e5068d 100644 --- a/tests/test_trainable_bilateral.py +++ b/tests/test_trainable_bilateral.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np