Skip to content

Commit

Permalink
[data] Add microbenchmark for reading and transforming images from pr…
Browse files Browse the repository at this point in the history
…eprocessed image files (#37610)

MosaicML streaming and tf.data have their own specialized file formats to improve image load time. This adds a microbenchmark similar to read_images_comparison_microbenchmark_single_node, except that it first preprocesses the images into the right formats.

---------

Signed-off-by: Stephanie Wang <[email protected]>
  • Loading branch information
stephanie-wang authored Jul 27, 2023
1 parent 6f106ca commit e8db5da
Show file tree
Hide file tree
Showing 7 changed files with 503 additions and 6 deletions.
3 changes: 3 additions & 0 deletions release/nightly_tests/dataset/app_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ python:
pip_packages:
- boto3
- tqdm
- mosaicml-streaming
# Temporary fix to get around docker issue.
- typing-extensions<4.6.0
conda_packages: []

post_build_cmds:
Expand Down
27 changes: 21 additions & 6 deletions release/nightly_tests/dataset/image_loader_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
FULL_IMAGE_SIZE = (1213, 1546)


def iterate(dataset, label, metrics):
def iterate(dataset, label, batch_size, metrics):
start = time.time()
it = iter(dataset)
num_rows = 0
for batch in it:
num_rows += len(batch)
# NOTE(swang): This will be slightly off if batch_size does not divide
# evenly into number of images but should be okay for large enough
# datasets.
num_rows += batch_size
end = time.time()
print(label, end - start, "epoch", i)

Expand Down Expand Up @@ -98,6 +101,12 @@ def tf_crop_and_flip(image_buffer, num_channels=3):
)
# Flip to add a little more random distortion in.
image_buffer = tf.image.random_flip_left_right(image_buffer)
image_buffer = tf.compat.v1.image.resize(
image_buffer,
[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE],
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False,
)
return image_buffer


Expand Down Expand Up @@ -162,21 +171,21 @@ def crop_and_flip_image_batch(image_batch):
args.data_root, batch_size=args.batch_size, image_size=FULL_IMAGE_SIZE
)
for i in range(args.num_epochs):
iterate(tf_dataset, "tf_data", metrics)
iterate(tf_dataset, "tf_data", args.batch_size, metrics)
tf_dataset = tf_dataset.map(lambda img, label: (tf_crop_and_flip(img), label))
for i in range(args.num_epochs):
iterate(tf_dataset, "tf_data+transform", metrics)
iterate(tf_dataset, "tf_data+transform", args.batch_size, metrics)

torch_dataset = build_torch_dataset(
args.data_root, args.batch_size, transform=torchvision.transforms.ToTensor()
)
for i in range(args.num_epochs):
iterate(torch_dataset, "torch", metrics)
iterate(torch_dataset, "torch", args.batch_size, metrics)
torch_dataset = build_torch_dataset(
args.data_root, args.batch_size, transform=get_transform(True)
)
for i in range(args.num_epochs):
iterate(torch_dataset, "torch+transform", metrics)
iterate(torch_dataset, "torch+transform", args.batch_size, metrics)

ray_dataset = ray.data.read_images(args.data_root).map_batches(
crop_and_flip_image_batch
Expand All @@ -185,6 +194,7 @@ def crop_and_flip_image_batch(image_batch):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+transform",
args.batch_size,
metrics,
)

Expand All @@ -195,6 +205,7 @@ def crop_and_flip_image_batch(image_batch):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+transform+zerocopy",
args.batch_size,
metrics,
)

Expand All @@ -203,6 +214,7 @@ def crop_and_flip_image_batch(image_batch):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data",
args.batch_size,
metrics,
)

Expand All @@ -213,6 +225,7 @@ def crop_and_flip_image_batch(image_batch):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+dummy_pyarrow_transform",
args.batch_size,
metrics,
)

Expand All @@ -223,6 +236,7 @@ def crop_and_flip_image_batch(image_batch):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+dummy_np_transform",
args.batch_size,
metrics,
)

Expand All @@ -244,6 +258,7 @@ def load(batch):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data_manual_load",
args.batch_size,
metrics,
)

