Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix precision default from environment #18928

Merged
merged 4 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
strategy = self._argument_from_env("strategy", strategy, default="auto")
devices = self._argument_from_env("devices", devices, default="auto")
num_nodes = int(self._argument_from_env("num_nodes", num_nodes, default=1))
precision = self._argument_from_env("precision", precision, default="32-true")
precision = self._argument_from_env("precision", precision, default=None)

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
Expand Down
21 changes: 15 additions & 6 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,16 +870,25 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
assert isinstance(connector.strategy, strategy_cls)


@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"])
@pytest.mark.parametrize(
("precision", "expected"),
[
(None, Precision),
("64-true", DoublePrecision),
("32-true", Precision),
("16-true", HalfPrecision),
("16-mixed", MixedPrecision),
],
)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
def test_precision_from_environment(_, precision):
def test_precision_from_environment(_, precision, expected):
"""Test that the precision input can be set through the environment variable."""
env_vars = {}
env_vars = {"LT_CLI_USED": "1"}
if precision is not None:
env_vars["LT_PRECISION"] = precision
with mock.patch.dict(os.environ, env_vars):
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
assert isinstance(connector.precision, Precision)
assert isinstance(connector.precision, expected)


@pytest.mark.parametrize(
Expand All @@ -897,7 +906,7 @@ def test_precision_from_environment(_, precision):
)
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
"""Test that the accelerator and strategy input can be set through the environment variables."""
env_vars = {}
env_vars = {"LT_CLI_USED": "1"}
if accelerator is not None:
env_vars["LT_ACCELERATOR"] = accelerator
if strategy is not None:
Expand All @@ -912,7 +921,7 @@ def test_accelerator_strategy_from_environment(accelerator, strategy, expected_a
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=8)
def test_devices_from_environment(*_):
"""Test that the devices and number of nodes can be set through the environment variables."""
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_NUM_NODES": "3"}):
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_NUM_NODES": "3", "LT_CLI_USED": "1"}):
connector = _Connector(accelerator="cuda")
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
Expand Down
Loading