Skip to content
This repository has been archived by the owner on Apr 12, 2023. It is now read-only.

Commit

Permalink
Add option to install local package prior to running (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Nov 23, 2022
1 parent b43c43d commit b7d4bea
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 17 deletions.
7 changes: 7 additions & 0 deletions lightning_hpo/commands/experiment/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def run(self) -> None:
type=str,
help="Which framework you are using.",
)
parser.add_argument(
"--pip-install-source",
default=False,
action="store_true",
help="Run `pip install -e .` on the uploaded source before running",
)

hparams, args = parser.parse_known_args()

Expand Down Expand Up @@ -106,6 +112,7 @@ def run(self) -> None:
direction="minimize", # This won't be used
experiments={0: ExperimentConfig(name=name, params={})},
disk_size=hparams.disk_size,
pip_install_source=hparams.pip_install_source,
data=data,
username=str(getuser()),
)
Expand Down
13 changes: 11 additions & 2 deletions lightning_hpo/commands/sweep/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SweepConfig(SQLModel, table=True):
stage: str = Stage.NOT_STARTED
desired_stage: str = Stage.RUNNING
disk_size: int = 80
pip_install_source: bool = False
data: Dict[str, Optional[str]] = Field(..., sa_column=Column(pydantic_column_type(Dict[str, Optional[str]])))
username: Optional[str] = None

Expand Down Expand Up @@ -263,7 +264,7 @@ def parse_grid_search(script_args, args):
if expected_value:
distributions[key] = expected_value
else:
script_args.append(f"--{key}={value}")
script_args.append(f"{key}={value}")
return distributions


Expand All @@ -290,7 +291,7 @@ def parse_random_search(script_args, args):
if expected_value:
distributions[key] = expected_value
else:
script_args.append(f"--{key}={value}")
script_args.append(f"{key}={value}")
return distributions


Expand Down Expand Up @@ -393,6 +394,13 @@ def run(self) -> None:
type=str,
help="Syntax for sweep parameters at the CLI.",
)
parser.add_argument(
"--pip-install-source",
default=False,
action="store_true",
help="Run `pip install -e .` on the uploaded source before running",
)

hparams, args = parser.parse_known_args()

if hparams.framework != "pytorch_lightning" and hparams.num_nodes > 1:
Expand Down Expand Up @@ -457,6 +465,7 @@ def run(self) -> None:
direction=hparams.direction,
experiments={},
disk_size=hparams.disk_size,
pip_install_source=hparams.pip_install_source,
data=data,
username=getuser(),
)
Expand Down
4 changes: 4 additions & 0 deletions lightning_hpo/components/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
experiments: Optional[Dict[int, Dict]] = None,
stage: Optional[str] = Stage.NOT_STARTED,
logger_url: str = "",
pip_install_source: bool = False,
data: Optional[List[str]] = None,
**objective_kwargs: Any,
):
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
self.experiments = experiments or {}
self.stage = stage
self.logger_url = logger_url
self.pip_install_source = pip_install_source
self.data = data

self._objective_cls = _resolve_objective_cls(objective_cls, framework)
Expand Down Expand Up @@ -257,6 +259,7 @@ def _get_objective(self, experiment_id: int):
experiment_name=experiment_config["name"],
cloud_compute=cloud_compute,
last_model_path=experiment_config["last_model_path"],
pip_install_source=self.pip_install_source,
**self._kwargs,
)
setattr(self, f"w_{experiment_id}", objective)
Expand Down Expand Up @@ -299,6 +302,7 @@ def from_config(cls, config: SweepConfig, code: Optional[Code] = None, mounts: O
stage=config.stage,
logger_url=config.logger_url,
data=config.data,
pip_install_source=config.pip_install_source,
requirements=config.requirements,
packages=config.packages,
)
Expand Down
17 changes: 17 additions & 0 deletions lightning_hpo/framework/agnostic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
import subprocess
import sys
import uuid
from abc import ABC
from typing import Any, Dict, Optional, TypedDict

