Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… cross_entropy
  • Loading branch information
qingqing01 committed Aug 19, 2017
2 parents 8f6c878 + 0d9846f commit a8863a8
Show file tree
Hide file tree
Showing 17 changed files with 275 additions and 33 deletions.
2 changes: 1 addition & 1 deletion cmake/cudnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ if(NOT WITH_GPU)
return()
endif()

set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT")
set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT")
find_path(CUDNN_INCLUDE_DIR cudnn.h
PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
Expand Down
10 changes: 10 additions & 0 deletions paddle/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ limitations under the License. */
#include <memory> // for unique_ptr
#include <mutex> // for call_once

#include "glog/logging.h"

#include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/gpu_info.h"

DECLARE_double(fraction_of_gpu_memory_to_use);

namespace paddle {
namespace memory {
Expand Down Expand Up @@ -80,6 +85,11 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize()));
}
VLOG(3) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n"
<< "You can set environment variable '"
<< platform::kEnvFractionGpuMemoryToUse
<< "' to change the fraction of GPU usage.\n\n";
});

platform::SetDeviceId(gpu_id);
Expand Down
1 change: 0 additions & 1 deletion paddle/memory/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#pragma once

#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h"

namespace paddle {
Expand Down
8 changes: 4 additions & 4 deletions paddle/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
Expand All @@ -40,8 +40,8 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
Expand Down
25 changes: 21 additions & 4 deletions paddle/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
namespace paddle {
namespace operators {

using framework::Tensor;

class MulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand Down Expand Up @@ -59,10 +61,23 @@ class MulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(x_dims[0] == out_dims[0],
"Out@GRAD M X N must equal to X dims 0, M ");
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
"Out@GRAD M X N must equal to Y dims 1, N ");

x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
}
};

Expand All @@ -72,3 +87,5 @@ class MulOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
2 changes: 2 additions & 0 deletions paddle/operators/mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::GPUPlace, float>);
40 changes: 28 additions & 12 deletions paddle/operators/mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,34 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto Z = EigenMatrix<T>::From(*output);
auto& place = context.GetEigenDevice<Place>();

Z.device(place) = X.contract(Y, dim_pair);
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
}
};

template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));

auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
}
};

Expand Down
34 changes: 28 additions & 6 deletions paddle/operators/rowwise_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
namespace paddle {
namespace operators {

class RowWiseAddOp : public framework::OperatorWithKernel {
using framework::Tensor;

class RowwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

Expand All @@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel {
}
};

class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(framework::OpProto *proto,
RowwiseAddOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix");
Expand All @@ -49,12 +51,32 @@ for i in xrange(X.shape[0]):
)DOC");
}
};
class RowwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp,
ops::RowWiseAddOpMaker);
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
rowwise_add_grad,
ops::RowwiseAddGradKernel<paddle::platform::CPUPlace, float>);
2 changes: 1 addition & 1 deletion paddle/operators/rowwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::GPUPlace, float>);
rowwise_add, ops::RowwiseAddKernel<paddle::platform::GPUPlace, float>);
22 changes: 21 additions & 1 deletion paddle/operators/rowwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename Place, typename T>
class RowWiseAddKernel : public framework::OpKernel {
class RowwiseAddKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>("Out");
Expand All @@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel {
}
};

template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
dX->mutable_data<T>(context.GetPlace());
db->mutable_data<T>(context.GetPlace());

auto OutGrad = EigenMatrix<T>::From(*dOut);
auto place = context.GetEigenDevice<Place>();
EigenMatrix<T>::From(*dX).device(place) = OutGrad;

// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{1}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
}
};
} // namespace operators
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog)
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)

nv_library(gpu_info SRCS gpu_info.cc DEPS gflags)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog)

cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)

add_subdirectory(dynload)

cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
cc_test(environment_test SRCS environment_test.cc DEPS stringpiece)

IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
Expand Down
60 changes: 60 additions & 0 deletions paddle/platform/environment.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <stdlib.h>
#include <unistd.h>
#include <vector>

#include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"

extern char** environ; // for environment variables

namespace paddle {
namespace platform {

inline void SetEnvVariable(const std::string& name, const std::string& value) {
PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1,
"Failed to set environment variable %s=%s", name, value);
}

inline void UnsetEnvVariable(const std::string& name) {
PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1,
"Failed to unset environment variable %s", name);
}

inline bool IsEnvVarDefined(const std::string& name) {
return std::getenv(name.c_str()) != nullptr;
}

inline std::string GetEnvValue(const std::string& name) {
PADDLE_ENFORCE(IsEnvVarDefined(name),
"Tried to access undefined environment variable %s", name);
return std::getenv(name.c_str());
}

inline std::vector<std::string> GetAllEnvVariables() {
std::vector<std::string> vars;
for (auto var = environ; *var != nullptr; ++var) {
auto tail = string::Index(*var, "=");
auto name = string::SubStr(*var, 0, tail).ToString();
vars.push_back(name);
}
return vars;
}

} // namespace platform
} // namespace paddle
54 changes: 54 additions & 0 deletions paddle/platform/environment_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/platform/environment.h"

#include "glog/logging.h"
#include "gtest/gtest.h"

TEST(ENVIRONMENT, ACCESS) {
namespace platform = paddle::platform;
namespace string = paddle::string;

platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE");

EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV"));
EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE");

platform::UnsetEnvVariable("PADDLE_USE_ENV");
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV"));

platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello ");
platform::SetEnvVariable("PADDLE_USE_ENV2", "World, ");
platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!");

std::string env_info;
auto vars = platform::GetAllEnvVariables();
for_each(vars.begin(), vars.end(), [&](const std::string& var) {
env_info += platform::GetEnvValue(var);
});

EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!"));
platform::UnsetEnvVariable("PADDLE_USE_ENV1");
platform::UnsetEnvVariable("PADDLE_USE_ENV2");
platform::UnsetEnvVariable("PADDLE_USE_ENV3");

env_info.clear();
vars = platform::GetAllEnvVariables();
for_each(vars.begin(), vars.end(), [&](const std::string& var) {
env_info += platform::GetEnvValue(var);
});

EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3"));
}
Loading

0 comments on commit a8863a8

Please sign in to comment.