Skip to content

Commit

Permalink
Add DNS optimize support (#19429)
Browse files Browse the repository at this point in the history
* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <[email protected]>
(cherry picked from commit 4c2fc3b)
  • Loading branch information
tchaton authored and Borda committed Feb 12, 2024
1 parent 2e36719 commit e30d787
Show file tree
Hide file tree
Showing 15 changed files with 111 additions and 26 deletions.
File renamed without changes.
10 changes: 5 additions & 5 deletions src/lightning/data/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
from tqdm.auto import tqdm as _tqdm

from lightning import seed_everything
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.constants import (
from lightning.data.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_LATEST,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.resolver import _resolve_dir
from lightning.data.utilities.broadcast import broadcast_object
from lightning.data.utilities.packing import _pack_greedily
Expand Down
47 changes: 47 additions & 0 deletions src/lightning/data/processing/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from contextlib import contextmanager
from subprocess import Popen
from typing import Any

from lightning.data.constants import _IS_IN_STUDIO


@contextmanager
def optimize_dns_context(enable: bool) -> Any:
optimize_dns(enable)
try:
yield
optimize_dns(False) # always disable the optimize DNS
except Exception as e:
optimize_dns(False) # always disable the optimize DNS
raise e

def optimize_dns(enable: bool) -> None:
if not _IS_IN_STUDIO:
return

with open("/etc/resolv.conf") as f:
lines = f.readlines()

if (
(enable and any("127.0.0.53" in line for line in lines))
or (not enable and any("127.0.0.1" in line for line in lines))
): # noqa E501
Popen(f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns({enable})'", shell=True).wait() # noqa E501

def _optimize_dns(enable: bool) -> None:
with open("/etc/resolv.conf") as f:
lines = f.readlines()

write_lines = []
for line in lines:
if "nameserver 127" in line:
if enable:
write_lines.append('nameserver 127.0.0.1\n')
else:
write_lines.append('nameserver 127.0.0.53\n')
else:
write_lines.append(line)

with open("/etc/resolv.conf", "w") as f:
for line in write_lines:
f.write(line)
25 changes: 15 additions & 10 deletions src/lightning/data/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

import torch

from lightning.data.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.processing.dns import optimize_dns_context
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.resolver import (
Dir,
_assert_dir_has_index_file,
Expand Down Expand Up @@ -218,7 +219,8 @@ def map(
weights=weights,
reader=reader,
)
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
with optimize_dns_context(True):
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return _execute(
f"data-prep-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
Expand Down Expand Up @@ -303,15 +305,18 @@ def optimize(
reorder_files=reorder_files,
reader=reader,
)
return data_processor.run(
LambdaDataChunkRecipe(
fn,
inputs,
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,

with optimize_dns_context(True):
data_processor.run(
LambdaDataChunkRecipe(
fn,
inputs,
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
)
)
)
return None
return _execute(
f"data-prep-optimize-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.streaming.constants import (
from lightning.data.constants import (
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TORCH_GREATER_EQUAL_2_1_0,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from typing import Any, Optional

from lightning.data.streaming.constants import _BOTO3_AVAILABLE
from lightning.data.constants import _BOTO3_AVAILABLE

if _BOTO3_AVAILABLE:
import boto3
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple

from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.downloader import get_downloader_cls
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from lightning.data.streaming.sampler import ChunkedIndex
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
)
from torch.utils.data.sampler import BatchSampler, Sampler

from lightning.data.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
from lightning.data.streaming import Cache
from lightning.data.streaming.combined import (
__NUM_SAMPLES_YIELDED_KEY__,
__SAMPLES_KEY__,
CombinedStreamingDataset,
)
from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.sampler import CacheBatchSampler
from lightning.data.utilities.env import _DistributedEnv
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import numpy as np
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import (
from lightning.data.constants import (
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
)
from lightning.data.streaming import Cache
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.resolver import Dir, _resolve_dir
from lightning.data.streaming.sampler import ChunkedIndex
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from filelock import FileLock, Timeout

from lightning.data.constants import _INDEX_FILENAME
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.constants import _INDEX_FILENAME


class Downloader(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch

from lightning.data.streaming.constants import (
from lightning.data.constants import (
_TORCH_DTYPES_MAPPING,
_TORCH_GREATER_EQUAL_2_1_0,
)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.constants import _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.config import ChunksConfig
from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer, _get_serializers
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from lightning.data.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING

_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import numpy as np
import torch

from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.serializers import Serializer, _get_serializers
from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv
from lightning.data.utilities.format import _convert_bytes_to_int, _human_readable_bytes
Expand Down
33 changes: 33 additions & 0 deletions tests/tests_data/processing/test_dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import MagicMock

from lightning.data.processing import dns as dns_module
from lightning.data.processing.dns import optimize_dns_context


def test_optimize_dns_context(monkeypatch):
popen_mock = MagicMock()

monkeypatch.setattr(dns_module, "_IS_IN_STUDIO", True)
monkeypatch.setattr(dns_module, "Popen", popen_mock)

class FakeFile:

def __init__(self, *args, **kwargs):
pass

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
return self

def readlines(self):
return ["127.0.0.53"]

monkeypatch.setitem(__builtins__, "open", MagicMock(return_value=FakeFile()))

with optimize_dns_context(True):
pass

cmd = popen_mock._mock_call_args_list[0].args[0]
assert cmd == "sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns(True)'" # noqa: E501

0 comments on commit e30d787

Please sign in to comment.