-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensor.py
154 lines (112 loc) · 4.83 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import tensteorflow as tf
import numpy as np
import os
import csv
import random
# true trains the model and false produces an output based on the training checkpoints
isTraining = false
# open (cleaned) tweet file
text = open('tweets.txt').read()
print(len(text))
# uncomment to truncate input for quicker training
# text = text[:100000]
# mapping unique characters to indicies
vocabulary = sorted(set(text))
print('{} unique characters'.format(len(vocabulary)))
char2idx = {u: i for i, u in enumerate(vocabulary)}
idx2char = np.array(vocabulary)
text_as_int = np.array([char2idx[c] for c in text])
# This variable controls the maximum input sequence length
seq_length = 280
examples_per_epoch = len(text) // (seq_length + 1)
# Creating training sets
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
for i in char_dataset.take(5):
print(idx2char[i.numpy()])
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)
for item in sequences.take(5):
print(repr(''.join(idx2char[item.numpy()])))
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
print('Input data: ', repr(''.join(idx2char[input_example.numpy()])))
print('Target data:', repr(''.join(idx2char[target_example.numpy()])))
for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):
print("Step {:4d}".format(i))
print(" input: {} ({:s})".format(input_idx, repr(idx2char[input_idx])))
print(" expected output: {} ({:s})".format(target_idx, repr(idx2char[target_idx])))
# Batch & Buffer sizes
BATCH_SIZE = 32
BUFFER_SIZE = 10000
# shuffles the dataset
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
vocabulary_size = len(vocabulary)
embedding_dim = 256
rnn_units = 1024
# constructs tensorflow model
def build_model(vocabulary_size, embedding_dim, rnn_units, batch_size):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocabulary_size, embedding_dim,
batch_input_shape=[batch_size, None]),
tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dense(vocabulary_size)
])
return model
model = build_model(
vocabulary_size=len(vocabulary),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE)
for input_example_batch, target_example_batch in dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocabulary_size)")
model.summary()
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
# loss fn
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
model.compile(optimizer='adam', loss=loss)
# save training checkpoints
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
if(isTraining):
EPOCHS = 5
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
else:
tf.train.latest_checkpoint(checkpoint_dir)
model = build_model(vocabulary_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([32, None]))
def generate_text(model, start_string):
#max num of chars to generate
num_generate = 250
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
#this var can be edited to change the output, lower numbers produce more predictable outputs
temperature = .80
model.reset_states()
for i in range(num_generate):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0)
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])
return (start_string + ''.join(text_generated))
# generated text starters
starters = ["Americans ", "U.S. ", "We must ", "Lazy Joe Biden ", "Dems ", "The ", "We "]
for i in range(0, 10):
index = random.randint(0, len(starters) - 1)
print("tweet ", i, ": ", generate_text(model, start_string=starters[index]), "\n")