Expand Down Expand Up @@ -28,6 +32,7 @@ def __init__(
function_name: str = "objective",
num_nodes: int = 1, # TODO # Add support for multi node
last_model_path: Optional[str] = None,
pip_install_source: bool = False,
**kwargs,
):
super().__init__(*args, raise_exception=raise_exception, **kwargs)
Expand All @@ -49,6 +54,8 @@ def __init__(
self.num_nodes = num_nodes
self.progress = None
self.last_model_path = last_model_path
self.pip_install_source = pip_install_source
self._rootwd = os.getcwd()

def configure_tracer(self):
assert self.params is not None
Expand All @@ -63,12 +70,22 @@ def configure_tracer(self):
return tracer

def run(self, params: Optional[Dict[str, Any]] = None, restart_count: int = 0):
if self.pip_install_source:
os.chdir(self._rootwd)
uid = uuid.uuid4().hex[:8]
dirname = f"uploaded-{uid}"
os.makedirs(dirname)
os.chdir(dirname)
self.params = params or {}
if is_overridden("objective", self, Objective):
self.objective(**self.params)
else:
return super().run(params=params)

def on_before_run(self):
if self.pip_install_source:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."])

def on_after_run(self, global_scripts: Any):
objective_fn = global_scripts.get(self.function_name, None)
if objective_fn:
Expand Down
22 changes: 21 additions & 1 deletion lightning_hpo/framework/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import subprocess
import sys
import time
import uuid
from typing import Any, Dict, Optional

import pytorch_lightning
Expand All @@ -22,6 +26,7 @@ def __init__(
experiment_name: str,
num_nodes: int,
last_model_path: Optional[str] = None,
pip_install_source: bool = False,
**kwargs,
):
Objective.__init__(
Expand All @@ -38,6 +43,8 @@ def __init__(
self.start_time = None
self.end_time = None
self.last_model_path = Path(last_model_path) if last_model_path else None
self.pip_install_source = pip_install_source
self._rootwd = os.getcwd()

def configure_tracer(self):
tracer = Objective.configure_tracer(self)
Expand All @@ -51,10 +58,21 @@ def run(
restart_count: int = 0,
**kwargs,
):
code_dir = "."
if self.pip_install_source:
os.chdir(self._rootwd)
uid = uuid.uuid4().hex[:8]
code_dir = f"code-{uid}"
os.makedirs(code_dir)

if self.last_model_path:
self.last_model_path.get(overwrite=True)
self.params = params
return PyTorchLightningScriptRunner.run(self, params=params, **kwargs)
return PyTorchLightningScriptRunner.run(self, params=params, code_dir=code_dir, **kwargs)

def on_before_run(self):
if self.pip_install_source:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."])

def on_after_run(self, script_globals):
self.end_time = time.time()
Expand Down Expand Up @@ -179,6 +197,7 @@ def __init__(
logger: str,
sweep_id: str,
num_nodes: int = 1,
pip_install_source: bool = False,
**kwargs,
):
super().__init__(
Expand All @@ -189,6 +208,7 @@ def __init__(
experiment_id=experiment_id,
experiment_name=experiment_name,
num_nodes=num_nodes,
pip_install_source=pip_install_source,
**kwargs,
)
self.experiment_id = experiment_id
Expand Down
56 changes: 56 additions & 0 deletions tests/commands/experiment/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def check(config: SweepConfig):
stage="not_started",
desired_stage="running",
disk_size=80,
pip_install_source=False,
data={},
)
expected.sweep_id = config.sweep_id
Expand Down Expand Up @@ -126,6 +127,7 @@ def check(config: SweepConfig):
stage="not_started",
desired_stage="running",
disk_size=80,
pip_install_source=False,
data={"example": None},
)
expected.sweep_id = config.sweep_id
Expand Down Expand Up @@ -225,6 +227,60 @@ def check(config: SweepConfig):
stage="not_started",
desired_stage="running",
disk_size=80,
pip_install_source=False,
data={},
)
expected.sweep_id = config.sweep_id
expected.username = config.username
assert config == expected

command = _create_client_command_mock(run.RunExperimentCommand, None, MagicMock(), check)
command.run()


def test_experiment_run_parsing_pip_install(monkeypatch):

monkeypatch.setattr(run, "CustomLocalSourceCodeDir", MagicMock())

monkeypatch.setattr(sys, "argv", ["", __file__, "--pip-install-source"])

def check(config: SweepConfig):
exp = run.ExperimentConfig(
name="",
best_model_score=None,
monitor=None,
best_model_path=None,
stage="not_started",
params={},
exception=None,
progress=None,
total_parameters=None,
start_time=None,
end_time=None,
)
exp.name = config.experiments[0].name
expected = SweepConfig(
sweep_id="",
script_path=__file__,
total_experiments=1,
parallel_experiments=1,
total_experiments_done=0,
requirements=["pytorch_lightning", "optuna", "deepspeed"],
packages=[],
script_args=[],
algorithm="",
distributions={},
logger_url="",
experiments={0: exp},
framework="pytorch_lightning",
cloud_compute="cpu",
num_nodes=1,
logger="tensorboard",
direction="minimize",
stage="not_started",
desired_stage="running",
disk_size=80,
pip_install_source=True,
data={},
)
expected.sweep_id = config.sweep_id
Expand Down
28 changes: 14 additions & 14 deletions tests/commands/sweep/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_sweep_run_parsing_file_single_list(monkeypatch):

def check_0(config):
assert config.distributions == {
"lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]})
"--lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]})
}
assert config.algorithm == "grid_search"

Expand All @@ -61,8 +61,8 @@ def test_sweep_run_parsing_file_two_lists(monkeypatch):

def check_0(config):
assert config.distributions == {
"lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
"gamma": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
"--lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
"--gamma": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
}
assert config.algorithm == "grid_search"

Expand Down Expand Up @@ -91,7 +91,7 @@ def test_sweep_run_parsing_file_list_and_script_arguments(monkeypatch):

def check_0(config):
assert config.distributions == {
"lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
"--lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
}
assert config.script_args == ["--data.batch=something"]

Expand All @@ -105,7 +105,7 @@ def test_sweep_run_parsing_range(monkeypatch):

def check_0(config):
assert config.distributions == {
"lr": Distributions(
"--lr": Distributions(
distribution="categorical", params={"choices": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]}
),
}
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_sweep_run_parsing_random_search(monkeypatch):

def check_0(config):
assert config.distributions == {
"lr": Distributions(
"--lr": Distributions(
distribution="categorical", params={"choices": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]}
),
}
Expand All @@ -159,7 +159,7 @@ def check_0(config):

def check_1(config):
assert config.distributions == {
"lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
"--lr": Distributions(distribution="categorical", params={"choices": [0.0, 1.0, 2.0]}),
}

argv = [
Expand All @@ -178,8 +178,8 @@ def check_1(config):

def check_2(config):
assert {
"lr": Distributions(distribution="categorical", params={"choices": [0.0, 2.0]}),
"batch_size": Distributions(
"--lr": Distributions(distribution="categorical", params={"choices": [0.0, 2.0]}),
"--batch_size": Distributions(
distribution="categorical", params={"choices": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]}
),
}
Expand Down Expand Up @@ -278,11 +278,11 @@ def test_sweep_run_parsing_random_search_further_distributions(monkeypatch):

def check_1(config):
assert config.distributions == {
"lr": Distributions(distribution="categorical", params={"choices": [0.0, 2.0]}),
"batch_size": Distributions(
"--lr": Distributions(distribution="categorical", params={"choices": [0.0, 2.0]}),
"--batch_size": Distributions(
distribution="categorical", params={"choices": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]}
),
"gamma": Distributions(distribution="log_uniform", params={"low": 32.0, "high": 64.0}),
"--gamma": Distributions(distribution="log_uniform", params={"low": 32.0, "high": 64.0}),
}
assert config.script_args == ["--data.batch=something"]

Expand Down Expand Up @@ -311,8 +311,8 @@ def test_parsing(monkeypatch):

def check(config):
assert config.distributions == {
"model.lr": Distributions(distribution="categorical", params={"choices": [0.001]}),
"data.batch": Distributions(distribution="categorical", params={"choices": [32.0, 64.0]}),
"--model.lr": Distributions(distribution="categorical", params={"choices": [0.001]}),
"--data.batch": Distributions(distribution="categorical", params={"choices": [32.0, 64.0]}),
}
assert config.script_args == ["--something=else"]

Expand Down

0 comments on commit b7d4bea

Please sign in to comment.