From e30d787f8adccb221fa89010e77827c083d06b98 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 8 Feb 2024 11:14:57 +0000 Subject: [PATCH] Add DNS optimize support (#19429) * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * update * update * update * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas (cherry picked from commit 4c2fc3b0cb293a776a39f69e59e6bdc03adc2ef2) --- .../data/{streaming => }/constants.py | 0 .../data/processing/data_processor.py | 10 ++-- src/lightning/data/processing/dns.py | 47 +++++++++++++++++++ src/lightning/data/processing/functions.py | 25 ++++++---- src/lightning/data/streaming/cache.py | 2 +- src/lightning/data/streaming/client.py | 2 +- src/lightning/data/streaming/config.py | 2 +- src/lightning/data/streaming/dataloader.py | 2 +- src/lightning/data/streaming/dataset.py | 4 +- src/lightning/data/streaming/downloader.py | 2 +- src/lightning/data/streaming/item_loader.py | 2 +- src/lightning/data/streaming/reader.py | 2 +- src/lightning/data/streaming/serializers.py | 2 +- src/lightning/data/streaming/writer.py | 2 +- tests/tests_data/processing/test_dns.py | 33 +++++++++++++ 15 files changed, 111 insertions(+), 26 deletions(-) rename src/lightning/data/{streaming => }/constants.py (100%) create mode 100644 src/lightning/data/processing/dns.py create mode 100644 tests/tests_data/processing/test_dns.py diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/constants.py similarity index 100% rename from src/lightning/data/streaming/constants.py rename to src/lightning/data/constants.py diff --git a/src/lightning/data/processing/data_processor.py b/src/lightning/data/processing/data_processor.py index 0449a1218623a..3a757857b08f8 100644 --- a/src/lightning/data/processing/data_processor.py +++ b/src/lightning/data/processing/data_processor.py @@ -20,11 +20,7 @@ from tqdm.auto import tqdm as _tqdm from lightning import seed_everything -from lightning.data.processing.readers import BaseReader -from lightning.data.streaming import Cache -from lightning.data.streaming.cache import Dir -from lightning.data.streaming.client import S3Client -from lightning.data.streaming.constants import ( +from lightning.data.constants import ( _BOTO3_AVAILABLE, _DEFAULT_FAST_DEV_RUN_ITEMS, _INDEX_FILENAME, @@ -32,6 +28,10 @@ _LIGHTNING_CLOUD_LATEST, _TORCH_GREATER_EQUAL_2_1_0, ) +from lightning.data.processing.readers import BaseReader +from lightning.data.streaming import Cache +from lightning.data.streaming.cache import Dir +from lightning.data.streaming.client import S3Client from lightning.data.streaming.resolver import _resolve_dir from lightning.data.utilities.broadcast import broadcast_object from lightning.data.utilities.packing import _pack_greedily diff --git a/src/lightning/data/processing/dns.py b/src/lightning/data/processing/dns.py new file mode 100644 index 0000000000000..f1ca83dbc0e2e --- /dev/null +++ b/src/lightning/data/processing/dns.py @@ -0,0 +1,47 @@ +from contextlib import contextmanager +from subprocess import Popen +from typing import Any + +from lightning.data.constants import _IS_IN_STUDIO + + +@contextmanager +def optimize_dns_context(enable: bool) -> Any: + optimize_dns(enable) + try: + yield + optimize_dns(False) # always disable the optimize DNS + except Exception as e: + optimize_dns(False) # always disable the optimize DNS + raise e + +def optimize_dns(enable: bool) -> None: + if not _IS_IN_STUDIO: + return + + with open("/etc/resolv.conf") as f: + lines = f.readlines() + + if ( + (enable and any("127.0.0.53" in line for line in lines)) + or (not enable and any("127.0.0.1" in line for line in lines)) + ): # noqa E501 + Popen(f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns({enable})'", shell=True).wait() # noqa E501 + +def _optimize_dns(enable: bool) -> None: + with open("/etc/resolv.conf") as f: + lines = f.readlines() + + write_lines = [] + for line in lines: + if "nameserver 127" in line: + if enable: + write_lines.append('nameserver 127.0.0.1\n') + else: + write_lines.append('nameserver 127.0.0.53\n') + else: + write_lines.append(line) + + with open("/etc/resolv.conf", "w") as f: + for line in write_lines: + f.write(line) diff --git a/src/lightning/data/processing/functions.py b/src/lightning/data/processing/functions.py index 6939418ad617c..00905aa40dcd4 100644 --- a/src/lightning/data/processing/functions.py +++ b/src/lightning/data/processing/functions.py @@ -22,9 +22,10 @@ import torch +from lightning.data.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe +from lightning.data.processing.dns import optimize_dns_context from lightning.data.processing.readers import BaseReader -from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.resolver import ( Dir, _assert_dir_has_index_file, @@ -218,7 +219,8 @@ def map( weights=weights, reader=reader, ) - return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) + with optimize_dns_context(True): + return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) return _execute( f"data-prep-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}", num_nodes, @@ -303,15 +305,18 @@ def optimize( reorder_files=reorder_files, reader=reader, ) - return data_processor.run( - LambdaDataChunkRecipe( - fn, - inputs, - chunk_size=chunk_size, - chunk_bytes=chunk_bytes, - compression=compression, + + with optimize_dns_context(True): + data_processor.run( + LambdaDataChunkRecipe( + fn, + inputs, + chunk_size=chunk_size, + chunk_bytes=chunk_bytes, + compression=compression, + ) ) - ) + return None return _execute( f"data-prep-optimize-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}", num_nodes, diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index 667409c81b8d8..305f393d276ac 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -15,7 +15,7 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.data.streaming.constants import ( +from lightning.data.constants import ( _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST, _TORCH_GREATER_EQUAL_2_1_0, diff --git a/src/lightning/data/streaming/client.py b/src/lightning/data/streaming/client.py index c42d3af9f6ed1..4a572548854de 100644 --- a/src/lightning/data/streaming/client.py +++ b/src/lightning/data/streaming/client.py @@ -2,7 +2,7 @@ from time import time from typing import Any, Optional -from lightning.data.streaming.constants import _BOTO3_AVAILABLE +from lightning.data.constants import _BOTO3_AVAILABLE if _BOTO3_AVAILABLE: import boto3 diff --git a/src/lightning/data/streaming/config.py b/src/lightning/data/streaming/config.py index 386ea475e7681..4a5a4ba8c55b2 100644 --- a/src/lightning/data/streaming/config.py +++ b/src/lightning/data/streaming/config.py @@ -15,7 +15,7 @@ import os from typing import Any, Dict, List, Optional, Tuple -from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 +from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.downloader import get_downloader_cls from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader from lightning.data.streaming.sampler import ChunkedIndex diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 942d2f98a3090..04793296ca30f 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -33,13 +33,13 @@ ) from torch.utils.data.sampler import BatchSampler, Sampler +from lightning.data.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE from lightning.data.streaming import Cache from lightning.data.streaming.combined import ( __NUM_SAMPLES_YIELDED_KEY__, __SAMPLES_KEY__, CombinedStreamingDataset, ) -from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE from lightning.data.streaming.dataset import StreamingDataset from lightning.data.streaming.sampler import CacheBatchSampler from lightning.data.utilities.env import _DistributedEnv diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index e281bbc0a1e2d..da0028184e123 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -19,11 +19,11 @@ import numpy as np from torch.utils.data import IterableDataset -from lightning.data.streaming import Cache -from lightning.data.streaming.constants import ( +from lightning.data.constants import ( _DEFAULT_CACHE_DIR, _INDEX_FILENAME, ) +from lightning.data.streaming import Cache from lightning.data.streaming.item_loader import BaseItemLoader from lightning.data.streaming.resolver import Dir, _resolve_dir from lightning.data.streaming.sampler import ChunkedIndex diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index 982b4d25142b3..03d7b9302068a 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -19,8 +19,8 @@ from filelock import FileLock, Timeout +from lightning.data.constants import _INDEX_FILENAME from lightning.data.streaming.client import S3Client -from lightning.data.streaming.constants import _INDEX_FILENAME class Downloader(ABC): diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 779a683146182..2a1a02da67293 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -19,7 +19,7 @@ import numpy as np import torch -from lightning.data.streaming.constants import ( +from lightning.data.constants import ( _TORCH_DTYPES_MAPPING, _TORCH_GREATER_EQUAL_2_1_0, ) diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 50705cb18f663..298452ea685f7 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -20,8 +20,8 @@ from threading import Thread from typing import Any, Dict, List, Optional, Tuple, Union +from lightning.data.constants import _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.config import ChunksConfig -from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader from lightning.data.streaming.sampler import ChunkedIndex from lightning.data.streaming.serializers import Serializer, _get_serializers diff --git a/src/lightning/data/streaming/serializers.py b/src/lightning/data/streaming/serializers.py index b689429953c21..9e5a92f520741 100644 --- a/src/lightning/data/streaming/serializers.py +++ b/src/lightning/data/streaming/serializers.py @@ -23,7 +23,7 @@ import torch from lightning_utilities.core.imports import RequirementCache -from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING +from lightning.data.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") diff --git a/src/lightning/data/streaming/writer.py b/src/lightning/data/streaming/writer.py index 44e6a8951773f..98b7a31e07d1b 100644 --- a/src/lightning/data/streaming/writer.py +++ b/src/lightning/data/streaming/writer.py @@ -21,8 +21,8 @@ import numpy as np import torch +from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.compression import _COMPRESSORS, Compressor -from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.serializers import Serializer, _get_serializers from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv from lightning.data.utilities.format import _convert_bytes_to_int, _human_readable_bytes diff --git a/tests/tests_data/processing/test_dns.py b/tests/tests_data/processing/test_dns.py new file mode 100644 index 0000000000000..0703e346b02fd --- /dev/null +++ b/tests/tests_data/processing/test_dns.py @@ -0,0 +1,33 @@ +from unittest.mock import MagicMock + +from lightning.data.processing import dns as dns_module +from lightning.data.processing.dns import optimize_dns_context + + +def test_optimize_dns_context(monkeypatch): + popen_mock = MagicMock() + + monkeypatch.setattr(dns_module, "_IS_IN_STUDIO", True) + monkeypatch.setattr(dns_module, "Popen", popen_mock) + + class FakeFile: + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + return self + + def readlines(self): + return ["127.0.0.53"] + + monkeypatch.setitem(__builtins__, "open", MagicMock(return_value=FakeFile())) + + with optimize_dns_context(True): + pass + + cmd = popen_mock._mock_call_args_list[0].args[0] + assert cmd == "sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns(True)'" # noqa: E501