Skip to content

Commit

Permalink
add stride and change parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
icoxfog417 committed Jul 14, 2017
1 parent 4a86190 commit e494408
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion model/augmented_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self,
sequence_size,
setting=None,
checkpoint_path="",
temperature=10,
temperature=20,
tying=False):

super().__init__(vocab_size, sequence_size, setting, checkpoint_path)
Expand Down
6 changes: 3 additions & 3 deletions model/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_wiki2(self, data_root, vocab_size=30000, force=False):
r_idx = r.to_indexed().make_vocab(vocab_size=vocab_size, force=force)
return r_idx

def make_batch_iter(self, r_idx, kind="train", batch_size=20, sequence_size=35):
def make_batch_iter(self, r_idx, kind="train", batch_size=20, sequence_size=35, stride=0):
# count all tokens
word_count = 0
path = r_idx.train_file_path
Expand All @@ -34,14 +34,14 @@ def make_batch_iter(self, r_idx, kind="train", batch_size=20, sequence_size=35):

vocab_size = len(r_idx.vocab_data())
steps_per_epoch = 0
for i in range(sequence_size):
for i in range(stride + 1):
steps_per_epoch += (word_count - i - 1) // sequence_size
steps_per_epoch = steps_per_epoch // batch_size

def generator():
while True:
buffer = []
for i in range(sequence_size):
for i in range(stride + 1):
initial_slide = False
with open(path, encoding="utf-8") as f:
for line in f:
Expand Down
2 changes: 1 addition & 1 deletion model/one_hot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def compile(self):
@classmethod
def perplexity(cls, y_true, y_pred):
cross_entropy = K.mean(K.categorical_crossentropy(y_pred, y_true))
perplexity = K.exp(cross_entropy)
perplexity = K.pow(2.0, cross_entropy)
return perplexity

def fit(self, x_train, y_train, x_test, y_test, batch_size=20, epochs=20):
Expand Down
4 changes: 2 additions & 2 deletions model/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def __init__(self, network_size="small", dataset_kind="ptb"):
self.dropout = 0.35 if dataset_kind == "ptb" else 0.6

if dataset_kind == "ptb":
self.gamma = 0.5
self.gamma = 0.65 # 0.5~0.8
elif kind == "wiki2":
self.gamma = 1.0
self.gamma = 1.25 # 1.0~1.5
17 changes: 9 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ def prepare_dataset(dataset_kind):
return dataset


def train_baseline(network_size, dataset_kind, epochs=40):
def train_baseline(network_size, dataset_kind, epochs=40, stride=0):
# prepare the data
setting = ProposedSetting(network_size, dataset_kind)
dataset = prepare_dataset(dataset_kind)
vocab_size = len(dataset.vocab_data())
sequence_size = 20

dp = DataProcessor()
train_steps, train_generator = dp.make_batch_iter(dataset, sequence_size=sequence_size)
valid_steps, valid_generator = dp.make_batch_iter(dataset, kind="valid", sequence_size=sequence_size)
train_steps, train_generator = dp.make_batch_iter(dataset, sequence_size=sequence_size, stride=stride)
valid_steps, valid_generator = dp.make_batch_iter(dataset, kind="valid", sequence_size=sequence_size, stride=stride)

# make one hot model
model = OneHotModel(vocab_size, sequence_size, setting, LOG_ROOT)
Expand All @@ -40,16 +40,16 @@ def train_baseline(network_size, dataset_kind, epochs=40):
model.save(MODEL_ROOT)


def train_augmented(network_size, dataset_kind, tying=False, epochs=40):
def train_augmented(network_size, dataset_kind, tying=False, epochs=40, stride=0):
# prepare the data
setting = ProposedSetting(network_size, dataset_kind)
dataset = prepare_dataset(dataset_kind)
vocab_size = len(dataset.vocab_data())
sequence_size = 20

dp = DataProcessor()
train_steps, train_generator = dp.make_batch_iter(dataset, sequence_size=sequence_size)
valid_steps, valid_generator = dp.make_batch_iter(dataset, kind="valid", sequence_size=sequence_size)
train_steps, train_generator = dp.make_batch_iter(dataset, sequence_size=sequence_size, stride=stride)
valid_steps, valid_generator = dp.make_batch_iter(dataset, kind="valid", sequence_size=sequence_size, stride=stride)

# make one hot model
model = AugmentedModel(vocab_size, sequence_size, setting, tying=tying, checkpoint_path=LOG_ROOT)
Expand All @@ -67,6 +67,7 @@ def train_augmented(network_size, dataset_kind, tying=False, epochs=40):
parser.add_argument("--nsize", default="small", help="network size (small, medium, large)")
parser.add_argument("--dataset", default="ptb", help="dataset kind (ptb or wiki2)")
parser.add_argument("--epochs", type=int, default=40, help="epoch to train")
parser.add_argument("--stride", type=int, default=0, help="stride of the sequence")
args = parser.parse_args()

n_size = args.nsize
Expand All @@ -77,6 +78,6 @@ def train_augmented(network_size, dataset_kind, tying=False, epochs=40):

if args.aug or args.tying:
print("Use Augmented Model (tying={})".format(args.tying))
train_augmented(n_size, dataset, args.tying, args.epochs)
train_augmented(n_size, dataset, args.tying, args.epochs, args.stride)
else:
train_baseline(n_size, dataset, args.epochs)
train_baseline(n_size, dataset, args.epochs, args.stride)

0 comments on commit e494408

Please sign in to comment.