diff --git a/docs/source/networks.rst b/docs/source/networks.rst index a4c225de29..64ba0313c4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -349,6 +349,11 @@ Layers .. autoclass:: BilateralFilter :members: +`TrainableBilateralFilter` +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TrainableBilateralFilter + :members: + `PHLFilter` ~~~~~~~~~~~ .. autoclass:: PHLFilter diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index ac43e6fd3e..fc47247473 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -22,6 +22,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // filtering m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter"); + m.def("tbf_forward", &TrainableBilateralFilterForward, "Trainable Bilateral Filter Forward"); + m.def("tbf_backward", &TrainableBilateralFilterBackward, "Trainable Bilateral Filter Backward"); // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index 3e680010ed..2b0c499259 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -15,3 +15,4 @@ limitations under the License. #include "bilateral/bilateral.h" #include "permutohedral/permutohedral.h" +#include "trainable_bilateral/trainable_bilateral.h" diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp new file mode 100644 index 0000000000..a9e224d9dc --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -0,0 +1,249 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_bilateral.h" + +struct Indexer { + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer() { + delete[] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpuBackward_3d( + torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(gradientInputTensor); + + // Raw tensor data pointers. + scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr(); + scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr(); + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t filter_kernel = 0; + scalar_t valueSum = 0; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] - + inputTensorData[homeOffset + + i * desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y || + kernelIndex[2] != halfWindowSize_z) { + filter_kernel = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * + outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance / + (colorSigma * colorSigma) + + (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight * + (1 + + inputTensorData[homeOffset + i * desc.channelStride] * colorDistance / + (colorSigma * colorSigma)); // inputTensorData[homeOffset] !! + } else { + filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride]; + } + + valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; + } + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +torch::Tensor BilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "BilateralFilterCpuBackward_3d", ([&] { + BilateralFilterCpuBackward_3d( + gradientInputTensor, + gradientOutputTensor, + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return gradientOutputTensor; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp new file mode 100644 index 0000000000..f6f8979d29 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp @@ -0,0 +1,286 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_bilateral.h" + +struct Indexer { + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer() { + delete[] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpuForward_3d( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr(); + scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr(); + scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr(); + scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistance / + (colorSigma * + colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + colorSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + + xSum_w += totalWeight * xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + xSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + + ySum_w += totalWeight * yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + ySum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + + zSum_w += totalWeight * zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + zSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum; + dO_dx_kiData[homeOffset + i * desc.channelStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_rData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_xData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_yData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_zData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +std::tuple +BilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpuForward_3d", ([&] { + BilateralFilterCpuForward_3d( + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + dO_dsig_r, + dO_dsig_x, + dO_dsig_y, + dO_dsig_z, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu new file mode 100644 index 0000000000..23c532af41 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -0,0 +1,279 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" + +__constant__ int cBatchStrideBack; +__constant__ int cColorStrideBack; + +__constant__ int cSizesBack[3]; +__constant__ int cStridesBack[3]; + +__constant__ int cKernelSizesBack[3]; +__constant__ int cHalfWindowSize_arrBack[3]; +__constant__ float cGaussianKernel_xBack[256]; +__constant__ float cGaussianKernel_yBack[256]; +__constant__ float cGaussianKernel_zBack[256]; +__constant__ float cXDistanceSquaredBack[256]; +__constant__ float cYDistanceSquaredBack[256]; +__constant__ float cZDistanceSquaredBack[256]; +__constant__ float cColorExponentConstantBack; +__constant__ float cSigma_xBack; +__constant__ float cSigma_yBack; +__constant__ float cSigma_zBack; +__constant__ float cColorSigmaBack; + +template +__global__ void BilateralFilterCudaKernel3DBackward( + scalar_t* gradientInputTensor, + scalar_t* gradientOutputTensor, + scalar_t* inputTensor, + scalar_t* outputTensor, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStrideBack; + + if (homeOffset >= cColorStrideBack) + return; + + int homeX = homeOffset / cStridesBack[0]; + int homeY = (homeOffset - homeX * cStridesBack[0]) / cStridesBack[1]; + int homeZ = (homeOffset - homeX * cStridesBack[0] - homeY * cStridesBack[1]) / cStridesBack[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); + scalar_t gaussianX = cGaussianKernel_xBack[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizesBack[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arrBack[1]), cSizesBack[1] - 1)); + scalar_t gaussianY = cGaussianKernel_yBack[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizesBack[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arrBack[2]), cSizesBack[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_zBack[kernelZ]; + + int neighbourOffset = neighbourX * cStridesBack[0] + neighbourY * cStridesBack[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = inputTensor[batchOffset + neighbourOffset + c * cColorStrideBack]; + scalar_t b = inputTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (X_k - + // X_i) and not (X_i - X_q) + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstantBack * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + scalar_t filter_kernel_back; + +#pragma unroll + for (int c = 0; c < C; c++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] || + kernelZ != cHalfWindowSize_arrBack[2]) { + filter_kernel_back = -(1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * + outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * totalWeight * colorDistance / + (cColorSigmaBack * cColorSigmaBack) + + (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight * + (1 + + inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance / + (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !! + } else { + filter_kernel_back = dO_dx_ki[batchOffset + homeOffset + c * cColorStrideBack]; + } + + valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; + } + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSum; + } +} + +template +void BilateralFilterCudaBackwardFunction( + torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStrideBack, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStrideBack, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizesBack, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStridesBack, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizesBack, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arrBack, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_xBack, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_yBack, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_zBack, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquaredBack, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquaredBack, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquaredBack, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstantBack, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_xBack, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_yBack, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float)); + + // cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel3DBackward", ([&] { + BilateralFilterCudaKernel3DBackward + <<>>( + gradientInputTensor.data_ptr(), + gradientOutputTensor.data_ptr(), + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + BilateralFilterCudaBackwardFunction( \ + gradientInputTensor, \ + gradientOutputTensor, \ + inputTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dx_ki, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB( + CASE, + BF_CUDA_MAX_CHANNELS, + BF_CUDA_MAX_SPATIAL_DIMENSION, + gradientInputTensor.size(1), + gradientInputTensor.dim() - 2); + + return gradientOutputTensor; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu new file mode 100644 index 0000000000..29278feca1 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu @@ -0,0 +1,313 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSizes[3]; +__constant__ int cHalfWindowSize_arr[3]; +__constant__ float cGaussianKernel_x[256]; +__constant__ float cGaussianKernel_y[256]; +__constant__ float cGaussianKernel_z[256]; +__constant__ float cXDistanceSquared[256]; +__constant__ float cYDistanceSquared[256]; +__constant__ float cZDistanceSquared[256]; +__constant__ float cColorExponentConstant; +__constant__ float cSigma_x; +__constant__ float cSigma_y; +__constant__ float cSigma_z; +__constant__ float cColorSigma; + +template +__global__ void BilateralFilterCudaKernel3DForward( + scalar_t* input, + scalar_t* output, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki, + scalar_t* dO_dsig_r, + scalar_t* dO_dsig_x, + scalar_t* dO_dsig_y, + scalar_t* dO_dsig_z) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizes[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arr[0]), cSizes[0] - 1)); + scalar_t gaussianX = cGaussianKernel_x[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizes[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arr[1]), cSizes[1] - 1)); + scalar_t gaussianY = cGaussianKernel_y[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizes[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arr[2]), cSizes[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_z[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward the + // other way around !! + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstant * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { +#pragma unroll + for (int c = 0; c < C; c++) { + valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistance / + (cColorSigma * + cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + colorSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + + xSum_w += totalWeight * cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + xSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + + ySum_w += totalWeight * cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + ySum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + + zSum_w += totalWeight * cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + zSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + } + + weightSum += totalWeight; + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + // output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum; + dO_dx_ki[batchOffset + homeOffset + c * cColorStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_r[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_x[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_y[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_z[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } +} + +template +void BilateralFilterCudaForwardFunction( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizes, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arr, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_x, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_y, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_z, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquared, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquared, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquared, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstant, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_x, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_y, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float)); + + // cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel3DForward", ([&] { + BilateralFilterCudaKernel3DForward + <<>>( + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr(), + dO_dsig_r.data_ptr(), + dO_dsig_x.data_ptr(), + dO_dsig_y.data_ptr(), + dO_dsig_z.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +std::tuple +BilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + BilateralFilterCudaForwardFunction( \ + inputTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dx_ki, \ + dO_dsig_r, \ + dO_dsig_x, \ + dO_dsig_y, \ + dO_dsig_z, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp new file mode 100644 index 0000000000..0cef92810f --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp @@ -0,0 +1,105 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "trainable_bilateral.h" +#include "utils/common_utils.h" + +std::tuple +TrainableBilateralFilterForward( + torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + std::tuple ( + *filterFunction)(torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && inputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(inputTensor); + + if (inputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (inputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &BilateralFilterCudaForward; + } else { + filterFunction = &BilateralFilterCpuForward; + } +#else + filterFunction = &BilateralFilterCpuForward; +#endif + + return filterFunction(inputTensor, sigma_x, sigma_y, sigma_z, colorSigma); +} + +torch::Tensor TrainableBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor (*filterFunction)( + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && gradientInputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(gradientInputTensor); + + if (gradientInputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (gradientInputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &BilateralFilterCudaBackward; + } else { + filterFunction = &BilateralFilterCpuBackward; + } +#else + filterFunction = &BilateralFilterCpuBackward; +#endif + + return filterFunction( + gradientInputTensor, + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); +} diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h new file mode 100644 index 0000000000..ab5c6e1a74 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h @@ -0,0 +1,72 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include +#include "utils/common_utils.h" +#include "utils/tensor_description.h" + +#define BF_CUDA_MAX_CHANNELS 16 +#define BF_CUDA_MAX_SPATIAL_DIMENSION 3 + +#ifdef WITH_CUDA +std::tuple +BilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma); +torch::Tensor BilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); +#endif + +std::tuple +BilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma); + +torch::Tensor BilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +std::tuple +TrainableBilateralFilterForward( + torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +torch::Tensor TrainableBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 1d62a19e8a..08418c400b 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -14,7 +14,7 @@ from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args -from .filtering import BilateralFilter, PHLFilter +from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter from .gmm import GaussianMixtureModel from .simplelayers import ( LLTM, diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index c0ddd64a87..20bcdb26c1 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -17,7 +17,7 @@ _C, _ = optional_import("monai._C") -__all__ = ["BilateralFilter", "PHLFilter"] +__all__ = ["BilateralFilter", "PHLFilter", "TrainableBilateralFilter"] class BilateralFilter(torch.autograd.Function): @@ -101,3 +101,150 @@ def backward(ctx, grad_output): # scaled_features, = ctx.saved_variables # grad_input = _C.phl_filter(grad_output, scaled_features) # return grad_input + + +class TrainableBilateralFilterFunction(torch.autograd.Function): + """ + torch.autograd.Function for the TrainableBilateralFilter layer. + + See: + F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in + computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718 + + Args: + input: input tensor to be filtered. + + sigma x: trainable standard deviation of the spatial filter kernel in x direction. + + sigma y: trainable standard deviation of the spatial filter kernel in y direction. + + sigma z: trainable standard deviation of the spatial filter kernel in z direction. + + color sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + @staticmethod + def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma): + output_tensor, output_weights_tensor, do_dx_ki, do_dsig_r, do_dsig_x, do_dsig_y, do_dsig_z = _C.tbf_forward( + input_img, sigma_x, sigma_y, sigma_z, color_sigma + ) + + ctx.save_for_backward( + input_img, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + output_tensor, + output_weights_tensor, + do_dx_ki, + do_dsig_r, + do_dsig_x, + do_dsig_y, + do_dsig_z, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + input_img = ctx.saved_tensors[0] # input image + sigma_x = ctx.saved_tensors[1] + sigma_y = ctx.saved_tensors[2] + sigma_z = ctx.saved_tensors[3] + color_sigma = ctx.saved_tensors[4] + output_tensor = ctx.saved_tensors[5] # filtered image + output_weights_tensor = ctx.saved_tensors[6] # weights + do_dx_ki = ctx.saved_tensors[7] # derivative of output with respect to input, while k==i + do_dsig_r = ctx.saved_tensors[8] # derivative of output with respect to range sigma + do_dsig_x = ctx.saved_tensors[9] # derivative of output with respect to sigma x + do_dsig_y = ctx.saved_tensors[10] # derivative of output with respect to sigma y + do_dsig_z = ctx.saved_tensors[11] # derivative of output with respect to sigma z + + # calculate gradient with respect to the sigmas + grad_color_sigma = torch.sum(grad_output * do_dsig_r) + grad_sig_x = torch.sum(grad_output * do_dsig_x) + grad_sig_y = torch.sum(grad_output * do_dsig_y) + grad_sig_z = torch.sum(grad_output * do_dsig_z) + + grad_output_tensor = _C.tbf_backward( + grad_output, + input_img, + output_tensor, + output_weights_tensor, + do_dx_ki, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + ) + + return grad_output_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma + + +class TrainableBilateralFilter(torch.nn.Module): + """ + Implementation of a trainable bilateral filter layer as proposed in the corresponding publication. + All filter parameters can be trained data-driven. The spatial filter kernels x, y, and z determine + image smoothing whereas the color parameter specifies the amount of edge preservation. + Can run on 1D, 2D, or 3D tensors (on top of Batch and Channel dimensions). + + See: + F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in + computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718 + + Args: + input: input tensor to be filtered. + + sigma x: trainable standard deviation of the spatial filter kernel in x direction. + + sigma y: trainable standard deviation of the spatial filter kernel in y direction. + + sigma z: trainable standard deviation of the spatial filter kernel in z direction. + + sigma color: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + def __init__(self, sigma_x, sigma_y, sigma_z, sigma_color): + super().__init__() + + # Register sigmas as trainable parameters. + self.sigma_x = torch.nn.Parameter(torch.tensor(sigma_x)) + self.sigma_y = torch.nn.Parameter(torch.tensor(sigma_y)) + self.sigma_z = torch.nn.Parameter(torch.tensor(sigma_z)) + self.sigma_color = torch.nn.Parameter(torch.tensor(sigma_color)) + + def forward(self, input_tensor): + assert input_tensor.shape[1] == 1, ( + "Currently channel dimensions >1 are not supported. " + "Please use multiple parallel filter layers if you want " + "to filter multiple channels." + ) + + len_input = len(input_tensor.shape) + + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + prediction = TrainableBilateralFilterFunction.apply( + input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color + ) + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + prediction = prediction.squeeze(4).squeeze(3) + elif len_input == 4: + prediction = prediction.squeeze(4) + + return prediction diff --git a/tests/test_trainable_bilateral.py b/tests/test_trainable_bilateral.py new file mode 100644 index 0000000000..1300e5068d --- /dev/null +++ b/tests/test_trainable_bilateral.py @@ -0,0 +1,478 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized +from torch.autograd import gradcheck + +from monai.networks.layers.filtering import TrainableBilateralFilterFunction +from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1.0, 1.0, 1.0, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999997, 0.000001, 0.000000, 0.000001, 0.999997] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.999995, 0.000001, 0.000000] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1, 1, 1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.714200, 0.158126, 0.061890, 0.158126, 0.714200] + ], + # Batch 1 + [ + # Channel 0 + [0.043465, 0.158126, 0.555452, 0.158126, 0.043465] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999994, 0.000002, 0.000002, 0.000002, 0.999994] + ], + # Batch 1 + [ + # Channel 0 + [0.000001, 0.000001, 0.999986, 0.000001, 0.000001] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.533282, 0.245915, 0.244711, 0.245915, 0.533282] + ], + # Batch 1 + [ + # Channel 0 + [0.125052, 0.126608, 0.333592, 0.126608, 0.125052] + ], + ], + ], + [ + # Case Description + "2 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.239789, 0.082990, 0.082630, 0.082990, 0.239789], + [0.082990, 0.081934, 0.081579, 0.081934, 0.082990], + [0.082630, 0.081579, 0.081225, 0.081579, 0.082630], + [0.082990, 0.081934, 0.081579, 0.081934, 0.082990], + [0.239789, 0.082990, 0.082630, 0.082990, 0.239789], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.024155, 0.024432, 0.024525, 0.024432, 0.024155], + [0.024432, 0.024712, 0.024806, 0.024712, 0.024432], + [0.024525, 0.024806, 0.080686, 0.024806, 0.024525], + [0.024432, 0.024712, 0.024806, 0.024712, 0.024432], + [0.024155, 0.024432, 0.024525, 0.024432, 0.024155], + ] + ], + ], + ], + [ + # Case Description + "3 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.098142, 0.030317, 0.030191, 0.030316, 0.098142], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.098142, 0.030317, 0.030191, 0.030317, 0.098142], + ], + # Frame 1 + [ + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + ], + # Frame 2 + [ + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029698, 0.029336, 0.029214, 0.029336, 0.029698], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + ], + # Frame 3 + [ + [0.030316, 0.029947, 0.029822, 0.029947, 0.030317], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + ], + # Frame 4 + [ + [0.098142, 0.030317, 0.030191, 0.030317, 0.098142], + [0.030317, 0.029947, 0.029822, 0.029947, 0.030316], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.030317, 0.029947, 0.029822, 0.029947, 0.030316], + [0.098142, 0.030317, 0.030191, 0.030316, 0.098142], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extension +class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precise(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +@skip_if_no_cuda +@skip_if_no_cpp_extension +class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precise(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cuda") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +if __name__ == "__main__": + unittest.main()