Skip to content
This repository has been archived by the owner on Mar 22, 2024. It is now read-only.

wrap fname update in a function #20

Merged
merged 6 commits into from
Dec 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions graphprot/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,27 @@ def plot_hit_rate(self, data='eval', threshold=4, mode='percentage', name=''):
except:
print('No hit rate plot could be generated for you {} task'.format(
self.task))

@staticmethod
def update_name(hdf5, outdir):
"""Check if the file already exists
if so, update the name
ex. 1: train.hdf5 -> train_001.hdf5
ex. 2: train_001.hdf5 -> train_002.hdf5
"""

fname = os.path.join(outdir, hdf5)

count = 0
hdf5_name = hdf5.split('.')[0]

# If file exists, change its name with a number
while os.path.exists(fname) :
count += 1
hdf5 = '{}_{:03d}.hdf5'.format(hdf5_name, count)
fname = os.path.join(outdir, hdf5)

return fname

def train(self, nepoch=1, validate=False, plot=False, save_model='last', hdf5='train_data.hdf5', save_epoch='intermediate', save_every=5):
"""Train the model
Expand All @@ -213,16 +234,8 @@ def train(self, nepoch=1, validate=False, plot=False, save_model='last', hdf5='t
save_every (int, optional): save data every n epoch if save_epoch == 'intermediate'. Defaults to 5
"""

# Output file
fname = os.path.join(self.outdir, hdf5)

# If file exists, change its name with a number
count = 0
hdf5_name = hdf5.split('.')[0]
while os.path.exists(fname) :
count += 1
hdf5 = '{}_{:03d}.hdf5'.format(hdf5_name, count)
fname = os.path.join(self.outdir, hdf5)
# Output file name
fname = self.update_name(hdf5, self.outdir)

# Open output file for writting
self.f5 = h5py.File(fname, 'w')
Expand Down Expand Up @@ -342,15 +355,8 @@ def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):
"""

# Output file
fname = os.path.join(self.outdir, hdf5)

# If file exists, change its name with a number
count = 0
hdf5_name = hdf5.split('.')[0]
while os.path.exists(fname) :
count += 1
hdf5 = '{}_{:03d}.hdf5'.format(hdf5_name, count)
fname = os.path.join(self.outdir, hdf5)
# Output file name
fname = self.update_name(hdf5, self.outdir)

# Open output file for writting
self.f5 = h5py.File(fname, 'w')
Expand Down