-
Notifications
You must be signed in to change notification settings - Fork 2
/
kernel_ops.py
executable file
·38 lines (28 loc) · 1.21 KB
/
kernel_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
description: CUDA kernel for Image Filtering and Upsampling
@author: Xiaoxu Meng ([email protected])
"""
import os
import tensorflow as tf
from tensorflow.python.framework import ops
##################################################################
path = os.path.dirname(os.path.abspath(__file__))
path_kernel_functions = tf.resource_loader.get_path_to_datafile(
os.path.join(path, '0_kernel_functions', 'kernel_filter.so'))
kernel_filter_lib = tf.load_op_library(path_kernel_functions)
kernel_filter = kernel_filter_lib.kernel_filter
@ops.RegisterGradient('KernelFilter')
def _kernel_filter_grad(op, grad):
image = op.inputs[0]
kernel = op.inputs[1]
return kernel_filter_lib.kernel_filter_grad(image, kernel, grad)
##################################################################
path = os.path.dirname(os.path.abspath(__file__))
path_upsampling_functions = tf.resource_loader.get_path_to_datafile(
os.path.join(path, '0_upsampling', 'upsampling.so'))
upsampling_lib = tf.load_op_library(path_upsampling_functions)
upsampling = upsampling_lib.upsampling
@ops.RegisterGradient('Upsampling')
def _upsampling_grad(op, grad):
image = op.inputs[0]
return upsampling_lib.upsampling_grad(image, grad)