forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
L1Cost.cu
34 lines (30 loc) · 821 Bytes
/
L1Cost.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <THCUNN/THCUNN.h>
#include <THCUNN/common.h>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
#include <thrust/transform_reduce.h>
template <typename Dtype, typename Acctype>
struct l1cost_functor
{
__host__ __device__ Acctype operator()(Dtype x) const
{
return THCNumerics<Acctype>::abs(ScalarConvert<Dtype, Acctype>::to(x));
}
};
template <typename Dtype>
struct l1cost_updateGradInput_functor
{
__host__ __device__ Dtype operator()(Dtype x) const
{
if (x > 0)
return ScalarConvert<int, Dtype>::to(1);
else if (x < 0)
return ScalarConvert<int, Dtype>::to(-1);
else
return ScalarConvert<int, Dtype>::to(0);
}
};
#include <THCUNN/generic/L1Cost.cu>
#include <THC/THCGenerateFloatTypes.h>