Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-design command options for testing for better understanding #411

Merged
merged 12 commits into from
Dec 5, 2016
2 changes: 1 addition & 1 deletion doc/howto/cmd_parameter/arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ It looks like there are a lot of arguments. However, most of them are for develo
</tr>

<tr>
<td class="left" rowspan = "2">testing during training</td><td class="left">test_all_data_in_one_period</td>
<td class="left" rowspan = "2">testing during training</td><td class="left">test_period</td>
<td class="left">√</td><td class="left">√</td><td class="left"></td><td class="left"></td>
</tr>

Expand Down
10 changes: 3 additions & 7 deletions doc/howto/cmd_parameter/detail_introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
- type: string (default: null).

* `--version`
- Whether to print version infomatrion.
- Whether to print version information.
- type: bool (default: 0).

* `--show_layer_stat`
Expand Down Expand Up @@ -110,8 +110,8 @@
- type: int32 (default: -1).

* `--test_period`
- Run testing every test_period train batches. If not set, run testing each pass.
- type: int32 (default: 1000).
- if equal 0, do test on all test data at the end of each pass while if equal non-zero, do test on all test data once each test_period batches passed while training is going on.
- type: int32 (default: 0).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if equal 0, do test on all test data at the end of each pass. While if equal non-zero, do test on all test data every test_period batches.

下面好几处,一起改下吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


* `--test_wait`
- Whether to wait for parameter per pass if not exist. If set test_data_path in submitting environment of cluster, it will launch one process to perfom testing, so we need to set test_wait=1. Note that in the cluster submitting environment, this argument has been set True by default.
Expand All @@ -121,10 +121,6 @@
- File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path.
- type: string (default: "", null).

* `--test_all_data_in_one_period`
- This argument is usually used in testing period during traning. If true, all data will be tested in one test period. Otherwise (batch_size * log_peroid) data will be tested.
- type: bool (default: 0).

* `--predict_output_dir`
- Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function.
- type: string (default: "", null).
Expand Down
5 changes: 2 additions & 3 deletions doc/howto/cmd_parameter/use_case.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ paddle train \
--config=network_config \
--save_dir=output \
--trainer_count=COUNT \ #(default:1)
--test_period=M \ #(default:1000)
--test_all_data_in_one_period=true \ #(default:false)
--num_passes=N \ #(defalut:100)
--test_period=M \ #(default:0)
--num_passes=N \ #(defalut:100)
--log_period=K \ #(default:100)
--dot_period=1000 \ #(default:1)
#[--show_parameter_stats_period=100] \ #(default:0)
Expand Down
12 changes: 3 additions & 9 deletions paddle/trainer/Tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
void Tester::testOnePeriod() {
DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size();
bool testAllData =
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod;
int batches =
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod;

int batches = std::numeric_limits<int>::max();

std::vector<Argument> outArgs;

Expand All @@ -102,11 +100,7 @@ void Tester::testOnePeriod() {
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
if (testAllData) {
break;
} else {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
}
break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

startTestPeriod();
while(int num = testDataProvider_->getNextBatch(batchSize, &dataBatch)) {
   testOneDataBatch(dataBatch, &outArgs);
}
testDataProvider_->reset();
if (intconfig_->prevBatchState) {
  gradientMachine_->resetState();
}
finishTestPeriod();

}
testOneDataBatch(dataBatch, &outArgs);
}
Expand Down
5 changes: 0 additions & 5 deletions paddle/trainer/TesterConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ struct TesterConfig {
*/
int testPeriod;

/**
* indicate whether testing data in one period
*/
bool testAllDataInOnePeriod;

/**
* indicate whether to save previous batch state
*/
Expand Down
35 changes: 22 additions & 13 deletions paddle/trainer/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,17 @@ limitations under the License. */
#include "TrainerConfigHelper.h"

P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period,
0,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data");

P_DEFINE_bool(local, true, "Train in local mode or not");
P_DEFINE_int32(test_period, 0,
"if equal 0, do test on all test data at the end of "
"each pass while if equal non-zero, do test on all test "
"data once each test_period batches passed while "
"training is going on");
P_DEFINE_bool(test_all_data_in_one_period, false,
"This option was deprecated, since we will always do "
"test on all test set ");

P_DEFINE_bool(
test_all_data_in_one_period,
false,
"true will test all data in one test peroid."
"Otherwise test (batch_size * log_peroid) data in one test period.");
P_DEFINE_bool(local, true, "Train in local mode or not");

P_DEFINE_int32(average_test_period,
0,
Expand Down Expand Up @@ -633,8 +630,20 @@ void Trainer::test() { tester_->test(); }

std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig;
if (FLAGS_test_period) {
LOG(WARNING)
<< "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of "
<< "each pass while if equal non-zero, do test on all test "
<< "data once each test_period batches passed while "
<< "training is going on";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set ";
}
conf->testPeriod = FLAGS_test_period;
conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period;
conf->prevBatchState = FLAGS_prev_batch_state;
conf->logPeriod = FLAGS_log_period;
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;
Expand Down