From e50659a5c3b26c7861895f8e3af9c66a7958b04a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Jan 2023 16:09:28 +0100 Subject: [PATCH] Fix configuration validation error message in Lite CLI (#16334) --- src/lightning_fabric/CHANGELOG.md | 2 ++ src/lightning_fabric/connector.py | 2 +- tests/tests_fabric/test_connector.py | 15 +++++---------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index c720d2b3a9fe10..93350b54a85599 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -71,6 +71,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Restored sampling parity between PyTorch and Fabric dataloaders when using the `DistributedSampler` ([#16101](https://github.com/Lightning-AI/lightning/issues/16101)) +- Fixes an issue where the error message wouldn't tell the user the real value that was passed through the CLI ([#16334](https://github.com/Lightning-AI/lightning/issues/16334)) + ## [1.8.6] - 2022-12-21 diff --git a/src/lightning_fabric/connector.py b/src/lightning_fabric/connector.py index 208fb9f00dfd6f..e5bda1faa168b9 100644 --- a/src/lightning_fabric/connector.py +++ b/src/lightning_fabric/connector.py @@ -539,7 +539,7 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any: if env_value is not None and env_value != str(current) and str(current) != str(default): raise ValueError( f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value " - f"`--{name}={current}` set through the CLI. " + f"`--{name}={env_value}` set through the CLI. " " Remove it either from the CLI or from the Lightning Fabric object." ) if env_value is None: diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 2f5164854eed02..0ea1295e229f8e 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -13,7 +13,6 @@ # limitations under the License import os -from re import escape from typing import Any, Dict from unittest import mock @@ -808,27 +807,23 @@ def test_devices_from_environment(*_): def test_arguments_from_environment_collision(): """Test that the connector raises an error when the CLI settings conflict with settings in the code.""" with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}): - with pytest.raises( - ValueError, match=escape("Your code has `Fabric(accelerator='cuda', ...)` but it conflicts") - ): + with pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"): _Connector(accelerator="cuda") with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp"}): - with pytest.raises( - ValueError, match=escape("Your code has `Fabric(strategy='ddp_spawn', ...)` but it conflicts") - ): + with pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"): _Connector(strategy="ddp_spawn") with mock.patch.dict(os.environ, {"LT_DEVICES": "2"}): - with pytest.raises(ValueError, match=escape("Your code has `Fabric(devices=3, ...)` but it conflicts")): + with pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"): _Connector(devices=3) with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3"}): - with pytest.raises(ValueError, match=escape("Your code has `Fabric(num_nodes=2, ...)` but it conflicts")): + with pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"): _Connector(num_nodes=2) with mock.patch.dict(os.environ, {"LT_PRECISION": "16"}): - with pytest.raises(ValueError, match=escape("Your code has `Fabric(precision=64, ...)` but it conflicts")): + with pytest.raises(ValueError, match="`Fabric\\(precision=64, ...\\)` but .* `--precision=16`"): _Connector(precision=64)