This repository has been archived by the owner on Apr 28, 2024. It is now read-only.
forked from yilundu/comet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tetrominoes.py
67 lines (57 loc) · 2.73 KB
/
tetrominoes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tetrominoes dataset reader."""
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
IMAGE_SIZE = [35, 35]
# The maximum number of foreground and background entities in the provided
# dataset. This corresponds to the number of segmentation masks returned per
# scene.
MAX_NUM_ENTITIES = 4
BYTE_FEATURES = ['mask', 'image']
# Create a dictionary mapping feature names to `tf.Example`-compatible
# shape and data type descriptors.
features = {
'image': tf.io.FixedLenFeature(IMAGE_SIZE+[3], tf.string),
'mask': tf.io.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string),
'x': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
'y': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
'shape': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
'color': tf.io.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32),
'visibility': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
}
def _decode(example_proto):
# Parse the input `tf.Example` proto using the feature description dict above.
single_example = tf.io.parse_single_example(example_proto, features)
for k in BYTE_FEATURES:
single_example[k] = tf.squeeze(tf.io.decode_raw(single_example[k], tf.uint8),
axis=-1)
return single_example
def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None):
"""Read, decompress, and parse the TFRecords file.
Args:
tfrecords_path: str. Path to the dataset file.
read_buffer_size: int. Number of bytes in the read buffer. See documentation
for `tf.data.TFRecordDataset.__init__`.
map_parallel_calls: int. Number of elements decoded asynchronously in
parallel. See documentation for `tf.data.Dataset.map`.
Returns:
An unbatched `tf.data.TFRecordDataset`.
"""
raw_dataset = tf.data.TFRecordDataset(
tfrecords_path, compression_type=COMPRESSION_TYPE,
buffer_size=read_buffer_size)
return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls)