diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index bee7f061fed3a..19efce70c353d 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -51,11 +51,17 @@ void SgdThreadUpdater::init(std::vector& parameters) { size_t numRows = para->isGradSparseUpdate() ? para->getConfig().dims(0) : 0; optimizers_[pid]->init(numRows, ¶->getConfig()); if (para->isGradSparseUpdate() && FLAGS_trainer_count == 1) { - // For trainer_count=1, the gradient machine is NeuralNetwork, which does - // not create parameter buf for PARAMETER_GRADIENT for sparse update in - // Parameter::enableType(). But gradient parameter buf is still used - // in SgdThreadUpdater. We need to explicitly create it. - para->enableBufType(PARAMETER_GRADIENT); + LOG(INFO) << "I'm here"; + // // For trainer_count=1, the gradient machine is NeuralNetwork, + // which + // does + // // not create parameter buf for PARAMETER_GRADIENT for sparse + // update + // in + // // Parameter::enableType(). But gradient parameter buf is still + // used + // // in SgdThreadUpdater. We need to explicitly create it. + // para->enableBufType(PARAMETER_GRADIENT); } } } diff --git a/paddle/trainer/TrainerConfigHelper.cpp b/paddle/trainer/TrainerConfigHelper.cpp index 2017a08d20d49..c8548e96ec660 100644 --- a/paddle/trainer/TrainerConfigHelper.cpp +++ b/paddle/trainer/TrainerConfigHelper.cpp @@ -193,7 +193,12 @@ std::shared_ptr TrainerConfigHelper::createFromFlags() { std::shared_ptr TrainerConfigHelper::createFromFlagConfig() { CHECK(!FLAGS_config.empty()); - return std::make_shared(FLAGS_config); + return create(FLAGS_config); +} + +std::shared_ptr TrainerConfigHelper::create( + const std::string &configFilename) { + return std::make_shared(configFilename); } } // namespace paddle diff --git a/paddle/trainer/TrainerConfigHelper.h b/paddle/trainer/TrainerConfigHelper.h index f1366cc041b0d..36caa3f40747a 100644 --- a/paddle/trainer/TrainerConfigHelper.h +++ b/paddle/trainer/TrainerConfigHelper.h @@ -193,6 +193,14 @@ class TrainerConfigHelper /*final*/ { */ static std::shared_ptr createFromFlagConfig(); + /** + * @brief Create TrainerConfigHelper from configi file. + * @param configFilename config file path. + * @return nullptr if cannot load, otherwise return a TrainerConfigHelper. + */ + static std::shared_ptr create( + const std::string& configFilename); + private: static std::string getConfigNameFromPassId(int passId, const std::string& modelPath); diff --git a/paddle/trainer/tests/CMakeLists.txt b/paddle/trainer/tests/CMakeLists.txt index 60c129f4e2386..98bde6a4b7e8e 100644 --- a/paddle/trainer/tests/CMakeLists.txt +++ b/paddle/trainer/tests/CMakeLists.txt @@ -83,3 +83,14 @@ add_test(NAME test_config_parser COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ python ${PROJ_ROOT}/paddle/trainer/tests/config_parser_test.py WORKING_DIRECTORY ${PROJ_ROOT}/paddle/) + + +############# test_SgdLocalUpdaterForSparseNetwork ########### +add_unittest_without_exec(test_SgdLocalUpdaterForSparseNetwork + test_SgdLocalUpdaterForSparseNetwork.cpp) + +add_test(NAME test_SgdLocalUpdaterForSparseNetwork + COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d + ${PROJ_ROOT}/python/ + ${CMAKE_CURRENT_BINARY_DIR}/test_SgdLocalUpdaterForSparseNetwork + WORKING_DIRECTORY ${PROJ_ROOT}/paddle/trainer/tests/sgd_local_updater_sparse_network/) diff --git a/paddle/trainer/tests/sgd_local_updater_sparse_network/.gitignore b/paddle/trainer/tests/sgd_local_updater_sparse_network/.gitignore new file mode 100644 index 0000000000000..a21555ad1e26e --- /dev/null +++ b/paddle/trainer/tests/sgd_local_updater_sparse_network/.gitignore @@ -0,0 +1 @@ +train.list diff --git a/paddle/trainer/tests/sgd_local_updater_sparse_network/sparse_updated_network.py b/paddle/trainer/tests/sgd_local_updater_sparse_network/sparse_updated_network.py new file mode 100644 index 0000000000000..026d55f7fcf36 --- /dev/null +++ b/paddle/trainer/tests/sgd_local_updater_sparse_network/sparse_updated_network.py @@ -0,0 +1,22 @@ +from paddle.trainer_config_helpers import * + +define_py_data_sources2( + train_list=["do_not_matter.txt"], + test_list=None, + module='sparse_updated_network_provider', + obj='process') + +settings(batch_size=100, learning_rate=1e-4) + +outputs( + classification_cost( + input=fc_layer( + size=10, + act=SoftmaxActivation(), + input=embedding_layer( + size=64, + input=data_layer( + name='word_id', size=600000), + param_attr=ParamAttr(sparse_update=True))), + label=data_layer( + name='label', size=10))) diff --git a/paddle/trainer/tests/sgd_local_updater_sparse_network/sparse_updated_network_provider.py b/paddle/trainer/tests/sgd_local_updater_sparse_network/sparse_updated_network_provider.py new file mode 100644 index 0000000000000..39f19ad9400e5 --- /dev/null +++ b/paddle/trainer/tests/sgd_local_updater_sparse_network/sparse_updated_network_provider.py @@ -0,0 +1,11 @@ +from paddle.trainer.PyDataProvider2 import * +import random + + +@provider( + input_types={"word_id": integer_value(600000), + "label": integer_value(10)}, + min_pool_size=0) +def process(settings, filename): + for _ in xrange(1000): + yield random.randint(0, 600000 - 1), random.randint(0, 9) diff --git a/paddle/trainer/tests/test_SgdLocalUpdaterForSparseNetwork.cpp b/paddle/trainer/tests/test_SgdLocalUpdaterForSparseNetwork.cpp new file mode 100644 index 0000000000000..83e7446ff6bc0 --- /dev/null +++ b/paddle/trainer/tests/test_SgdLocalUpdaterForSparseNetwork.cpp @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/pserver/ParameterServer2.h" +#include "paddle/trainer/Trainer.h" +#include "paddle/utils/PythonUtil.h" +#include "paddle/utils/Util.h" + +P_DECLARE_bool(local); + +static std::unique_ptr createTrainer( + bool useGpu, size_t trainerCount, const std::string& configFilename) { + FLAGS_use_gpu = useGpu; + FLAGS_trainer_count = trainerCount; + paddle::Trainer* trainer = new paddle::Trainer(); + + trainer->init(paddle::TrainerConfigHelper::create(configFilename)); + return std::unique_ptr(trainer); +} + +TEST(SgdLocalUpdater, RemoteSparseNNCpu) { + FLAGS_ports_num_for_sparse = 1; + FLAGS_num_passes = 1; + FLAGS_local = false; + std::vector> pservers; + + for (int i = 0; i < FLAGS_ports_num + FLAGS_ports_num_for_sparse; ++i) { + auto pserver = + std::make_shared("127.0.0.1", FLAGS_port + i); + pserver->init(); + pserver->start(); + pservers.push_back(pserver); + } + + auto trainerPtr = createTrainer(false, 1, "sparse_updated_network.py"); + ASSERT_TRUE(trainerPtr != nullptr); + paddle::Trainer& trainer = *trainerPtr; + trainer.startTrain(); + trainer.train(1); + trainer.finishTrain(); +} + +TEST(SgdLocalUpdater, LocalSparseNNCpu) { + FLAGS_local = true; + auto trainerPtr = createTrainer(false, 1, "sparse_updated_network.py"); + ASSERT_TRUE(trainerPtr != nullptr); + paddle::Trainer& trainer = *trainerPtr; + trainer.startTrain(); + trainer.train(1); + trainer.finishTrain(); +} +// TEST(SgdLocalUpdater, SparseNNGpu) { +// auto trainerPtr = createTrainer(true, 1, "sparse_updated_network.py"); +// ASSERT_TRUE(trainerPtr != nullptr); +// paddle::Trainer& trainer = *trainerPtr; +// trainer.startTrain(); +// trainer.train(1); +// trainer.finishTrain(); +//} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + paddle::initPython(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index c62553f54cc30..d2ab04146b1a6 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -69,7 +69,7 @@ def define_py_data_source(file_list, """ if isinstance(file_list, list): file_list_name = 'train.list' - if isinstance(cls, TestData): + if cls == TestData: file_list_name = 'test.list' with open(file_list_name, 'w') as f: f.writelines(file_list)