Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap up SyncedMem resize from @kloudkl; make train/test nets share data blobs #355

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f2fed7a
Use Thrust to manage the underlying memory of SyncedMem
kloudkl Mar 22, 2014
f42e347
Implement and test SyncedMem::resize and reserve
kloudkl Mar 22, 2014
a278657
Add Blob::capacity and make Blob::Reshape capacity aware
kloudkl Mar 22, 2014
1b0c1e0
Merge the Blob constructors and initialize all the fields properly
kloudkl Apr 4, 2014
3b574fd
Change the Makefile to be consistent in style
kloudkl Apr 7, 2014
ff98a33
Use static_cast instead of reinterpret_cast to cast void pointers
kloudkl Apr 10, 2014
993c55a
minor cast/style fixes
jeffdonahue Apr 23, 2014
a495223
minor makefile style/dependency fixes; move more general rule after more
jeffdonahue Apr 24, 2014
ccdd521
use handwritten resize methods instead of thrust host/device vectors
jeffdonahue Apr 23, 2014
ccccbdb
make TestResize less trivial given new syncedmem resize design
jeffdonahue Apr 24, 2014
fd896bf
add net constructor that takes pointer to other net with which we will
jeffdonahue Apr 23, 2014
7272ca0
increase imagenet val batch size to 250 since little additional memor…
jeffdonahue Apr 24, 2014
5a94f27
misc minor cleanup
jeffdonahue Apr 24, 2014
7a89d81
check that malloc/realloc succeed
jeffdonahue Apr 24, 2014
43bdc6f
add comments explaining memory_share_net
jeffdonahue Apr 24, 2014
78b35f3
make memory_share_blob input to Blob constructor const
jeffdonahue Apr 25, 2014
dbba6c2
rebase fixup: rm duplicate constructor
jeffdonahue Apr 25, 2014
3011d82
post-rebase cleanup
jeffdonahue May 4, 2014
5717ab5
add comment documenting function of set_size method in SyncedMemory
jeffdonahue May 4, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ $(TEST_BIN_DIR)/%.testbin: $(TEST_BUILD_DIR)/%.o $(GTEST_OBJ) $(STATIC_NAME) \
-o $@ $(CXXFLAGS) $(LDFLAGS) $(WARNINGS)
@ echo

$(GTEST_OBJ): $(GTEST_SRC) | $(GTEST_BUILD_DIR)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with the Makefile changes introduced in 1d4ea4be7e77203306a157580a44f70fb475093e?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kloudkl had added the +$(OBJ_BUILD_DIR)/%.cuo: rule (presumably due to the renaming of syncedmem.cpp->syncedmem.cu -- we had no rule for building *.cu files in the top-level src/caffe directory), but I moved it down in the list because older versions of make (including the one I use, the default Ubuntu installation) will match the first rule matching instead of the most specific one, so this rule matched *.cu files in subdirs also, which I fixed by moving it after all other *.cu rules.

The rest of the changes were basically style changes and adding a dependency on the header files $(HXX_SRCS) where I happened to notice there wasn't one before (in one case changing from $(PROTO_GEN_HEADER), which is a subset of $(HXX_SRCS)) Sorry for mixing these changes into an unrelated PR...I can remove them from history if desired.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to keep it in this PR for convenience, but could you split the Makefile changes into their own commit (or at least mention them in the message for 1d4ea4b)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done - moved into different commit right before the "use handwritten resize methods" one (github displays it as being the last commit though)

$(CXX) $< $(CXXFLAGS) -c -o $@
@ echo

$(TOOL_BINS): %.bin : %.o $(STATIC_NAME)
$(CXX) $< $(STATIC_NAME) -o $@ $(CXXFLAGS) $(LDFLAGS) $(WARNINGS)
@ echo
Expand All @@ -327,29 +331,29 @@ $(UTIL_BUILD_DIR)/%.o: src/$(PROJECT)/util/%.cpp $(HXX_SRCS) | $(UTIL_BUILD_DIR)
$(CXX) $< $(CXXFLAGS) -c -o $@
@ echo

$(GTEST_OBJ): $(GTEST_SRC) | $(GTEST_BUILD_DIR)
$(CXX) $< $(CXXFLAGS) -c -o $@
@ echo

