Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Splitting source files between CUDA and CPU code. #172

Merged
merged 1 commit into from
Feb 27, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/caffe/layers/bnll_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2013 Yangqing Jia

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include <algorithm>

using std::min;

namespace caffe {

const float kBNLL_THRESHOLD = 50.;

template <typename Dtype>
void BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] > 0 ?
bottom_data[i] + log(1. + exp(-bottom_data[i])) :
log(1. + exp(bottom_data[i]));
}
}

template <typename Dtype>
Dtype BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
Dtype expval;
for (int i = 0; i < count; ++i) {
expval = exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD)));
bottom_diff[i] = top_diff[i] * expval / (expval + 1.);
}
}
return Dtype(0);
}


INSTANTIATE_CLASS(BNLLLayer);


} // namespace caffe
31 changes: 0 additions & 31 deletions src/caffe/layers/bnll_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,6 @@ namespace caffe {

const float kBNLL_THRESHOLD = 50.;

template <typename Dtype>
void BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] > 0 ?
bottom_data[i] + log(1. + exp(-bottom_data[i])) :
log(1. + exp(bottom_data[i]));
}
}

template <typename Dtype>
Dtype BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
Dtype expval;
for (int i = 0; i < count; ++i) {
expval = exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD)));
bottom_diff[i] = top_diff[i] * expval / (expval + 1.);
}
}
return Dtype(0);
}

template <typename Dtype>
__global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
Expand Down
88 changes: 0 additions & 88 deletions src/caffe/layers/conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,36 +106,6 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
}

template <typename Dtype>
void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
Dtype* col_data = col_buffer_.mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < NUM_; ++n) {
// First, im2col
im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
// Second, innerproduct with groups
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
(Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
(Dtype)0., top_data + (*top)[0]->offset(n) + top_offset * g);
}
// third, add bias
if (biasterm_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
(Dtype)1., top_data + (*top)[0]->offset(n));
}
}
}

template <typename Dtype>
Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
Expand Down Expand Up @@ -192,64 +162,6 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
return Dtype(0.);
}

template <typename Dtype>
Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* weight = this->blobs_[0]->gpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
Dtype* col_data = col_buffer_.mutable_gpu_data();
Dtype* col_diff = col_buffer_.mutable_gpu_diff();
// bias gradient if necessary
Dtype* bias_diff = NULL;

if (biasterm_) {
bias_diff = this->blobs_[1]->mutable_gpu_diff();
CUDA_CHECK(cudaMemset(bias_diff, 0,
sizeof(Dtype) * this->blobs_[1]->count()));
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
1., top_diff + top[0]->offset(n),
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
1., bias_diff);
}
}

int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
CUDA_CHECK(cudaMemset(weight_diff, 0,
sizeof(Dtype) * this->blobs_[0]->count()));
for (int n = 0; n < NUM_; ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
// gradient w.r.t. weight. Note that we will accumulate diffs.
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
(Dtype)1., top_diff + top[0]->offset(n) + top_offset * g,
col_data + col_offset * g, (Dtype)1.,
weight_diff + weight_offset * g);
}
// gradient w.r.t. bottom data, if necessary
if (propagate_down) {
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
(Dtype)1., weight + weight_offset * g,
top_diff + top[0]->offset(n) + top_offset * g,
(Dtype)0., col_diff + col_offset * g);
}
// col2im back to the data
col2im_gpu(col_diff, CHANNELS_, HEIGHT_, WIDTH_, KSIZE_, PAD_, STRIDE_,
bottom_diff + (*bottom)[0]->offset(n));
}
}
return Dtype(0.);
}

INSTANTIATE_CLASS(ConvolutionLayer);

} // namespace caffe
104 changes: 104 additions & 0 deletions src/caffe/layers/conv_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2013 Yangqing Jia

#include <vector>

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/filler.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
Dtype* col_data = col_buffer_.mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < NUM_; ++n) {
// First, im2col
im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
// Second, innerproduct with groups
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
(Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
(Dtype)0., top_data + (*top)[0]->offset(n) + top_offset * g);
}
// third, add bias
if (biasterm_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
(Dtype)1., top_data + (*top)[0]->offset(n));
}
}
}

template <typename Dtype>
Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* weight = this->blobs_[0]->gpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
Dtype* col_data = col_buffer_.mutable_gpu_data();
Dtype* col_diff = col_buffer_.mutable_gpu_diff();
// bias gradient if necessary
Dtype* bias_diff = NULL;

if (biasterm_) {
bias_diff = this->blobs_[1]->mutable_gpu_diff();
CUDA_CHECK(cudaMemset(bias_diff, 0,
sizeof(Dtype) * this->blobs_[1]->count()));
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
1., top_diff + top[0]->offset(n),
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
1., bias_diff);
}
}

int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
CUDA_CHECK(cudaMemset(weight_diff, 0,
sizeof(Dtype) * this->blobs_[0]->count()));
for (int n = 0; n < NUM_; ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
// gradient w.r.t. weight. Note that we will accumulate diffs.
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
(Dtype)1., top_diff + top[0]->offset(n) + top_offset * g,
col_data + col_offset * g, (Dtype)1.,
weight_diff + weight_offset * g);
}
// gradient w.r.t. bottom data, if necessary
if (propagate_down) {
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
(Dtype)1., weight + weight_offset * g,
top_diff + top[0]->offset(n) + top_offset * g,
(Dtype)0., col_diff + col_offset * g);
}
// col2im back to the data
col2im_gpu(col_diff, CHANNELS_, HEIGHT_, WIDTH_, KSIZE_, PAD_, STRIDE_,
bottom_diff + (*bottom)[0]->offset(n));
}
}
return Dtype(0.);
}


INSTANTIATE_CLASS(ConvolutionLayer);

} // namespace caffe
23 changes: 0 additions & 23 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,36 +227,13 @@ void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

template <typename Dtype>
void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
// Copy the data
CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
cudaMemcpyHostToDevice));
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

// The backward operations are dummy - they do not carry any computation.
template <typename Dtype>
Dtype DataLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

template <typename Dtype>
Dtype DataLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

INSTANTIATE_CLASS(DataLayer);

} // namespace caffe
44 changes: 44 additions & 0 deletions src/caffe/layers/data_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2013 Yangqing Jia

#include <stdint.h>
#include <leveldb/db.h>
#include <pthread.h>

#include <string>
#include <vector>

#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"

using std::string;

namespace caffe {

template <typename Dtype>
void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
// Copy the data
CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
cudaMemcpyHostToDevice));
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

// The backward operations are dummy - they do not carry any computation.
template <typename Dtype>
Dtype DataLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

INSTANTIATE_CLASS(DataLayer);

} // namespace caffe
Loading