Skip to content
This repository has been archived by the owner on Mar 11, 2021. It is now read-only.

Commit

Permalink
Filter dataset before training to 5% of the window. (#93)
Browse files Browse the repository at this point in the history
* add filter param to preprocessing for get_input_tensors

* remove premature references to symmetries

* rename argument to better reflect usage

* address PR comments
  • Loading branch information
amj authored Jan 28, 2018
1 parent 506f349 commit 84e9a21
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
2 changes: 0 additions & 2 deletions dual_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def train(self, tf_records, init_from=None, logdir=None, num_steps=None):
tensor_values = self.sess.run(train_tensors)
except tf.errors.OutOfRangeError:
break
if i % 1000 == 0:
print(tensor_values)
if logdir is not None:
training_stats.report(
tensor_values['policy_cost'],
Expand Down
28 changes: 18 additions & 10 deletions preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import numpy as np
import tensorflow as tf
import random

import coords
import features as features_lib
Expand Down Expand Up @@ -82,29 +83,33 @@ def batch_parse_tf_example(batch_size, example_batch):
}

def read_tf_records(batch_size, tf_records, num_repeats=None,
shuffle_records=True, shuffle_examples=True):
shuffle_records=True, shuffle_examples=True,
filter_amount=1.0):
'''
Args:
batch_size: batch size to return
tf_records: a list of tf_record filenames
num_repeats: how many times the data should be read (default: infinite)
shuffle_records: whether to shuffle the order of files read
shuffle_examples: whether to shuffle the tf.Examples
filter_amount: what fraction of records to keep
Returns:
a dict of batched tensors
a tf dataset of batched tensors
'''
# compression_type here must agree with write_tf_examples
# cycle_length = how many tfrecord files are read in parallel
# block_length = how many tf.Examples are read from each file before
# moving to the next file
# The idea is to shuffle both the order of the files being read,
# and the examples being read from the files.
record_list = tf.data.Dataset.from_tensor_slices(tf_records)
if shuffle_records:
record_list = record_list.shuffle(buffer_size=1000)
random.shuffle(tf_records)
record_list = tf.data.Dataset.from_tensor_slices(tf_records)
dataset = record_list.interleave(lambda x:
tf.data.TFRecordDataset(x, compression_type='ZLIB'),
cycle_length=8, block_length=16)
cycle_length=64, block_length=16)
dataset = dataset.filter(lambda x: tf.less(tf.random_uniform([1]), filter_amount)[0])
# TODO(amj): apply py_func for transforms here.
if num_repeats is not None:
dataset = dataset.repeat(num_repeats)
else:
Expand All @@ -115,13 +120,16 @@ def read_tf_records(batch_size, tf_records, num_repeats=None,
return dataset

def get_input_tensors(batch_size, tf_records, num_repeats=None,
shuffle_records=True, shuffle_examples=True):
shuffle_records=True, shuffle_examples=True,
filter_amount=0.05):
'''Read tf.Records and prepare them for ingestion by dual_net
Returns a dict of tensors (see return value of batch_parse_tf_example)
'''
dataset = read_tf_records(batch_size, tf_records, num_repeats=num_repeats,
shuffle_records=shuffle_records, shuffle_examples=shuffle_examples)
shuffle_records=shuffle_records,
shuffle_examples=shuffle_examples,
filter_amount=filter_amount)
dataset = dataset.filter(lambda t: tf.equal(tf.shape(t)[0], batch_size))
dataset = dataset.map(functools.partial(batch_parse_tf_example, batch_size))
return dataset.make_one_shot_iterator().get_next()
Expand Down Expand Up @@ -149,16 +157,16 @@ def _make_tf_example_from_pwc(position_w_context):
value = position_w_context.result
return make_tf_example(features, pi, value)

def shuffle_tf_examples(batch_size, records_to_shuffle):
def shuffle_tf_examples(gather_size, records_to_shuffle):
'''Read through tf.Record and yield shuffled, but unparsed tf.Examples
Args:
batch_size: The number of tf.Examples to be batched together
gather_size: The number of tf.Examples to be gathered together
records_to_shuffle: A list of filenames
Returns:
An iterator yielding lists of bytes, which are serialized tf.Examples.
'''
dataset = read_tf_records(batch_size, records_to_shuffle, num_repeats=1)
dataset = read_tf_records(gather_size, records_to_shuffle, num_repeats=1)
batch = dataset.make_one_shot_iterator().get_next()
sess = tf.Session()
while True:
Expand Down
17 changes: 15 additions & 2 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def create_random_data(self, num_examples):
raw_data.append((feature, pi, value))
return raw_data

def extract_data(self, tf_record):
def extract_data(self, tf_record,filter_amount=1):
tf_example_tensor = preprocessing.get_input_tensors(
1, [tf_record], num_repeats=1, shuffle_records=False, shuffle_examples=False)
1, [tf_record], num_repeats=1, shuffle_records=False,
shuffle_examples=False, filter_amount=filter_amount)
recovered_data = []
with tf.Session() as sess:
while True:
Expand Down Expand Up @@ -62,6 +63,18 @@ def test_serialize_round_trip(self):

self.assertEqualData(raw_data, recovered_data)

def test_filter(self):
raw_data = self.create_random_data(100)
tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

with tempfile.NamedTemporaryFile() as f:
preprocessing.write_tf_examples(f.name, tfexamples)
recovered_data = self.extract_data(f.name, filter_amount=.05)

#TODO: this will flake out very infrequently. Use set_random_seed
self.assertLess(len(recovered_data), 50)


def test_serialize_round_trip_no_parse(self):
np.random.seed(1)
raw_data = self.create_random_data(10)
Expand Down

0 comments on commit 84e9a21

Please sign in to comment.