$(LAYER_BUILD_DIR)/%.cuo: src/$(PROJECT)/layers/%.cu $(HXX_SRCS) \
| $(LAYER_BUILD_DIR)
$(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
@ echo

$(UTIL_BUILD_DIR)/%.cuo: src/$(PROJECT)/util/%.cu | $(UTIL_BUILD_DIR)
$(UTIL_BUILD_DIR)/%.cuo: src/$(PROJECT)/util/%.cu $(HXX_SRCS) \
| $(UTIL_BUILD_DIR)
$(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
@ echo

$(OBJ_BUILD_DIR)/%.cuo: src/$(PROJECT)/%.cu $(HXX_SRCS) | $(OBJ_BUILD_DIR)
$(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
@ echo

$(TOOL_BUILD_DIR)/%.o: tools/%.cpp $(PROTO_GEN_HEADER) | $(TOOL_BUILD_DIR)
$(TOOL_BUILD_DIR)/%.o: tools/%.cpp $(HXX_SRCS) | $(TOOL_BUILD_DIR)
$(CXX) $< $(CXXFLAGS) -c -o $@ $(LDFLAGS)
@ echo

$(EXAMPLE_BUILD_DIR)/%.o: examples/%.cpp $(PROTO_GEN_HEADER) \
| $(EXAMPLE_BUILD_DIRS)
$(EXAMPLE_BUILD_DIR)/%.o: examples/%.cpp $(HXX_SRCS) | $(EXAMPLE_BUILD_DIRS)
$(CXX) $< $(CXXFLAGS) -c -o $@ $(LDFLAGS)
@ echo

$(BUILD_DIR)/src/$(PROJECT)/%.o: src/$(PROJECT)/%.cpp $(HXX_SRCS)
$(OBJ_BUILD_DIR)/%.o: src/$(PROJECT)/%.cpp $(HXX_SRCS) | $(OBJ_BUILD_DIR)
$(CXX) $< $(CXXFLAGS) -c -o $@
@ echo

Expand Down
2 changes: 1 addition & 1 deletion examples/imagenet/imagenet_solver.prototxt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
train_net: "imagenet_train.prototxt"
test_net: "imagenet_val.prototxt"
test_iter: 1000
test_iter: 200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason to change this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the batch size from 50 to 250, so leaving the setting as test_iter: 1000 would run through the 50K val set 5 times. At least I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this is right–I misread this as test_interval.

test_interval: 1000
base_lr: 0.01
lr_policy: "step"
Expand Down
2 changes: 1 addition & 1 deletion examples/imagenet/imagenet_val.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ layers {
data_param {
source: "ilsvrc12_val_leveldb"
mean_file: "../../data/ilsvrc12/imagenet_mean.binaryproto"
batch_size: 50
batch_size: 250
crop_size: 227
mirror: false
}
Expand Down
17 changes: 8 additions & 9 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@ namespace caffe {
template <typename Dtype>
class Blob {
public:
Blob()
: num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
diff_() {}
explicit Blob(const int num, const int channels, const int height,
const int width);
void Reshape(const int num, const int channels, const int height,
const int width);
explicit Blob(const int num = 0, const int channels = 0,
const int height = 0, const int width = 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove the default arguments? It would be odd to set some being 0 and some being not 0, but if all are 0, it is effectively just the default constructor Blob().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kloudkl got rid of the no-arg blob constructor -- the one with the default args sets member variables and does Reshape, so I don't think it's quite the same?

explicit Blob(const Blob& memory_share_blob);
void Reshape(const int num, const int channels,
const int height, const int width);
void ReshapeLike(const Blob& other);
inline int num() const { return num_; }
inline int channels() const { return channels_; }
inline int height() const { return height_; }
inline int width() const { return width_; }
inline int count() const {return count_; }
inline int count() const { return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
Expand Down Expand Up @@ -91,8 +89,9 @@ class Blob {
int height_;
int width_;
int count_;
size_t space_requirement_;

DISABLE_COPY_AND_ASSIGN(Blob);
DISABLE_ASSIGN(Blob);
}; // class Blob

} // namespace caffe
Expand Down
5 changes: 5 additions & 0 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ private:\
classname(const classname&);\
classname& operator=(const classname&)

// Disable just the assignment operator for a class.
#define DISABLE_ASSIGN(classname) \
private:\
classname& operator=(const classname&)

// Instantiate a class with float and double specifications.
#define INSTANTIATE_CLASS(classname) \
template class classname<float>; \
Expand Down
12 changes: 8 additions & 4 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ namespace caffe {
template <typename Dtype>
class Net {
public:
explicit Net(const NetParameter& param);
explicit Net(const string& param_file);
explicit Net(const NetParameter& param, Net<Dtype>* memory_share_net = NULL);
explicit Net(const string& param_file, Net<Dtype>* memory_share_net = NULL);
virtual ~Net() {}

// Initialize a network with the network parameter.
void Init(const NetParameter& param);
// Initialize a network with the network parameter. If memory_share_net is
// non-null, any top/bottom blob in this net with an identically-named blob
// in memory_share_net will share its memory location to save on memory, using
// memory proportional to max(net_a_blob_size, net_b_blob_size) rather than
// (net_a_blob_size + net_b_blob_size).
void Init(const NetParameter& param, Net<Dtype>* memory_share_net = NULL);

// Run forward with the input blobs already fed separately. You can get the
// input blobs using input_blobs().
Expand Down
37 changes: 29 additions & 8 deletions include/caffe/syncedmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ namespace caffe {

inline void CaffeMallocHost(void** ptr, size_t size) {
*ptr = malloc(size);
CHECK(*ptr) << "malloc failed when attempting to allocate "
<< size << " bytes of host memory.";
}

inline void CaffeReallocHost(void** ptr, size_t size) {
*ptr = realloc(*ptr, size);
CHECK(*ptr) << "realloc failed when attempting to allocate "
<< size << " bytes of host memory.";
}

inline void CaffeFreeHost(void* ptr) {
Expand All @@ -34,11 +42,9 @@ inline void CaffeFreeHost(void* ptr) {

class SyncedMemory {
public:
SyncedMemory()
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
own_cpu_data_(false) {}
explicit SyncedMemory(size_t size)
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
explicit SyncedMemory(const size_t size = 0)
: cpu_data_(NULL), gpu_data_(NULL), size_(size),
cpu_capacity_(0), gpu_capacity_(0), head_(UNINITIALIZED),
own_cpu_data_(false) {}
~SyncedMemory();
const void* cpu_data();
Expand All @@ -48,13 +54,28 @@ class SyncedMemory {
void* mutable_gpu_data();
enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED };
SyncedHead head() { return head_; }
size_t size() { return size_; }
inline size_t size() const { return size_; }
// set_size sets the "size_" variable. The current size_ is checked whenever
// a data accessor/mutator method (cpu_data, gpu_data, mutable_cpu_data, ...)
// is called and if the current CPU or GPU memory is insufficient to hold the
// size_, extra space is allocated. If the current allocation on the device/
// host (depending on whether cpu_* or gpu_data was called) is sufficient
// (i.e., capacity >= size_), no action is taken as a result of the set_size
// call. Therefore, the actual allocation can only grow, never shrinking
// (until the SyncedMemory itself is freed/deleted).
inline void set_size(const size_t size) { size_ = size; }
inline size_t cpu_capacity() const { return cpu_capacity_; }
inline size_t gpu_capacity() const { return gpu_capacity_; }

private:
void to_cpu();
void to_gpu();
void* cpu_ptr_;
void* gpu_ptr_;
size_t cpu_resize();
size_t gpu_resize();
void* cpu_data_;
void* gpu_data_;
size_t cpu_capacity_;
size_t gpu_capacity_;
size_t size_;
SyncedHead head_;
bool own_cpu_data_;
Expand Down
75 changes: 48 additions & 27 deletions src/caffe/blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,22 @@
namespace caffe {

template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
Blob<Dtype>::Blob(const int num, const int channels,
const int height, const int width) :
count_(0), space_requirement_(0), data_(), diff_() {
Reshape(num, channels, height, width);
}

template <typename Dtype>
Blob<Dtype>::Blob(const Blob& memory_share_blob) :
count_(0), space_requirement_(0), data_(), diff_() {
ShareData(memory_share_blob);
ShareDiff(memory_share_blob);
}

template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels,
const int height, const int width) {
CHECK_GE(num, 0);
CHECK_GE(channels, 0);
CHECK_GE(height, 0);
Expand All @@ -22,12 +36,16 @@ void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
if (count_) {
data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
space_requirement_ = count_ * sizeof(Dtype);
if (data_) {
data_->set_size(space_requirement_);
} else {
data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
data_.reset(new SyncedMemory(space_requirement_));
}
if (diff_) {
diff_->set_size(space_requirement_);
} else {
diff_.reset(new SyncedMemory(space_requirement_));
}
}

Expand All @@ -36,76 +54,79 @@ void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) {
Reshape(other.num(), other.channels(), other.height(), other.width());
}

template <typename Dtype>
Blob<Dtype>::Blob(const int num, const int channels, const int height,
const int width) {
Reshape(num, channels, height, width);
}

template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
return (const Dtype*)data_->cpu_data();
data_->set_size(space_requirement_);
return static_cast<const Dtype*>(data_->cpu_data());
}

template <typename Dtype>
void Blob<Dtype>::set_cpu_data(Dtype* data) {
CHECK(data);
data_->set_size(space_requirement_);
data_->set_cpu_data(data);
}

template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_data() const {
CHECK(data_);
return (const Dtype*)data_->gpu_data();
data_->set_size(space_requirement_);
return static_cast<const Dtype*>(data_->gpu_data());
}

template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_diff() const {
CHECK(diff_);
return (const Dtype*)diff_->cpu_data();
diff_->set_size(space_requirement_);
return static_cast<const Dtype*>(diff_->cpu_data());
}

template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_diff() const {
CHECK(diff_);
return (const Dtype*)diff_->gpu_data();
diff_->set_size(space_requirement_);
return static_cast<const Dtype*>(diff_->gpu_data());
}

template <typename Dtype>
Dtype* Blob<Dtype>::mutable_cpu_data() {
CHECK(data_);
return reinterpret_cast<Dtype*>(data_->mutable_cpu_data());
data_->set_size(space_requirement_);
return static_cast<Dtype*>(data_->mutable_cpu_data());
}

template <typename Dtype>
Dtype* Blob<Dtype>::mutable_gpu_data() {
CHECK(data_);
return reinterpret_cast<Dtype*>(data_->mutable_gpu_data());
data_->set_size(space_requirement_);
return static_cast<Dtype*>(data_->mutable_gpu_data());
}

template <typename Dtype>
Dtype* Blob<Dtype>::mutable_cpu_diff() {
CHECK(diff_);
return reinterpret_cast<Dtype*>(diff_->mutable_cpu_data());
diff_->set_size(space_requirement_);
return static_cast<Dtype*>(diff_->mutable_cpu_data());
}

template <typename Dtype>
Dtype* Blob<Dtype>::mutable_gpu_diff() {
CHECK(diff_);
return reinterpret_cast<Dtype*>(diff_->mutable_gpu_data());
diff_->set_size(space_requirement_);
return static_cast<Dtype*>(diff_->mutable_gpu_data());
}

template <typename Dtype>
void Blob<Dtype>::ShareData(const Blob& other) {
CHECK_EQ(count_, other.count());
data_ = other.data();
CHECK(data_);
}

template <typename Dtype>
void Blob<Dtype>::ShareDiff(const Blob& other) {
CHECK_EQ(count_, other.count());
diff_ = other.diff();
CHECK(diff_);
}

template <typename Dtype>
Expand All @@ -115,15 +136,15 @@ void Blob<Dtype>::Update() {
case SyncedMemory::HEAD_AT_CPU:
// perform computation on CPU
caffe_axpy<Dtype>(count_, Dtype(-1),
reinterpret_cast<const Dtype*>(diff_->cpu_data()),
reinterpret_cast<Dtype*>(data_->mutable_cpu_data()));
static_cast<const Dtype*>(diff_->cpu_data()),
static_cast<Dtype*>(data_->mutable_cpu_data()));
break;
case SyncedMemory::HEAD_AT_GPU:
case SyncedMemory::SYNCED:
// perform computation on GPU
caffe_gpu_axpy<Dtype>(count_, Dtype(-1),
reinterpret_cast<const Dtype*>(diff_->gpu_data()),
reinterpret_cast<Dtype*>(data_->mutable_gpu_data()));
static_cast<const Dtype*>(diff_->gpu_data()),
static_cast<Dtype*>(data_->mutable_gpu_data()));
break;
default:
LOG(FATAL) << "Syncedmem not initialized.";
Expand Down
Loading