Skip to content

Commit

Permalink
give a clear message if model.get_ntypes()<data.get_ntypes() (deepm…
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Aug 23, 2021
1 parent 38bdaf5 commit 84fb302
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,16 @@ def build (self,
# Usually, the type number of the model should be equal to that of the data
# However, nt_model > nt_data should be allowed, since users may only want to
# train using a dataset that only have some of elements
assert (self.ntypes >= data.get_ntypes()), "ntypes should match that found in data"
if self.ntypes < data.get_ntypes():
raise ValueError(
"The number of types of the training data is %d, but that of the "
"model is only %d. The latter must be no less than the former. "
"You may need to reset one or both of them. Usually, the former "
"is given by `model/type_map` in the training parameter (if set) "
"or the maximum number in the training data. The latter is given "
"by `model/descriptor/sel` in the training parameter." % (
data.get_ntypes(), self.ntypes
))
self.type_map = data.get_type_map()
self.batch_size = data.get_batch_size()
self.model.data_stat(data)
Expand Down

0 comments on commit 84fb302

Please sign in to comment.