diff --git a/source/train/argcheck.py b/source/train/argcheck.py index 89034180ba..d87d41857c 100644 --- a/source/train/argcheck.py +++ b/source/train/argcheck.py @@ -321,7 +321,7 @@ def training_args(): Argument("seed", [int,None], optional = True, doc = doc_seed), Argument("disp_file", str, optional = True, default = 'lcueve.out', doc = doc_disp_file), Argument("disp_freq", int, optional = True, default = 1000, doc = doc_disp_freq), - Argument("numb_test", int, optional = True, default = 1, doc = doc_numb_test), + Argument("numb_test", [list,int,str], optional = True, default = 1, doc = doc_numb_test), Argument("save_freq", int, optional = True, default = 1000, doc = doc_save_freq), Argument("save_ckpt", str, optional = True, default = 'model.ckpt', doc = doc_save_ckpt), Argument("disp_training", bool, optional = True, default = True, doc = doc_disp_training),