Skip to content

Commit

Permalink
Merge pull request deepmodeling#296 from njzjz/ntypes
Browse files Browse the repository at this point in the history
allow ntypes_model > ntypes_data (fix deepmodeling#261)
  • Loading branch information
amcadmus authored Nov 24, 2020
2 parents 045b8b3 + d232647 commit bd61329
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions source/train/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def build (self,
data,
stop_batch = 0) :
self.ntypes = self.model.get_ntypes()
assert (self.ntypes == data.get_ntypes()), "ntypes should match that found in data"
# Usually, the type number of the model should be equal to that of the data
# However, nt_model > nt_data should be allowed, since users may only want to
# train using a dataset that only have some of elements
assert (self.ntypes >= data.get_ntypes()), "ntypes should match that found in data"
self.stop_batch = stop_batch

self.batch_size = data.get_batch_size()
Expand Down Expand Up @@ -492,4 +495,4 @@ def test_on_the_fly (self,
feed_dict_batch)
print_str += " %8.1e\n" % current_lr
fp.write(print_str)
fp.flush ()
fp.flush ()

0 comments on commit bd61329

Please sign in to comment.