Skip to content

Commit

Permalink
fix template related build issues with gcc (#128)
Browse files Browse the repository at this point in the history
* add missing template parameters
* add `using` statements to derived classes (https://stackoverflow.com/q/50321788)
  • Loading branch information
blin00 authored and mooskagh committed Jul 2, 2018
1 parent 9399e6d commit 3455131
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions src/neural/network_cudnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,15 @@ class BaseLayer {

template <typename DataType>
class ConvLayer : public BaseLayer<DataType> {
using BaseLayer<DataType>::C;
using BaseLayer<DataType>::H;
using BaseLayer<DataType>::W;
using BaseLayer<DataType>::GetC;
using BaseLayer<DataType>::GetH;
using BaseLayer<DataType>::GetW;

public:
ConvLayer(BaseLayer *ip, int C, int H, int W, int size, int Cin,
ConvLayer(BaseLayer<DataType> *ip, int C, int H, int W, int size, int Cin,
bool relu = false, bool bias = false);
~ConvLayer();
void LoadWeights(float *pfilter, float *pBias, void *scratch);
Expand Down Expand Up @@ -156,8 +163,12 @@ class ConvLayer : public BaseLayer<DataType> {

template <typename DataType>
class SoftMaxLayer : public BaseLayer<DataType> {
using BaseLayer<DataType>::GetC;
using BaseLayer<DataType>::GetH;
using BaseLayer<DataType>::GetW;

public:
SoftMaxLayer(BaseLayer *ip);
SoftMaxLayer(BaseLayer<DataType> *ip);
void Eval(int N, DataType *output, const DataType *input,
const DataType *input2, void *scratch, cudnnHandle_t cudnn,
cublasHandle_t cublas) override;
Expand All @@ -168,8 +179,10 @@ class SoftMaxLayer : public BaseLayer<DataType> {

template <typename DataType>
class BNLayer : public BaseLayer<DataType> {
using BaseLayer<DataType>::C;

public:
BNLayer(BaseLayer *ip, bool relu);
BNLayer(BaseLayer<DataType> *ip, bool relu);
~BNLayer();

void LoadWeights(float *cpuMeans, float *cpuVar);
Expand All @@ -189,7 +202,7 @@ class BNLayer : public BaseLayer<DataType> {
template <typename DataType>
class FCLayer : public BaseLayer<DataType> {
public:
FCLayer(BaseLayer *ip, int C, int H, int W, bool relu, bool bias,
FCLayer(BaseLayer<DataType> *ip, int C, int H, int W, bool relu, bool bias,
bool tanh = false);
~FCLayer();

Expand Down Expand Up @@ -443,8 +456,8 @@ BaseLayer<DataType>::BaseLayer(int c, int h, int w, BaseLayer *ip)
: C(c), H(h), W(w), input_(ip) {}

template <typename DataType>
SoftMaxLayer<DataType>::SoftMaxLayer(BaseLayer *ip)
: BaseLayer(ip->GetC(), ip->GetH(), ip->GetW(), ip) {
SoftMaxLayer<DataType>::SoftMaxLayer(BaseLayer<DataType> *ip)
: BaseLayer<DataType>(ip->GetC(), ip->GetH(), ip->GetW(), ip) {
cudnnCreateTensorDescriptor(&out_tensor_desc_);
}

Expand All @@ -470,9 +483,9 @@ void SoftMaxLayer<DataType>::Eval(int N, DataType *output,
}

template <typename DataType>
ConvLayer<DataType>::ConvLayer(BaseLayer *ip, int C, int H, int W, int filter,
ConvLayer<DataType>::ConvLayer(BaseLayer<DataType> *ip, int C, int H, int W, int filter,
int Cin, bool relu, bool bias)
: BaseLayer(C, H, W, ip),
: BaseLayer<DataType>(C, H, W, ip),
filter_size_(filter),
c_input_(Cin),
use_relu_(relu),
Expand Down Expand Up @@ -614,8 +627,8 @@ ConvLayer<DataType>::~ConvLayer() {
}

template <typename DataType>
BNLayer<DataType>::BNLayer(BaseLayer *ip, bool relu)
: BaseLayer(ip->GetC(), ip->GetH(), ip->GetW(), ip), use_relu_(relu) {
BNLayer<DataType>::BNLayer(BaseLayer<DataType> *ip, bool relu)
: BaseLayer<DataType>(ip->GetC(), ip->GetH(), ip->GetW(), ip), use_relu_(relu) {
size_t weightSize = sizeof(float) * C;

reportCUDAErrors(cudaMalloc(&means_, weightSize));
Expand Down Expand Up @@ -654,9 +667,9 @@ BNLayer<DataType>::~BNLayer() {
}

template <typename DataType>
FCLayer<DataType>::FCLayer(BaseLayer *ip, int C, int H, int W, bool relu,
FCLayer<DataType>::FCLayer(BaseLayer<DataType> *ip, int C, int H, int W, bool relu,
bool bias, bool tanh)
: BaseLayer(C, H, W, ip),
: BaseLayer<DataType>(C, H, W, ip),
use_relu_(relu),
use_bias_(bias),
use_tanh_(tanh) {
Expand Down

0 comments on commit 3455131

Please sign in to comment.