Skip to content

Commit

Permalink
1. Edit the Reconstructed MI of the Supplementary Materials Fig. S1 2…
Browse files Browse the repository at this point in the history
…. Change default data_format of DeepConvNet and EEGNet from `channels_last` to `channels_first` (our paper and the original use `channels_first`) 3. DeepConvNet and EEGNet can use both `channels_last` and `channels_first`
  • Loading branch information
top-chaisaen committed Feb 5, 2022
1 parent 4882640 commit 1ed001b
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 28 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
experiments/datasets/
experiments/log*
experiments/pretrained*
build/
dist/
min2net.egg-info/
.ipynb
.DS_Store
__pycache__
24 changes: 12 additions & 12 deletions experiments/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,20 @@
},
'BCIC2a': {
'n_subjects': 9,
'input_shape': (20,400,1),
'data_format': 'NCTD',
'input_shape': (1,20,400),
'data_format': 'NDCT',
'num_class': 2
},
'OpenBMI': {
'n_subjects': 54,
'input_shape': (20,400,1),
'data_format': 'NCTD',
'input_shape': (1,20,400),
'data_format': 'NDCT',
'num_class': 2
},
'SMR_BCI': {
'n_subjects': 14,
'input_shape': (15,400,1),
'data_format': 'NCTD',
'input_shape': (1,15,400),
'data_format': 'NDCT',
'num_class': 2
},
},
Expand Down Expand Up @@ -199,20 +199,20 @@
},
'BCIC2a': {
'n_subjects': 9,
'input_shape': (20,400,1),
'data_format': 'NCTD',
'input_shape': (1,20,400),
'data_format': 'NDCT',
'num_class': 2
},
'OpenBMI': {
'n_subjects': 54,
'input_shape': (20,400,1),
'data_format': 'NCTD',
'input_shape': (1,20,400),
'data_format': 'NDCT',
'num_class': 2
},
'SMR_BCI': {
'n_subjects': 14,
'input_shape': (15,400,1),
'data_format': 'NCTD',
'input_shape': (1,15,400),
'data_format': 'NDCT',
'num_class': 2
}
},
Expand Down
6 changes: 3 additions & 3 deletions experiments/train_MIN2Net_k-fold-CV.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
python train_MIN2Net_k-fold-CV.py \
--dataset 'OpenBMI' \
--train_type 'subject_independent' --GPU 1 \
--margin 1.0 --loss_weights 1.0 1.0 1.0
--margin 1.0 --loss_weights 0.5 0.5 1.0
'''

Expand Down Expand Up @@ -77,7 +77,7 @@ def k_fold_cross_validation(subject):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--loss_weights', nargs='+', default=[0.5, 1, 1], type=float, help='loss_weights (beta): ex. [beta1,beta2,beta3]')
parser.add_argument('--loss_weights', nargs='+', default=[0.5, 0.5, 1.0], type=float, help='loss_weights (beta): ex. [beta1,beta2,beta3]')
parser.add_argument('--save_path', type=str, default='logs/MIN2Net', help='path to save logs')
parser.add_argument('--data_path', type=str, default='datasets', help='path to datasets')
parser.add_argument('--dataset', type=str, default='OpenBMI', help='dataset name: ex. [BCIC2a/SMR_BCI/OpenBMI]')
Expand Down Expand Up @@ -112,7 +112,7 @@ def k_fold_cross_validation(subject):
print('TRAIN SET: {}'.format(args.dataset))
print('The size of latent vector: {}'.format(latent_dim))

log_path = '{}_margin{}/{}_{}_classes_{}'.format(args.save_path, str(args.margin), args.train_type, str(num_class), args.dataset, str(args.loss_weights))
log_path = '{}_margin{}/{}_{}_classes_{}_{}'.format(args.save_path, str(args.margin), args.train_type, str(num_class), args.dataset, str(args.loss_weights))
for directory in [log_path]:
if not os.path.exists(directory):
os.makedirs(directory)
Expand Down
21 changes: 15 additions & 6 deletions min2net/model/DeepConvNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class DeepConvNet:
def __init__(self,
input_shape=(20,400,1),
input_shape=(1,20,400),
num_class=2,
loss='sparse_categorical_crossentropy',
epochs=200,
Expand Down Expand Up @@ -49,16 +49,14 @@ def __init__(self,
self.time_log = log_path+'/'+model_name+'_time_log.csv'

# use **kwargs to set the new value of below args.
self.Chans = input_shape[0]
self.Samples = input_shape[1]
self.kernLength = 125
self.F1 = 8
self.D = 2
self.F2 = int(self.F1*self.D)
self.norm_rate = 0.25
self.dropout_rate = 0.5
self.f1_average = 'binary' if self.num_class == 2 else 'macro'
self.data_format = 'channels_last'
self.data_format = 'channels_first'
self.shuffle = False
self.metrics = 'accuracy'
self.monitor = 'val_loss'
Expand All @@ -71,6 +69,13 @@ def __init__(self,

for k in kwargs.keys():
self.__setattr__(k, kwargs[k])

if self.data_format == 'channels_first':
self.Chans = self.input_shape[1]
self.Samples = self.input_shape[2]
else:
self.Chans = self.input_shape[0]
self.Samples = self.input_shape[1]

np.random.seed(self.seed)
tf.random.set_seed(self.seed)
Expand Down Expand Up @@ -150,8 +155,12 @@ def fit(self, X_train, y_train, X_val, y_val):
raise Exception('ValueError: `X_val` is incompatible: expected ndim=4, found ndim='+str(X_val.ndim))

self.input_shape = X_train.shape[1:]
self.Chans = self.input_shape[0]
self.Samples = self.input_shape[1]
if self.data_format == 'channels_first':
self.Chans = self.input_shape[1]
self.Samples = self.input_shape[2]
else:
self.Chans = self.input_shape[0]
self.Samples = self.input_shape[1]

csv_logger = CSVLogger(self.csv_dir)
time_callback = TimeHistory(self.time_log)
Expand Down
21 changes: 15 additions & 6 deletions min2net/model/EEGNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class EEGNet:
def __init__(self,
input_shape=(20,400,1),
input_shape=(1,20,400),
num_class=2,
loss='sparse_categorical_crossentropy',
epochs=200,
Expand Down Expand Up @@ -49,16 +49,14 @@ def __init__(self,
self.time_log = log_path+'/'+model_name+'_time_log.csv'

# use **kwargs to set the new value of below args.
self.Chans = input_shape[0]
self.Samples = input_shape[1]
self.kernLength = 200
self.F1 = 8
self.D = 2
self.F2 = int(self.F1*self.D)
self.norm_rate = 0.25
self.dropout_rate = 0.5
self.f1_average = 'binary' if self.num_class == 2 else 'macro'
self.data_format = 'channels_last'
self.data_format = 'channels_first'
self.shuffle = False
self.metrics = 'accuracy'
self.monitor = 'val_loss'
Expand All @@ -71,6 +69,13 @@ def __init__(self,

for k in kwargs.keys():
self.__setattr__(k, kwargs[k])

if self.data_format == 'channels_first':
self.Chans = self.input_shape[1]
self.Samples = self.input_shape[2]
else:
self.Chans = self.input_shape[0]
self.Samples = self.input_shape[1]

np.random.seed(self.seed)
tf.random.set_seed(self.seed)
Expand Down Expand Up @@ -117,8 +122,12 @@ def fit(self, X_train, y_train, X_val, y_val):
raise Exception('ValueError: `X_val` is incompatible: expected ndim=4, found ndim='+str(X_val.ndim))

self.input_shape = X_train.shape[1:]
self.Chans = self.input_shape[0]
self.Samples = self.input_shape[1]
if self.data_format == 'channels_first':
self.Chans = self.input_shape[1]
self.Samples = self.input_shape[2]
else:
self.Chans = self.input_shape[0]
self.Samples = self.input_shape[1]

csv_logger = CSVLogger(self.csv_dir)
time_callback = TimeHistory(self.time_log)
Expand Down
5 changes: 4 additions & 1 deletion min2net/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def _change_data_format(self, X):
if self.data_format == 'NCTD':
# (#n_trial, #channels, #time, #depth)
X = X.reshape(X.shape[0], X.shape[1], X.shape[2], 1)
elif self.data_format == 'NDCT':
# (#n_trial, #depth, #channels, #time)
X = X.reshape(X.shape[0], 1, X.shape[1], X.shape[2])
elif self.data_format == 'NTCD':
# (#n_trial, #time, #channels, #depth)
X = X.reshape(X.shape[0], X.shape[1], X.shape[2], 1)
Expand All @@ -112,7 +115,7 @@ def _change_data_format(self, X):
elif self.data_format == None:
pass
else:
raise Exception('Value Error: data_format requires None, \'NCTD\', \'NTCD\' or \'NSHWD\', found data_format={}'.format(self.data_format))
raise Exception('Value Error: data_format requires None, \'NCTD\', \'NDCT\', \'NTCD\' or \'NSHWD\', found data_format={}'.format(self.data_format))
print('change data_format to \'{}\', new dimention is {}'.format(self.data_format, X.shape))
return X

Expand Down
Binary file modified supplementary materials/Supplementary Materials.pdf
Binary file not shown.

0 comments on commit 1ed001b

Please sign in to comment.