Skip to content

Commit

Permalink
Fix precision default from environment (#18928)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
(cherry picked from commit 466f772)
  • Loading branch information
carmocca authored and Borda committed Nov 14, 2023
1 parent 0c21338 commit ba03469
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
strategy = self._argument_from_env("strategy", strategy, default="auto")
devices = self._argument_from_env("devices", devices, default="auto")
num_nodes = int(self._argument_from_env("num_nodes", num_nodes, default=1))
precision = self._argument_from_env("precision", precision, default="32-true")
precision = self._argument_from_env("precision", precision, default=None)

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
Expand Down
21 changes: 15 additions & 6 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,16 +870,25 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
assert isinstance(connector.strategy, strategy_cls)


@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"])
@pytest.mark.parametrize(
("precision", "expected"),
[
(None, Precision),
("64-true", DoublePrecision),
("32-true", Precision),
("16-true", HalfPrecision),
("16-mixed", MixedPrecision),
],
)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
def test_precision_from_environment(_, precision):
def test_precision_from_environment(_, precision, expected):
"""Test that the precision input can be set through the environment variable."""
env_vars = {}
env_vars = {"LT_CLI_USED": "1"}
if precision is not None:
env_vars["LT_PRECISION"] = precision
with mock.patch.dict(os.environ, env_vars):
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
assert isinstance(connector.precision, Precision)
assert isinstance(connector.precision, expected)


@pytest.mark.parametrize(
Expand All @@ -897,7 +906,7 @@ def test_precision_from_environment(_, precision):
)
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
"""Test that the accelerator and strategy input can be set through the environment variables."""
env_vars = {}
env_vars = {"LT_CLI_USED": "1"}
if accelerator is not None:
env_vars["LT_ACCELERATOR"] = accelerator
if strategy is not None:
Expand All @@ -912,7 +921,7 @@ def test_accelerator_strategy_from_environment(accelerator, strategy, expected_a
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=8)
def test_devices_from_environment(*_):
"""Test that the devices and number of nodes can be set through the environment variables."""
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_NUM_NODES": "3"}):
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_NUM_NODES": "3", "LT_CLI_USED": "1"}):
connector = _Connector(accelerator="cuda")
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
Expand Down

0 comments on commit ba03469

Please sign in to comment.