diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 9e2c68ac6f..7a2eaa424a 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -7,6 +7,7 @@ from torch import Tensor from pl_bolts.datasets import LightDataset +from pl_bolts.datasets.utils import safe_extract_tarfile from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg @@ -118,7 +119,7 @@ def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]: def _extract_archive_save_torch(self, download_path): # extract achieve with tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar: - tar.extractall(path=download_path) + safe_extract_tarfile(tar, path=download_path) # this is internal path in the archive path_content = os.path.join(download_path, "cifar-10-batches-py") diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index a150a0eb55..518c022682 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -1,22 +1,17 @@ -import gzip import hashlib import os import shutil -import sys -import tarfile import tempfile -import zipfile from contextlib import contextmanager import numpy as np import torch +from pl_bolts.datasets.utils import extract_archive from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -PY3 = sys.version_info[0] == 3 - if _TORCHVISION_AVAILABLE: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file @@ -247,59 +242,3 @@ def get_tmp_dir(): META_FILE = "meta.bin" torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) - - -@under_review() -def extract_archive(from_path, to_path=None, remove_finished=False): - if to_path is None: - to_path = os.path.dirname(from_path) - - PY3 = sys.version_info[0] == 3 - - if _is_tar(from_path): - with tarfile.open(from_path, "r") as tar: - tar.extractall(path=to_path) - elif _is_targz(from_path): - with tarfile.open(from_path, "r:gz") as tar: - tar.extractall(path=to_path) - elif _is_tarxz(from_path) and PY3: - # .tar.xz archive only supported in Python 3.x - with tarfile.open(from_path, "r:xz") as tar: - tar.extractall(path=to_path) - elif _is_gzip(from_path): - to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) - with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: - out_f.write(zip_f.read()) - elif _is_zip(from_path): - with zipfile.ZipFile(from_path, "r") as z: - z.extractall(to_path) - else: - raise ValueError(f"Extraction of {from_path} not supported") - - if remove_finished: - os.remove(from_path) - - -@under_review() -def _is_targz(filename): - return filename.endswith(".tar.gz") - - -@under_review() -def _is_tarxz(filename): - return filename.endswith(".tar.xz") - - -@under_review() -def _is_gzip(filename): - return filename.endswith(".gz") and not filename.endswith(".tar.gz") - - -@under_review() -def _is_tar(filename): - return filename.endswith(".tar") - - -@under_review() -def _is_zip(filename): - return filename.endswith(".zip") diff --git a/pl_bolts/datasets/utils.py b/pl_bolts/datasets/utils.py index 3c0214bc21..f6fcf1a7f2 100644 --- a/pl_bolts/datasets/utils.py +++ b/pl_bolts/datasets/utils.py @@ -1,3 +1,9 @@ +import gzip +import os +import tarfile +import zipfile +from typing import List, Optional + import torch from torch.utils.data.dataset import random_split @@ -55,3 +61,63 @@ def to_tensor(arrays: TArrays) -> torch.Tensor: Tensor of the integers """ return torch.tensor(arrays) + + +def is_within_directory(directory: str, target: str) -> bool: + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + +def safe_extract_tarfile( + tar: tarfile.TarFile, + path: str = ".", + members: Optional[List[tarfile.TarInfo]] = None, + *, + numeric_owner: bool = False, +) -> None: + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise RuntimeError(f"Attempted Path Traversal in Tar File {tar.name} with member: {member.name}") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: + if to_path is None: + to_path = os.path.dirname(from_path) + + extracted = False + for fn in (_extract_tar, _extract_gzip, _extract_zip): + try: + fn(from_path, to_path) + extracted = True + break + except (tarfile.TarError, zipfile.BadZipfile, OSError): + continue + + if not extracted: + raise ValueError(f"Extraction of {from_path} not supported") + + if remove_finished: + os.remove(from_path) + + +def _extract_tar(from_path: str, to_path: str) -> None: + with tarfile.open(from_path, "r:*") as tar: + safe_extract_tarfile(tar, path=to_path) + + +def _extract_gzip(from_path: str, to_path: str) -> None: + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + + +def _extract_zip(from_path: str, to_path: str) -> None: + with zipfile.ZipFile(from_path, "r") as z: + z.extractall(to_path)