Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove image size limits. #3062

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion nerfstudio/data/dataparsers/colmap_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Data parser for nerfstudio datasets. """
"""Data parser for nerfstudio datasets."""

from __future__ import annotations

Expand Down Expand Up @@ -39,10 +39,12 @@
get_train_eval_split_interval,
)
from nerfstudio.process_data.colmap_utils import parse_colmap_camera_params
from nerfstudio.utils.misc import set_pil_image_size_limit
from nerfstudio.utils.rich_utils import CONSOLE, status
from nerfstudio.utils.scripts import run_command

MAX_AUTO_RESOLUTION = 1600
set_pil_image_size_limit(None)


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Data parser for nerfstudio datasets. """
"""Data parser for nerfstudio datasets."""

from __future__ import annotations

Expand All @@ -34,9 +34,11 @@
get_train_eval_split_interval,
)
from nerfstudio.utils.io import load_from_json
from nerfstudio.utils.misc import set_pil_image_size_limit
from nerfstudio.utils.rich_utils import CONSOLE

MAX_AUTO_RESOLUTION = 1600
set_pil_image_size_limit(None)


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions nerfstudio/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
Dataset.
"""

from __future__ import annotations

from copy import deepcopy
Expand All @@ -32,6 +33,9 @@
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path
from nerfstudio.utils.misc import set_pil_image_size_limit

set_pil_image_size_limit(None)


class InputDataset(Dataset):
Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/data/datasets/depth_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import get_depth_image_from_path
from nerfstudio.model_components import losses
from nerfstudio.utils.misc import torch_compile
from nerfstudio.utils.misc import set_pil_image_size_limit, torch_compile
from nerfstudio.utils.rich_utils import CONSOLE

set_pil_image_size_limit(None)


class DepthDataset(InputDataset):
"""Dataset that returns images and depths. If no depths are found, then we generate them with Zoe Depth.
Expand Down
5 changes: 5 additions & 0 deletions nerfstudio/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Utility functions to allow easy re-use of common operations across dataloaders"""

from pathlib import Path
from typing import List, Tuple, Union

Expand All @@ -21,6 +22,10 @@
import torch
from PIL import Image

from nerfstudio.utils.misc import set_pil_image_size_limit

set_pil_image_size_limit(None)


def get_image_mask_tensor_from_path(filepath: Path, scale_factor: float = 1.0) -> torch.Tensor:
"""
Expand Down
7 changes: 4 additions & 3 deletions nerfstudio/generative/deepfloyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from torch import Generator, Tensor, nn
from torch.cuda.amp.grad_scaler import GradScaler

from nerfstudio.utils.misc import set_pil_image_size_limit
from nerfstudio.utils.rich_utils import CONSOLE

IMG_DIM = 64
set_pil_image_size_limit(None)


class DeepFloyd(nn.Module):
Expand Down Expand Up @@ -206,16 +208,15 @@ def prompt_to_image(
Returns:
The generated image.
"""

from diffusers import DiffusionPipeline, IFPipeline as IFOrig
from diffusers import DiffusionPipeline, IFPipeline
from diffusers.pipelines.deepfloyd_if import IFPipelineOutput as IFOutputOrig

prompts = [prompts] if isinstance(prompts, str) else prompts
negative_prompts = [negative_prompts] if isinstance(negative_prompts, str) else negative_prompts
assert isinstance(self.pipe, DiffusionPipeline)
prompt_embeds, negative_embeds = self.pipe.encode_prompt(prompts, negative_prompt=negative_prompts)

assert isinstance(self.pipe, IFOrig)
assert isinstance(self.pipe, IFPipeline)
model_output = self.pipe(
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator
)
Expand Down
3 changes: 3 additions & 0 deletions nerfstudio/process_data/realitycapture_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
from PIL import Image

from nerfstudio.process_data.process_data_utils import CAMERA_MODELS
from nerfstudio.utils.misc import set_pil_image_size_limit
from nerfstudio.utils.rich_utils import CONSOLE

set_pil_image_size_limit(None)


def realitycapture_to_json(
image_filename_map: Dict[str, Path],
Expand Down
3 changes: 3 additions & 0 deletions nerfstudio/scripts/datasets/process_project_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import tyro
from PIL import Image

from nerfstudio.utils.misc import set_pil_image_size_limit

try:
from projectaria_tools.core import mps
from projectaria_tools.core.data_provider import VrsDataProvider, create_vrs_data_provider
Expand All @@ -34,6 +36,7 @@
sys.exit(1)

ARIA_CAMERA_MODEL = "FISHEYE624"
set_pil_image_size_limit(None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we define set_pil_image_size_limit() as a context manager, it needs to be used in a with statement

there's an example in the Python docs here: https://docs.python.org/3/library/contextlib.html#contextlib.contextmanager


# The Aria coordinate system is different than the Blender/NerfStudio coordinate system.
# Blender / Nerfstudio: +Z = back, +Y = up, +X = right
Expand Down
17 changes: 16 additions & 1 deletion nerfstudio/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
Miscellaneous helper code.
"""


import contextlib
import platform
import typing
import warnings
from inspect import currentframe
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

import torch
from PIL import Image

T = TypeVar("T")
TKey = TypeVar("TKey")
Expand Down Expand Up @@ -219,3 +220,17 @@ def get_orig_class(obj, default=None):
finally:
del frame
return default


@contextlib.contextmanager
def set_pil_image_size_limit(max_pixels: Optional[Any]):
"""By default PIL limits the max image size preventing processing or training with high resolution images.
Use this function to disable or set a custom image size limit.

:param max_pixels: Max number of pixels for image processing in PIL.
:type max_pixels: Optional[int | None]
"""
orig = Image.MAX_IMAGE_PIXELS
Image.MAX_IMAGE_PIXELS = max_pixels
yield
Image.MAX_IMAGE_PIXELS = orig
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ pythonPlatform = "Linux"
[tool.ruff]
line-length = 120
respect-gitignore = false
select = [
lint.select = [
"E", # pycodestyle errors.
"F", # Pyflakes rules.
"I", # isort formatting.
Expand All @@ -173,7 +173,7 @@ select = [
"PLR", # Pylint refactor recommendations.
"PLW", # Pylint warnings.
]
ignore = [
lint.ignore = [
"E501", # Line too long.
"F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
"F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
Expand Down
Loading