From 62f04420970a48bdcc288bef8d723ef5fc5aba8f Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 27 Jul 2023 18:41:16 -0500 Subject: [PATCH] [data] Add microbenchmark for reading and transforming images from preprocessed image files (#37610) MosaicML streaming and tf.data have their own specialized file formats to improve image load time. This adds a microbenchmark similar to read_images_comparison_microbenchmark_single_node, except that it first preprocesses the images into the right formats. --------- Signed-off-by: Stephanie Wang Signed-off-by: NripeshN --- release/nightly_tests/dataset/app_config.yaml | 3 + .../dataset/image_loader_microbenchmark.py | 27 +- .../dataset/preprocess_images.py | 164 +++++++++++ ...reprocessed_image_loader_microbenchmark.py | 261 ++++++++++++++++++ .../run_image_loader_microbenchmark.sh | 4 + ...reprocessed_image_loader_microbenchmark.sh | 24 ++ release/release_tests.yaml | 26 ++ 7 files changed, 503 insertions(+), 6 deletions(-) create mode 100644 release/nightly_tests/dataset/preprocess_images.py create mode 100644 release/nightly_tests/dataset/preprocessed_image_loader_microbenchmark.py create mode 100644 release/nightly_tests/dataset/run_preprocessed_image_loader_microbenchmark.sh diff --git a/release/nightly_tests/dataset/app_config.yaml b/release/nightly_tests/dataset/app_config.yaml index 12e8cd86d84b..3cc7d791696f 100644 --- a/release/nightly_tests/dataset/app_config.yaml +++ b/release/nightly_tests/dataset/app_config.yaml @@ -4,6 +4,9 @@ python: pip_packages: - boto3 - tqdm + - mosaicml-streaming + # Temporary fix to get around docker issue. + - typing-extensions<4.6.0 conda_packages: [] post_build_cmds: diff --git a/release/nightly_tests/dataset/image_loader_microbenchmark.py b/release/nightly_tests/dataset/image_loader_microbenchmark.py index 811fb849d370..f419355f7753 100644 --- a/release/nightly_tests/dataset/image_loader_microbenchmark.py +++ b/release/nightly_tests/dataset/image_loader_microbenchmark.py @@ -15,12 +15,15 @@ FULL_IMAGE_SIZE = (1213, 1546) -def iterate(dataset, label, metrics): +def iterate(dataset, label, batch_size, metrics): start = time.time() it = iter(dataset) num_rows = 0 for batch in it: - num_rows += len(batch) + # NOTE(swang): This will be slightly off if batch_size does not divide + # evenly into number of images but should be okay for large enough + # datasets. + num_rows += batch_size end = time.time() print(label, end - start, "epoch", i) @@ -98,6 +101,12 @@ def tf_crop_and_flip(image_buffer, num_channels=3): ) # Flip to add a little more random distortion in. image_buffer = tf.image.random_flip_left_right(image_buffer) + image_buffer = tf.compat.v1.image.resize( + image_buffer, + [DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE], + method=tf.image.ResizeMethod.BILINEAR, + align_corners=False, + ) return image_buffer @@ -162,21 +171,21 @@ def crop_and_flip_image_batch(image_batch): args.data_root, batch_size=args.batch_size, image_size=FULL_IMAGE_SIZE ) for i in range(args.num_epochs): - iterate(tf_dataset, "tf_data", metrics) + iterate(tf_dataset, "tf_data", args.batch_size, metrics) tf_dataset = tf_dataset.map(lambda img, label: (tf_crop_and_flip(img), label)) for i in range(args.num_epochs): - iterate(tf_dataset, "tf_data+transform", metrics) + iterate(tf_dataset, "tf_data+transform", args.batch_size, metrics) torch_dataset = build_torch_dataset( args.data_root, args.batch_size, transform=torchvision.transforms.ToTensor() ) for i in range(args.num_epochs): - iterate(torch_dataset, "torch", metrics) + iterate(torch_dataset, "torch", args.batch_size, metrics) torch_dataset = build_torch_dataset( args.data_root, args.batch_size, transform=get_transform(True) ) for i in range(args.num_epochs): - iterate(torch_dataset, "torch+transform", metrics) + iterate(torch_dataset, "torch+transform", args.batch_size, metrics) ray_dataset = ray.data.read_images(args.data_root).map_batches( crop_and_flip_image_batch @@ -185,6 +194,7 @@ def crop_and_flip_image_batch(image_batch): iterate( ray_dataset.iter_torch_batches(batch_size=args.batch_size), "ray_data+transform", + args.batch_size, metrics, ) @@ -195,6 +205,7 @@ def crop_and_flip_image_batch(image_batch): iterate( ray_dataset.iter_torch_batches(batch_size=args.batch_size), "ray_data+transform+zerocopy", + args.batch_size, metrics, ) @@ -203,6 +214,7 @@ def crop_and_flip_image_batch(image_batch): iterate( ray_dataset.iter_torch_batches(batch_size=args.batch_size), "ray_data", + args.batch_size, metrics, ) @@ -213,6 +225,7 @@ def crop_and_flip_image_batch(image_batch): iterate( ray_dataset.iter_torch_batches(batch_size=args.batch_size), "ray_data+dummy_pyarrow_transform", + args.batch_size, metrics, ) @@ -223,6 +236,7 @@ def crop_and_flip_image_batch(image_batch): iterate( ray_dataset.iter_torch_batches(batch_size=args.batch_size), "ray_data+dummy_np_transform", + args.batch_size, metrics, ) @@ -244,6 +258,7 @@ def load(batch): iterate( ray_dataset.iter_torch_batches(batch_size=args.batch_size), "ray_data_manual_load", + args.batch_size, metrics, ) diff --git a/release/nightly_tests/dataset/preprocess_images.py b/release/nightly_tests/dataset/preprocess_images.py new file mode 100644 index 000000000000..3d853759c277 --- /dev/null +++ b/release/nightly_tests/dataset/preprocess_images.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Download or generate a fake TF dataset from images. +""" + +from typing import Union, Iterable, Tuple +import os +import sys +import PIL + +import tensorflow.compat.v1 as tf +from streaming import MDSWriter + +import ray + + +class ImageCoder(object): + """Helper class that provides TensorFlow image coding utilities.""" + + def __init__(self): + tf.disable_v2_behavior() + + # Create a single Session to run all image coding calls. + self._sess = tf.Session() + + # Initializes function that decodes RGB JPEG data. + self._decode_jpeg_data = tf.placeholder(dtype=tf.string) + self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) + + def decode_jpeg(self, image_data: bytes) -> tf.Tensor: + """Decodes a JPEG image.""" + image = self._sess.run( + self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data} + ) + assert len(image.shape) == 3 + assert image.shape[2] == 3 + return image + + +def parse_single_image(image_path: str) -> Tuple[bytes, int, int]: + with open(image_path, "rb") as f: + image_buffer = f.read() + + coder = ImageCoder() + image = coder.decode_jpeg(image_buffer) + height, width, _ = image.shape + + return image_buffer, height, width + + +def create_single_example(image_path: str) -> tf.train.Example: + image_buffer, height, width = parse_single_image(image_path) + + label = 0 + + example = tf.train.Example( + features=tf.train.Features( + feature={ + "image/class/label": _int64_feature(label), + "image/encoded": _bytes_feature(image_buffer), + } + ) + ) + + return example + + +def _int64_feature(value: Union[int, Iterable[int]]) -> tf.train.Feature: + """Inserts int64 features into Example proto.""" + if not isinstance(value, list): + value = [value] + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + + +def _bytes_feature(value: Union[bytes, str]) -> tf.train.Feature: + """Inserts bytes features into Example proto.""" + if isinstance(value, str): + value = bytes(value, "utf-8") + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def preprocess_tfdata(data_root, tf_data_root, max_images_per_file): + examples = [] + num_shards = 0 + output_filename = os.path.join(tf_data_root, f"data-{num_shards}.tfrecord") + for image_path in os.listdir(data_root): + example = create_single_example( + os.path.join(data_root, image_path) + ).SerializeToString() + examples.append(example) + + if len(examples) >= max_images_per_file: + output_filename = os.path.join(tf_data_root, f"data-{num_shards}.tfrecord") + with tf.python_io.TFRecordWriter(output_filename) as writer: + for example in examples: + writer.write(example) + print(f"Done writing {output_filename}", file=sys.stderr) + examples = [] + num_shards += 1 + + output_filename = os.path.join(tf_data_root, f"data-{num_shards}.tfrecord") + with tf.python_io.TFRecordWriter(output_filename) as writer: + for example in examples: + writer.write(example) + print(f"Done writing {output_filename}", file=sys.stderr) + + +def preprocess_mosaic(input_dir, output_dir): + ds = ray.data.read_images(input_dir) + it = ds.iter_rows() + + columns = {"image": "jpeg", "label": "int"} + # If reading from local disk, should turn off compression and use + # streaming.LocalDataset. + # If uploading to S3, turn on compression (e.g., compression="snappy") and + # streaming.StreamingDataset. + with MDSWriter(out=output_dir, columns=columns, compression=None) as out: + for i, img in enumerate(it): + out.write( + { + "image": PIL.Image.fromarray(img["image"]), + "label": 0, + } + ) + if i % 10 == 0: + print(f"Wrote {i} images.") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Preprocess images -> TFRecords." # noqa: E501 + ) + parser.add_argument( + "--data-root", + default="/tmp/imagenet-1gb-data", + type=str, + help='Directory path with TFRecords. Filenames should start with "train".', + ) + parser.add_argument( + "--mosaic-data-root", + default="/tmp/mosaicml-data", + type=str, + help='Directory path with TFRecords. Filenames should start with "train".', + ) + parser.add_argument( + "--tf-data-root", + default="/tmp/tf-data", + type=str, + help='Directory path with TFRecords. Filenames should start with "train".', + ) + parser.add_argument( + "--max-images-per-file", + default=32, + type=int, + ) + + args = parser.parse_args() + + ray.init() + + preprocess_mosaic(args.data_root, args.mosaic_data_root) + preprocess_tfdata(args.data_root, args.tf_data_root, args.max_images_per_file) diff --git a/release/nightly_tests/dataset/preprocessed_image_loader_microbenchmark.py b/release/nightly_tests/dataset/preprocessed_image_loader_microbenchmark.py new file mode 100644 index 000000000000..3c9d029905f9 --- /dev/null +++ b/release/nightly_tests/dataset/preprocessed_image_loader_microbenchmark.py @@ -0,0 +1,261 @@ +import torch +import torchvision +import os +from typing import Any, Callable +import time +import tensorflow as tf +import pandas as pd +import json + +import ray +from streaming import LocalDataset + + +DEFAULT_IMAGE_SIZE = 224 + +# tf.data needs to resize all images to the same size when loading. +# This is the size of dog.jpg in s3://air-cuj-imagenet-1gb. +FULL_IMAGE_SIZE = (1213, 1546) + + +def iterate(dataset, label, batch_size, metrics): + start = time.time() + it = iter(dataset) + num_rows = 0 + for batch in it: + # NOTE(swang): This will be slightly off if batch_size does not divide + # evenly into number of images but should be okay for large enough + # datasets. + num_rows += batch_size + end = time.time() + print(label, end - start, "epoch", i) + + tput = num_rows / (end - start) + metrics[label] = tput + + +class MosaicDataset(LocalDataset): + def __init__(self, local: str, transforms: Callable) -> None: + super().__init__(local=local) + self.transforms = transforms + + def __getitem__(self, idx: int) -> Any: + obj = super().__getitem__(idx) + image = obj["image"] + label = obj["label"] + return self.transforms(image), label + + +def parse_and_decode_tfrecord(example_serialized): + feature_map = { + "image/encoded": tf.io.FixedLenFeature([], dtype=tf.string, default_value=""), + "image/class/label": tf.io.FixedLenFeature( + [], dtype=tf.int64, default_value=-1 + ), + } + + features = tf.io.parse_single_example(example_serialized, feature_map) + label = tf.cast(features["image/class/label"], dtype=tf.int32) + + image_buffer = features["image/encoded"] + image_buffer = tf.reshape(image_buffer, shape=[]) + image_buffer = tf.io.decode_jpeg(image_buffer, channels=3) + return image_buffer, label + + +def tf_crop_and_flip(image_buffer, num_channels=3): + """Crops the given image to a random part of the image, and randomly flips. + + We use the fused decode_and_crop op, which performs better than the two ops + used separately in series, but note that this requires that the image be + passed in as an un-decoded string Tensor. + + Args: + image_buffer: scalar string Tensor representing the raw JPEG image buffer. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + num_channels: Integer depth of the image buffer for decoding. + + Returns: + 3-D tensor with cropped image. + + """ + # A large fraction of image datasets contain a human-annotated bounding box + # delineating the region of the image containing the object of interest. We + # choose to create a new bounding box for the object which is a randomly + # distorted version of the human-annotated bounding box that obeys an + # allowed range of aspect ratios, sizes and overlap with the human-annotated + # bounding box. If no box is supplied, then we assume the bounding box is + # the entire image. + shape = tf.shape(image_buffer) + bbox = tf.constant( + [0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4] + ) # From the entire image + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=0.1, + aspect_ratio_range=[0.75, 1.33], + area_range=[0.05, 1.0], + max_attempts=100, + use_image_if_no_bounding_boxes=True, + ) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Reassemble the bounding box in the format the crop op requires. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + + image_buffer = tf.image.crop_to_bounding_box( + image_buffer, + offset_height=offset_y, + offset_width=offset_x, + target_height=target_height, + target_width=target_width, + ) + # Flip to add a little more random distortion in. + image_buffer = tf.image.random_flip_left_right(image_buffer) + image_buffer = tf.compat.v1.image.resize( + image_buffer, + [DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE], + method=tf.image.ResizeMethod.BILINEAR, + align_corners=False, + ) + return image_buffer + + +def build_tf_dataset(data_root, batch_size): + filenames = [ + os.path.join(data_root, pathname) for pathname in os.listdir(data_root) + ] + ds = tf.data.Dataset.from_tensor_slices(filenames) + ds = ds.interleave(tf.data.TFRecordDataset).map( + parse_and_decode_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + ds = ds.map(lambda img, label: (tf_crop_and_flip(img), label)) + ds = ds.batch(batch_size) + return ds + + +def decode_crop_and_flip_tf_record_batch(tf_record_batch: pd.DataFrame) -> pd.DataFrame: + """ + This version of the preprocessor fuses the load step with the crop and flip + step, which should have better performance (at the cost of re-executing the + load step on each epoch): + - the reference tf.data implementation can use the fused decode_and_crop op + - ray.data doesn't have to materialize the intermediate decoded batch. + """ + + def process_images(): + for image_buffer in tf_record_batch["image/encoded"]: + # Each image output is ~600KB. + image_buffer = tf.reshape(image_buffer, shape=[]) + image_buffer = tf.io.decode_jpeg(image_buffer, channels=3) + yield tf_crop_and_flip(image_buffer).numpy() + + labels = (tf_record_batch["image/class/label"]).astype("float32") + df = pd.DataFrame.from_dict({"image": process_images(), "label": labels}) + + return df + + +def build_ray_dataset(data_root, batch_size): + filenames = [ + os.path.join(data_root, pathname) for pathname in os.listdir(data_root) + ] + ds = ray.data.read_tfrecords(filenames) + ds = ds.map_batches(decode_crop_and_flip_tf_record_batch, batch_format="pandas") + return ds + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data-root", + default="/tmp/imagenet-1gb-data", + type=str, + help='Directory path with TFRecords. Filenames should start with "train".', + ) + parser.add_argument( + "--mosaic-data-root", + default="/tmp/mosaicml-data", + type=str, + help='Directory path with TFRecords. Filenames should start with "train".', + ) + parser.add_argument( + "--tf-data-root", + default="/tmp/tf-data", + type=str, + help='Directory path with TFRecords. Filenames should start with "train".', + ) + parser.add_argument( + "--batch-size", + default=32, + type=int, + help="Batch size to use.", + ) + parser.add_argument( + "--num-epochs", + default=3, + type=int, + help="Number of epochs to run. The throughput for the last epoch will be kept.", + ) + args = parser.parse_args() + + metrics = {} + + # MosaicML streaming. + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.RandomResizedCrop( + size=DEFAULT_IMAGE_SIZE, + scale=(0.05, 1.0), + ratio=(0.75, 1.33), + ), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + ] + ) + mosaic_ds = MosaicDataset(args.mosaic_data_root, transforms=transform) + num_workers = os.cpu_count() + mosaic_dl = torch.utils.data.DataLoader( + mosaic_ds, batch_size=args.batch_size, num_workers=num_workers + ) + for i in range(args.num_epochs): + iterate(mosaic_dl, "mosaic", args.batch_size, metrics) + + # Tf.data. + tf_ds = build_tf_dataset(args.tf_data_root, args.batch_size) + for i in range(args.num_epochs): + iterate(tf_ds, "tf_data", args.batch_size, metrics) + + # ray.data. + ray_ds = build_ray_dataset(args.tf_data_root, args.batch_size) + for i in range(args.num_epochs): + iterate( + ray_ds.iter_batches(batch_size=args.batch_size), + "ray_tfrecords", + args.batch_size, + metrics, + ) + + metrics_dict = {} + for label, tput in metrics.items(): + metrics_dict[label] = { + "THROUGHPUT": tput, + } + result_dict = { + "perf_metrics": metrics_dict, + "success": 1, + } + + test_output_json = os.environ.get( + "TEST_OUTPUT_JSON", "/tmp/preprocessed_image_loader_microbenchmark.json" + ) + + with open(test_output_json, "wt") as f: + json.dump(result_dict, f) diff --git a/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh b/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh index 6b47faba71ec..5db752acb24a 100755 --- a/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh +++ b/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh @@ -1,5 +1,9 @@ #!/bin/bash + +# Exit if any of the test commands fail. +set -e pipeline + INPUT_DIR=~/imagenet-1gb OUTPUT_DIR=~/imagenet-1gb-data diff --git a/release/nightly_tests/dataset/run_preprocessed_image_loader_microbenchmark.sh b/release/nightly_tests/dataset/run_preprocessed_image_loader_microbenchmark.sh new file mode 100644 index 000000000000..0a1daf4fb9ad --- /dev/null +++ b/release/nightly_tests/dataset/run_preprocessed_image_loader_microbenchmark.sh @@ -0,0 +1,24 @@ +#!/bin/bash + + +# Exit if any of the test commands fail. +set -x -e pipeline + +DIR="/tmp/imagenet-1gb" +MOSAIC_DIR="/tmp/mosaicml-data" +TFRECORDS_DIR="/tmp/tf-data" + +rm -rf "$DIR" +rm -rf "$MOSAIC_DIR" +rm -rf "$TFRECORDS_DIR" + +mkdir -p "$DIR" +mkdir -p "$MOSAIC_DIR" +mkdir -p "$TFRECORDS_DIR" + +# Download 1GB dataset from S3 to local disk so we can preprocess with mosaic. +aws s3 sync s3://air-cuj-imagenet-1gb $DIR + +python preprocess_images.py --data-root "$DIR" --tf-data-root "$TFRECORDS_DIR" --mosaic-data-root "$MOSAIC_DIR" --tf-data-root "$TFRECORDS_DIR" + +python preprocessed_image_loader_microbenchmark.py --data-root $DIR --mosaic-data-root "$MOSAIC_DIR" --tf-data-root "$TFRECORDS_DIR" diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 50ce061bf402..34877b1a28b1 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -6282,6 +6282,32 @@ cluster_env: app_config.yaml cluster_compute: single_node_benchmark_compute_gce.yaml +- name: read_preprocessed_images_comparison_microbenchmark_single_node + group: data-tests + working_dir: nightly_tests/dataset + + frequency: nightly + team: data + python: "3.8" + cluster: + byod: + type: gpu + cluster_env: app_config.yaml + cluster_compute: single_node_benchmark_compute.yaml + + run: + timeout: 1800 + script: bash run_preprocessed_image_loader_microbenchmark.sh + + variations: + - __suffix__: aws + - __suffix__: gce + env: gce + frequency: manual + cluster: + cluster_env: app_config.yaml + cluster_compute: single_node_benchmark_compute_gce.yaml + - name: read_images_train_4_cpu group: data-tests working_dir: nightly_tests/dataset