Expand Down
164 changes: 164 additions & 0 deletions release/nightly_tests/dataset/preprocess_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""
Download or generate a fake TF dataset from images.
"""

from typing import Union, Iterable, Tuple
import os
import sys
import PIL

import tensorflow.compat.v1 as tf
from streaming import MDSWriter

import ray


class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""

def __init__(self):
tf.disable_v2_behavior()

# Create a single Session to run all image coding calls.
self._sess = tf.Session()

# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

def decode_jpeg(self, image_data: bytes) -> tf.Tensor:
"""Decodes a JPEG image."""
image = self._sess.run(
self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}
)
assert len(image.shape) == 3
assert image.shape[2] == 3
return image


def parse_single_image(image_path: str) -> Tuple[bytes, int, int]:
with open(image_path, "rb") as f:
image_buffer = f.read()

coder = ImageCoder()
image = coder.decode_jpeg(image_buffer)
height, width, _ = image.shape

return image_buffer, height, width


def create_single_example(image_path: str) -> tf.train.Example:
image_buffer, height, width = parse_single_image(image_path)

label = 0

example = tf.train.Example(
features=tf.train.Features(
feature={
"image/class/label": _int64_feature(label),
"image/encoded": _bytes_feature(image_buffer),
}
)
)

return example


def _int64_feature(value: Union[int, Iterable[int]]) -> tf.train.Feature:
"""Inserts int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value: Union[bytes, str]) -> tf.train.Feature:
"""Inserts bytes features into Example proto."""
if isinstance(value, str):
value = bytes(value, "utf-8")
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def preprocess_tfdata(data_root, tf_data_root, max_images_per_file):
examples = []
num_shards = 0
output_filename = os.path.join(tf_data_root, f"data-{num_shards}.tfrecord")
for image_path in os.listdir(data_root):
example = create_single_example(
os.path.join(data_root, image_path)
).SerializeToString()
examples.append(example)

if len(examples) >= max_images_per_file:
output_filename = os.path.join(tf_data_root, f"data-{num_shards}.tfrecord")
with tf.python_io.TFRecordWriter(output_filename) as writer:
for example in examples:
writer.write(example)
print(f"Done writing {output_filename}", file=sys.stderr)
examples = []
num_shards += 1

output_filename = os.path.join(tf_data_root, f"data-{num_shards}.tfrecord")
with tf.python_io.TFRecordWriter(output_filename) as writer:
for example in examples:
writer.write(example)
print(f"Done writing {output_filename}", file=sys.stderr)


def preprocess_mosaic(input_dir, output_dir):
ds = ray.data.read_images(input_dir)
it = ds.iter_rows()

columns = {"image": "jpeg", "label": "int"}
# If reading from local disk, should turn off compression and use
# streaming.LocalDataset.
# If uploading to S3, turn on compression (e.g., compression="snappy") and
# streaming.StreamingDataset.
with MDSWriter(out=output_dir, columns=columns, compression=None) as out:
for i, img in enumerate(it):
out.write(
{
"image": PIL.Image.fromarray(img["image"]),
"label": 0,
}
)
if i % 10 == 0:
print(f"Wrote {i} images.")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(
description="Preprocess images -> TFRecords." # noqa: E501
)
parser.add_argument(
"--data-root",
default="/tmp/imagenet-1gb-data",
type=str,
help='Directory path with TFRecords. Filenames should start with "train".',
)
parser.add_argument(
"--mosaic-data-root",
default="/tmp/mosaicml-data",
type=str,
help='Directory path with TFRecords. Filenames should start with "train".',
)
parser.add_argument(
"--tf-data-root",
default="/tmp/tf-data",
type=str,
help='Directory path with TFRecords. Filenames should start with "train".',
)
parser.add_argument(
"--max-images-per-file",
default=32,
type=int,
)

args = parser.parse_args()

ray.init()

preprocess_mosaic(args.data_root, args.mosaic_data_root)
preprocess_tfdata(args.data_root, args.tf_data_root, args.max_images_per_file)
Loading

0 comments on commit e8db5da

Please sign in to comment.