Skip to content

Commit

Permalink
Try add unittest for sgd local updater
Browse files Browse the repository at this point in the history
  • Loading branch information
reyoung committed Dec 14, 2016
1 parent b0c6331 commit cbddad3
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 7 deletions.
16 changes: 11 additions & 5 deletions paddle/trainer/ThreadParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
size_t numRows = para->isGradSparseUpdate() ? para->getConfig().dims(0) : 0;
optimizers_[pid]->init(numRows, &para->getConfig());
if (para->isGradSparseUpdate() && FLAGS_trainer_count == 1) {
// For trainer_count=1, the gradient machine is NeuralNetwork, which does
// not create parameter buf for PARAMETER_GRADIENT for sparse update in
// Parameter::enableType(). But gradient parameter buf is still used
// in SgdThreadUpdater. We need to explicitly create it.
para->enableBufType(PARAMETER_GRADIENT);
LOG(INFO) << "I'm here";
// // For trainer_count=1, the gradient machine is NeuralNetwork,
// which
// does
// // not create parameter buf for PARAMETER_GRADIENT for sparse
// update
// in
// // Parameter::enableType(). But gradient parameter buf is still
// used
// // in SgdThreadUpdater. We need to explicitly create it.
// para->enableBufType(PARAMETER_GRADIENT);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion paddle/trainer/TrainerConfigHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,12 @@ std::shared_ptr<TrainerConfigHelper> TrainerConfigHelper::createFromFlags() {
std::shared_ptr<TrainerConfigHelper>
TrainerConfigHelper::createFromFlagConfig() {
CHECK(!FLAGS_config.empty());
return std::make_shared<TrainerConfigHelper>(FLAGS_config);
return create(FLAGS_config);
}

std::shared_ptr<TrainerConfigHelper> TrainerConfigHelper::create(
const std::string &configFilename) {
return std::make_shared<TrainerConfigHelper>(configFilename);
}

} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/trainer/TrainerConfigHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ class TrainerConfigHelper /*final*/ {
*/
static std::shared_ptr<TrainerConfigHelper> createFromFlagConfig();

/**
* @brief Create TrainerConfigHelper from configi file.
* @param configFilename config file path.
* @return nullptr if cannot load, otherwise return a TrainerConfigHelper.
*/
static std::shared_ptr<TrainerConfigHelper> create(
const std::string& configFilename);

private:
static std::string getConfigNameFromPassId(int passId,
const std::string& modelPath);
Expand Down
11 changes: 11 additions & 0 deletions paddle/trainer/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,14 @@ add_test(NAME test_config_parser
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
python ${PROJ_ROOT}/paddle/trainer/tests/config_parser_test.py
WORKING_DIRECTORY ${PROJ_ROOT}/paddle/)


############# test_SgdLocalUpdaterForSparseNetwork ###########
add_unittest_without_exec(test_SgdLocalUpdaterForSparseNetwork
test_SgdLocalUpdaterForSparseNetwork.cpp)

add_test(NAME test_SgdLocalUpdaterForSparseNetwork
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d
${PROJ_ROOT}/python/
${CMAKE_CURRENT_BINARY_DIR}/test_SgdLocalUpdaterForSparseNetwork
WORKING_DIRECTORY ${PROJ_ROOT}/paddle/trainer/tests/sgd_local_updater_sparse_network/)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
train.list
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from paddle.trainer_config_helpers import *

define_py_data_sources2(
train_list=["do_not_matter.txt"],
test_list=None,
module='sparse_updated_network_provider',
obj='process')

settings(batch_size=100, learning_rate=1e-4)

outputs(
classification_cost(
input=fc_layer(
size=10,
act=SoftmaxActivation(),
input=embedding_layer(
size=64,
input=data_layer(
name='word_id', size=600000),
param_attr=ParamAttr(sparse_update=True))),
label=data_layer(
name='label', size=10)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from paddle.trainer.PyDataProvider2 import *
import random


@provider(
input_types={"word_id": integer_value(600000),
"label": integer_value(10)},
min_pool_size=0)
def process(settings, filename):
for _ in xrange(1000):
yield random.randint(0, 600000 - 1), random.randint(0, 9)
78 changes: 78 additions & 0 deletions paddle/trainer/tests/test_SgdLocalUpdaterForSparseNetwork.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include "paddle/pserver/ParameterServer2.h"
#include "paddle/trainer/Trainer.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Util.h"

P_DECLARE_bool(local);

static std::unique_ptr<paddle::Trainer> createTrainer(
bool useGpu, size_t trainerCount, const std::string& configFilename) {
FLAGS_use_gpu = useGpu;
FLAGS_trainer_count = trainerCount;
paddle::Trainer* trainer = new paddle::Trainer();

trainer->init(paddle::TrainerConfigHelper::create(configFilename));
return std::unique_ptr<paddle::Trainer>(trainer);
}

TEST(SgdLocalUpdater, RemoteSparseNNCpu) {
FLAGS_ports_num_for_sparse = 1;
FLAGS_num_passes = 1;
FLAGS_local = false;
std::vector<std::shared_ptr<paddle::ParameterServer2>> pservers;

for (int i = 0; i < FLAGS_ports_num + FLAGS_ports_num_for_sparse; ++i) {
auto pserver =
std::make_shared<paddle::ParameterServer2>("127.0.0.1", FLAGS_port + i);
pserver->init();
pserver->start();
pservers.push_back(pserver);
}

auto trainerPtr = createTrainer(false, 1, "sparse_updated_network.py");
ASSERT_TRUE(trainerPtr != nullptr);
paddle::Trainer& trainer = *trainerPtr;
trainer.startTrain();
trainer.train(1);
trainer.finishTrain();
}

TEST(SgdLocalUpdater, LocalSparseNNCpu) {
FLAGS_local = true;
auto trainerPtr = createTrainer(false, 1, "sparse_updated_network.py");
ASSERT_TRUE(trainerPtr != nullptr);
paddle::Trainer& trainer = *trainerPtr;
trainer.startTrain();
trainer.train(1);
trainer.finishTrain();
}
// TEST(SgdLocalUpdater, SparseNNGpu) {
// auto trainerPtr = createTrainer(true, 1, "sparse_updated_network.py");
// ASSERT_TRUE(trainerPtr != nullptr);
// paddle::Trainer& trainer = *trainerPtr;
// trainer.startTrain();
// trainer.train(1);
// trainer.finishTrain();
//}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
paddle::initPython(argc, argv);
return RUN_ALL_TESTS();
}
2 changes: 1 addition & 1 deletion python/paddle/trainer_config_helpers/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def define_py_data_source(file_list,
"""
if isinstance(file_list, list):
file_list_name = 'train.list'
if isinstance(cls, TestData):
if cls == TestData:
file_list_name = 'test.list'
with open(file_list_name, 'w') as f:
f.writelines(file_list)
Expand Down

0 comments on commit cbddad3

Please sign in to comment.