Skip to content

Commit

Permalink
Cherrypick fixes to 0.5 (#2257)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao authored Jul 12, 2022
1 parent 09f7056 commit dab4c24
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
push:
branches: [master]
pull_request:
branches: [master]
branches: [master, "*-stable"]

# we want an ongoing run of this workflow to be canceled by a later commit
# so that there is only one concurrent run of this workflow for each branch
Expand Down
3 changes: 3 additions & 0 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@ def load(

config = backend.broadcast_return(lambda: load_json(os.path.join(model_dir, MODEL_HYPERPARAMETERS_FILE_NAME)))

# Upgrades deprecated fields and adds new required fields, in case the config loaded from disk is old.
config = merge_with_defaults(config)

if backend_param is None and "backend" in config:
# Reset backend from config
backend = initialize_backend(config.get("backend"))
Expand Down
28 changes: 14 additions & 14 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def __init__(
self.dataset_shard = dataset_shard
self.features = features
self.training_set_metadata = training_set_metadata
self.dataset_iter = dataset_shard.iter_datasets()
self.epoch_iter = dataset_shard.iter_epochs()

@contextlib.contextmanager
def initialize_batcher(self, batch_size=128, should_shuffle=True, seed=0, ignore_last=False, horovod=None):
yield RayDatasetBatcher(
self.dataset_iter,
self.epoch_iter,
self.features,
self.training_set_metadata,
batch_size,
Expand All @@ -171,7 +171,7 @@ def initialize_batcher(self, batch_size=128, should_shuffle=True, seed=0, ignore
@lru_cache(1)
def __len__(self):
# TODO(travis): find way to avoid calling this, as it's expensive
return next(self.dataset_iter).count()
return next(self.epoch_iter).count()

@property
def size(self):
Expand All @@ -181,7 +181,7 @@ def size(self):
class RayDatasetBatcher(Batcher):
def __init__(
self,
dataset_epoch_iterator: Iterator[ray.data.Dataset],
dataset_epoch_iterator: Iterator[DatasetPipeline],
features: Dict[str, Dict],
training_set_metadata: Dict[str, Any],
batch_size: int,
Expand Down Expand Up @@ -233,17 +233,17 @@ def steps_per_epoch(self):
return math.ceil(self.samples_per_epoch / self.batch_size)

def _fetch_next_epoch(self):
dataset = next(self.dataset_epoch_iterator)
pipeline = next(self.dataset_epoch_iterator)

read_parallelism = 1
if read_parallelism == 1:
self.dataset_batch_iter = self._create_async_reader(dataset)
self.dataset_batch_iter = self._create_async_reader(pipeline)
elif read_parallelism > 1:
# TODO: consider removing this. doesn't work currently and read performance seems generally
# very good with 1 parallelism
self.dataset_batch_iter = self._create_async_parallel_reader(dataset, read_parallelism)
self.dataset_batch_iter = self._create_async_parallel_reader(pipeline, read_parallelism)
else:
self.dataset_batch_iter = self._create_sync_reader(dataset)
self.dataset_batch_iter = self._create_sync_reader(pipeline)

self._step = 0
self._fetch_next_batch()
Expand Down Expand Up @@ -285,26 +285,26 @@ def _prepare_batch(self, batch: pd.DataFrame) -> Dict[str, np.ndarray]:

return res

def _create_sync_reader(self, dataset: ray.data.Dataset):
def _create_sync_reader(self, pipeline: DatasetPipeline):
to_tensors = self._to_tensors_fn()

def sync_read():
for batch in dataset.map_batches(to_tensors, batch_format="pandas").iter_batches(
for batch in pipeline.map_batches(to_tensors, batch_format="pandas").iter_batches(
prefetch_blocks=0, batch_size=self.batch_size, batch_format="pandas"
):
yield self._prepare_batch(batch)

return sync_read()

def _create_async_reader(self, dataset: ray.data.Dataset):
def _create_async_reader(self, pipeline: DatasetPipeline):
q = queue.Queue(maxsize=100)

batch_size = self.batch_size

to_tensors = self._to_tensors_fn()

def producer():
for batch in dataset.map_batches(to_tensors, batch_format="pandas").iter_batches(
for batch in pipeline.map_batches(to_tensors, batch_format="pandas").iter_batches(
prefetch_blocks=0, batch_size=batch_size, batch_format="pandas"
):
res = self._prepare_batch(batch)
Expand All @@ -323,13 +323,13 @@ def async_read():

return async_read()

def _create_async_parallel_reader(self, dataset: ray.data.Dataset, num_threads: int):
def _create_async_parallel_reader(self, pipeline: DatasetPipeline, num_threads: int):
q = queue.Queue(maxsize=100)

batch_size = self.batch_size

to_tensors = self._to_tensors_fn()
splits = dataset.split(n=num_threads)
splits = pipeline.split(n=num_threads)

def producer(i):
for batch in (
Expand Down
12 changes: 9 additions & 3 deletions ludwig/data/dataset_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def build_feature_parameters(features):


def build_synthetic_dataset(dataset_size: int, features: List[dict]):
"""Symthesizes a dataset for testing purposes.
"""Synthesizes a dataset for testing purposes.
:param dataset_size: (int) size of the dataset
:param features: (List[dict]) list of features to generate in YAML format.
Expand Down Expand Up @@ -277,7 +277,9 @@ def generate_audio(feature):
return audio_dest_path


def generate_image(feature):
def generate_image(feature, save_as_numpy=False):
save_as_numpy = feature.get("save_as_numpy", save_as_numpy)

try:
from torchvision.io import write_png
except ImportError:
Expand Down Expand Up @@ -315,7 +317,11 @@ def generate_image(feature):

image_dest_path = os.path.join(destination_folder, image_filename)
# save_image(torch.from_numpy(img.astype("uint8")), image_dest_path)
write_png(img, image_dest_path)
if save_as_numpy:
with open(image_dest_path, "wb") as f:
np.save(f, img.detach().cpu().numpy())
else:
write_png(img, image_dest_path)

except OSError as e:
raise OSError("Unable to create a folder for images/save image to disk." "{}".format(e))
Expand Down
19 changes: 16 additions & 3 deletions ludwig/features/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_feature_meta(column, preprocessing_parameters, backend):

@staticmethod
def _read_image_if_bytes_obj_and_resize(
img_entry: Union[bytes, torch.Tensor],
img_entry: Union[bytes, torch.Tensor, np.ndarray],
img_width: int,
img_height: int,
should_resize: bool,
Expand All @@ -175,7 +175,7 @@ def _read_image_if_bytes_obj_and_resize(
user_specified_num_channels: bool,
) -> Optional[np.ndarray]:
"""
:param img_entry Union[bytes, torch.Tensor]: if str file path to the
:param img_entry Union[bytes, torch.Tensor, np.ndarray]: if str file path to the
image else torch.Tensor of the image itself
:param img_width: expected width of the image
:param img_height: expected height of the image
Expand All @@ -194,8 +194,11 @@ def _read_image_if_bytes_obj_and_resize(
If the user specifies a number of channels, we try to convert all the
images to the specifications by dropping channels/padding 0 channels
"""

if isinstance(img_entry, bytes):
img = read_image_from_bytes_obj(img_entry, num_channels)
elif isinstance(img_entry, np.ndarray):
img = torch.from_numpy(img_entry).permute(2, 0, 1)
else:
img = img_entry

Expand Down Expand Up @@ -333,16 +336,26 @@ def _finalize_preprocessing_parameters(
else:
sample_size = 1 # Take first image

failed_entries = []
for image_entry in column.head(sample_size):
if isinstance(image_entry, str):
# Tries to read image as PNG or numpy file from the path.
image = read_image_from_path(image_entry)
else:
image = image_entry

if isinstance(image, torch.Tensor):
sample.append(image)
elif isinstance(image, np.ndarray):
sample.append(torch.from_numpy(image).permute(2, 0, 1))
else:
failed_entries.append(image_entry)
if len(sample) == 0:
raise ValueError("No readable images in sample, image dimensions cannot be inferred")
failed_entries_repr = "\n\t- ".join(failed_entries)
raise ValueError(
f"Images dimensions cannot be inferred. Failed to read {sample_size} images as samples:\n\t- "
f"{failed_entries_repr}."
)

should_resize = False
if explicit_height_width:
Expand Down
4 changes: 3 additions & 1 deletion ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
from ray import tune
from ray.tune import register_trainable, Stopper
from ray.tune.schedulers.resource_changing_scheduler import DistributeResources, ResourceChangingScheduler
from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter, SEARCH_ALG_IMPORT
from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter

_ray_114 = version.parse(ray.__version__) >= version.parse("1.14")
if _ray_114:
from ray.tune.search import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig
else:
from ray.tune.syncer import get_cloud_sync_client
from ray.tune.suggest import SEARCH_ALG_IMPORT

from ray.tune.utils import wait_for_gpu
from ray.tune.utils.placement_groups import PlacementGroupFactory
Expand Down
29 changes: 28 additions & 1 deletion ludwig/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,43 @@ def read_image_from_path(path: str, num_channels: Optional[int] = None) -> Optio
def read_image_from_bytes_obj(
bytes_obj: Optional[bytes] = None, num_channels: Optional[int] = None
) -> Optional[torch.Tensor]:
"""Tries to read image as a tensor from the path.
If the path is not decodable as a PNG, attempts to read as a numpy file. If neither of these work, returns None.
"""
mode = get_image_read_mode_from_num_channels(num_channels)

image = read_image_as_png(bytes_obj, mode)
if image is None:
image = read_image_as_numpy(bytes_obj)
if image is None:
logger.warning("Unable to read image from bytes object.")
return image


def read_image_as_png(
bytes_obj: Optional[bytes] = None, mode: ImageReadMode = ImageReadMode.UNCHANGED
) -> Optional[torch.Tensor]:
"""Reads image from bytes object from a PNG file."""
try:
with BytesIO(bytes_obj) as buffer:
buffer_view = buffer.getbuffer()
image = decode_image(torch.frombuffer(buffer_view, dtype=torch.uint8), mode=mode)
del buffer_view
return image
except Exception as e:
logger.warning("Failed to read image from bytes object. Original exception: " + str(e))
logger.warning(f"Failed to read image from PNG file. Original exception: {e}")
return None


def read_image_as_numpy(bytes_obj: Optional[bytes] = None) -> Optional[torch.Tensor]:
"""Reads image from bytes object from a numpy file."""
try:
with BytesIO(bytes_obj) as buffer:
image = np.load(buffer)
return torch.from_numpy(image)
except Exception as e:
logger.warning(f"Failed to read image from numpy file. Original exception: {e}")
return None


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ requests
tables
fsspec[http]
dataclasses-json
jsonschema>=4.5.0
jsonschema>=4.5.0,<4.7
marshmallow
marshmallow-jsonschema
marshmallow-dataclass==8.5.5
Expand Down
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pytest
pytest-timeout
wget
six>=1.13.0
aim
wandb<0.12.11
Expand Down
64 changes: 61 additions & 3 deletions tests/integration_tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import numpy as np
import pandas as pd
import pytest
from PIL import Image

from ludwig.api import LudwigModel
from ludwig.constants import COLUMN, PROC_COLUMN
from ludwig.constants import COLUMN, NAME, PROC_COLUMN, TRAINER
from ludwig.data.concatenate_datasets import concatenate_df
from tests.integration_tests.utils import (
audio_feature,
Expand All @@ -21,6 +22,8 @@
sequence_feature,
)

NUM_EXAMPLES = 10


@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.distributed
Expand Down Expand Up @@ -80,7 +83,7 @@ def test_strip_whitespace_category(csv_filename, tmpdir):
@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.distributed
def test_with_split(backend, csv_filename, tmpdir):
num_examples = 10
num_examples = NUM_EXAMPLES
train_set_size = int(num_examples * 0.8)
val_set_size = int(num_examples * 0.1)
test_set_size = int(num_examples * 0.1)
Expand Down Expand Up @@ -117,7 +120,7 @@ def test_with_split(backend, csv_filename, tmpdir):
def test_dask_known_divisions(feature_fn, csv_filename, tmpdir):
import dask.dataframe as dd

num_examples = 10
num_examples = NUM_EXAMPLES

input_features = [feature_fn(os.path.join(tmpdir, "generated_output"))]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]
Expand All @@ -144,6 +147,61 @@ def test_dask_known_divisions(feature_fn, csv_filename, tmpdir):
)


@pytest.mark.parametrize("generate_images_as_numpy", [False, True])
def test_read_image_from_path(tmpdir, csv_filename, generate_images_as_numpy):
input_features = [image_feature(os.path.join(tmpdir, "generated_output"), save_as_numpy=generate_images_as_numpy)]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]
data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=NUM_EXAMPLES
)

config = {
"input_features": input_features,
"output_features": output_features,
"trainer": {"epochs": 2},
}

model = LudwigModel(config)
model.preprocess(
data_csv,
skip_save_processed_input=False,
)


def test_read_image_from_numpy_array(tmpdir, csv_filename):
input_features = [image_feature(os.path.join(tmpdir, "generated_output"))]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]

config = {
"input_features": input_features,
"output_features": output_features,
TRAINER: {"epochs": 2},
}

data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=NUM_EXAMPLES
)

df = pd.read_csv(data_csv)
processed_df_rows = []

for _, row in df.iterrows():
processed_df_rows.append(
{
input_features[0][NAME]: np.array(Image.open(row[input_features[0][NAME]])),
output_features[0][NAME]: row[output_features[0][NAME]],
}
)

df_with_images_as_numpy_arrays = pd.DataFrame(processed_df_rows)

model = LudwigModel(config)
model.preprocess(
df_with_images_as_numpy_arrays,
skip_save_processed_input=False,
)


def test_number_feature_wrong_dtype(csv_filename, tmpdir):
"""Tests that a number feature with all string values is treated as having missing values by default."""
data_csv_path = os.path.join(tmpdir, csv_filename)
Expand Down
Loading

0 comments on commit dab4c24

Please sign in to comment.