Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate num_processes,gpus, tpu_cores, and ipus from the Trainer constructor #11040

Merged
merged 37 commits into from
Apr 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
212a799
Deprecate num_processes, gpus, tpu_cores, and ipus from the Trainer c…
daniellepintz Dec 11, 2021
5368b14
fix accel_con tests
daniellepintz Dec 15, 2021
5124e3e
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
841aed7
change removal to 2.0
daniellepintz Dec 22, 2021
ad02054
fix
daniellepintz Dec 22, 2021
ca0b6a2
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Jan 20, 2022
151da46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2022
3186c2c
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Feb 4, 2022
e87b79d
Merge branch 'accel_con' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Feb 4, 2022
315b5e9
fix tpu test
daniellepintz Feb 10, 2022
f6819c3
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Feb 10, 2022
71e8011
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2022
6cb44d6
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Mar 2, 2022
87accbd
Merge branch 'accel_con' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Mar 2, 2022
5970989
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2022
33b959a
fix gpu test
daniellepintz Mar 2, 2022
4ecc5b0
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Mar 26, 2022
a6ab1a3
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Mar 27, 2022
38fc4d6
update final tests
daniellepintz Mar 27, 2022
cf1e981
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Mar 28, 2022
45fbee6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
d792bf2
fix more tests
daniellepintz Mar 28, 2022
369152f
Merge branch 'accel_con' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Mar 28, 2022
1a07d7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
f5db9f0
update dep strings
daniellepintz Mar 28, 2022
97fcdc3
update test_combined_data_loader_validation_test
daniellepintz Mar 28, 2022
54d3f15
mock ipus and tpus
daniellepintz Mar 28, 2022
87761f0
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Mar 29, 2022
69c1756
update to 1.7 and fix test
daniellepintz Mar 29, 2022
99af06c
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Mar 29, 2022
97ecfa8
fix tests
daniellepintz Mar 29, 2022
5f93ffc
fix test_v2_0_0_deprecated_tpu_cores
daniellepintz Mar 29, 2022
1d477f0
fix test
daniellepintz Mar 29, 2022
b8e00c5
Merge branch 'master' into accel_con
awaelchli Apr 3, 2022
6763a46
add missing deprecation messages in docs
awaelchli Apr 3, 2022
7dc4b50
update usage in nemo examples
awaelchli Apr 3, 2022
45b85b9
merge master
awaelchli Apr 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.loggers.base.LightningLoggerBase` in favor of `pytorch_lightning.loggers.logger.Logger`, and deprecated `pytorch_lightning.loggers.base` in favor of `pytorch_lightning.loggers.logger` ([#120148](https://github.com/PyTorchLightning/pytorch-lightning/pull/12014))


-

- Deprecated `num_processes`, `gpus`, `tpu_cores,` and `ipus` from the `Trainer` constructor in favor of using the `accelerator` and `devices` arguments ([#11040](https://github.com/PyTorchLightning/pytorch-lightning/pull/11040))


-
Expand Down
9 changes: 9 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,9 @@ See Also:
gpus
^^^^

.. warning:: ``gpus=x`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

.. raw:: html

<video width="50%" max-width="400px" controls
Expand Down Expand Up @@ -1055,6 +1058,9 @@ Number of GPU nodes for distributed training.
num_processes
^^^^^^^^^^^^^

.. warning:: ``num_processes=x`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='cpu'`` and ``devices=x`` instead.

.. raw:: html

<video width="50%" max-width="400px" controls
Expand Down Expand Up @@ -1457,6 +1463,9 @@ track_grad_norm
tpu_cores
^^^^^^^^^

.. warning:: ``tpu_cores=x`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='tpu'`` and ``devices=x`` instead.

.. raw:: html

<video width="50%" max-width="400px" controls
Expand Down
9 changes: 6 additions & 3 deletions docs/source/ecosystem/asr_nlp_tts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ including the PyTorch Lightning Trainer, customizable from the command line.
.. code-block:: bash

python NeMo/examples/asr/speech_to_text.py --config-name=quartznet_15x5 \
trainer.gpus=4 \
trainer.accelerator=gpu \
trainer.devices=4 \
trainer.max_epochs=128 \
+trainer.precision=16 \
model.train_ds.manifest_filepath=<PATH_TO_DATA>/librispeech-train-all.json \
Expand Down Expand Up @@ -433,7 +434,8 @@ Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trai
model.head.num_fc_layers=2 \
model.dataset.data_dir=/path/to/my/data \
trainer.max_epochs=5 \
trainer.gpus=[0,1]
trainer.accelerator=gpu \
trainer.devices=[0,1]

-----------

Expand Down Expand Up @@ -643,7 +645,8 @@ Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trai
.. code-block:: bash

python NeMo/examples/tts/glow_tts.py \
trainer.gpus=4 \
trainer.accelerator=gpu \
trainer.devices=4 \
trainer.max_epochs=400 \
...
train_dataset=/path/to/train/data \
Expand Down
6 changes: 6 additions & 0 deletions docs/source/starter/lightning_lite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ Configure the devices to run on. Can be of type:
gpus
====

.. warning:: ``gpus=x`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='gpu'`` and ``devices=x`` instead.

Shorthand for setting ``devices=X`` and ``accelerator="gpu"``.

.. code-block:: python
Expand All @@ -445,6 +448,9 @@ Shorthand for setting ``devices=X`` and ``accelerator="gpu"``.
tpu_cores
=========

.. warning:: ``tpu_cores=x`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='tpu'`` and ``devices=x`` instead.

Shorthand for ``devices=X`` and ``accelerator="tpu"``.

.. code-block:: python
Expand Down
27 changes: 24 additions & 3 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _check_device_config_and_set_final_flags(
self._devices_flag = devices

# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
self._map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
devices, num_processes, gpus, ipus, tpu_cores
)

Expand All @@ -424,15 +424,36 @@ def _check_device_config_and_set_final_flags(
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping"
)

def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
self,
devices: Optional[Union[List[int], str, int]],
num_processes: Optional[int],
gpus: Optional[Union[List[int], str, int]],
ipus: Optional[int],
tpu_cores: Optional[Union[List[int], str, int]],
) -> None:
"""Sets the `devices_flag` and `accelerator_flag` based on num_processes, gpus, ipus, tpu_cores."""
"""Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and
`accelerator_flag`."""
if num_processes is not None:
rank_zero_deprecation(
f"Setting `Trainer(num_processes={num_processes})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Trainer(accelerator='cpu', devices={num_processes})` instead."
)
if gpus is not None:
rank_zero_deprecation(
f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Trainer(accelerator='gpu', devices={gpus!r})` instead."
)
if tpu_cores is not None:
rank_zero_deprecation(
f"Setting `Trainer(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Trainer(accelerator='tpu', devices={tpu_cores!r})` instead."
)
if ipus is not None:
rank_zero_deprecation(
f"Setting `Trainer(ipus={ipus})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Trainer(accelerator='ipu', devices={ipus})` instead."
)
self._gpus: Optional[Union[List[int], str, int]] = gpus
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores
Expand Down
24 changes: 20 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ def __init__(
gradient_clip_algorithm: Optional[str] = None,
process_position: int = 0,
num_nodes: int = 1,
num_processes: Optional[int] = None,
num_processes: Optional[int] = None, # TODO: Remove in 2.0
devices: Optional[Union[List[int], str, int]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
gpus: Optional[Union[List[int], str, int]] = None, # TODO: Remove in 2.0
auto_select_gpus: bool = False,
tpu_cores: Optional[Union[List[int], str, int]] = None,
ipus: Optional[int] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None, # TODO: Remove in 2.0
ipus: Optional[int] = None, # TODO: Remove in 2.0
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
Expand Down Expand Up @@ -275,6 +275,10 @@ def __init__(
gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node
Default: ``None``.

.. deprecated:: v1.7
``gpus`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='gpu'`` and ``devices=x`` instead.

gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
Default: ``None``.
Expand Down Expand Up @@ -351,6 +355,10 @@ def __init__(
num_processes: Number of processes for distributed training with ``accelerator="cpu"``.
Default: ``1``.

.. deprecated:: v1.7
``num_processes`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='cpu'`` and ``devices=x`` instead.

num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders.
Default: ``2``.
Expand Down Expand Up @@ -381,9 +389,17 @@ def __init__(
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on (1)
Default: ``None``.

.. deprecated:: v1.7
``tpu_cores`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='tpu'`` and ``devices=x`` instead.

ipus: How many IPUs to train on.
Default: ``None``.

.. deprecated:: v1.7
``ipus`` has been deprecated in v1.7 and will be removed in v2.0.
Please use ``accelerator='ipu'`` and ``devices=x`` instead.

track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. If using
Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them.
Default: ``-1``.
Expand Down
1 change: 0 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from tests.helpers.runif import RunIf


# TODO: please modify/sunset any test that has accelerator=ddp/ddp2/ddp_cpu/ddp_spawn @daniellepintz
def test_accelerator_choice_cpu(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
assert isinstance(trainer.accelerator, CPUAccelerator)
Expand Down
16 changes: 7 additions & 9 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ def test_epoch_end(self, outputs) -> None:
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
def test_fail_if_no_ipus(mock_ipu_acc_avail, tmpdir):
with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"):
Trainer(default_root_dir=tmpdir, ipus=1)

with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"):
Trainer(default_root_dir=tmpdir, ipus=1, accelerator="ipu")
Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1)


@RunIf(ipu=True)
Expand Down Expand Up @@ -398,7 +395,8 @@ def test_manual_poptorch_opts(tmpdir):

trainer = Trainer(
default_root_dir=tmpdir,
ipus=2,
accelerator="ipu",
devices=2,
fast_dev_run=True,
strategy=IPUStrategy(inference_opts=inference_opts, training_opts=training_opts),
)
Expand Down Expand Up @@ -552,13 +550,13 @@ def test_precision_plugin(tmpdir):

@RunIf(ipu=True)
def test_accelerator_ipu():
trainer = Trainer(accelerator="ipu", ipus=1)
trainer = Trainer(accelerator="ipu", devices=1)
assert isinstance(trainer.accelerator, IPUAccelerator)

trainer = Trainer(accelerator="ipu")
assert isinstance(trainer.accelerator, IPUAccelerator)

trainer = Trainer(accelerator="auto", ipus=8)
trainer = Trainer(accelerator="auto", devices=8)
assert isinstance(trainer.accelerator, IPUAccelerator)


Expand Down Expand Up @@ -592,8 +590,8 @@ def test_accelerator_ipu_with_ipus_priority():

@RunIf(ipu=True)
def test_set_devices_if_none_ipu():

trainer = Trainer(accelerator="ipu", ipus=8)
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
trainer = Trainer(accelerator="ipu", ipus=8)
assert trainer.num_devices == 8


Expand Down
3 changes: 2 additions & 1 deletion tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_accelerator_tpu_with_tpu_cores_priority():

@RunIf(tpu=True)
def test_set_devices_if_none_tpu():
trainer = Trainer(accelerator="tpu", tpu_cores=8)
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
trainer = Trainer(accelerator="tpu", tpu_cores=8)
assert isinstance(trainer.accelerator, TPUAccelerator)
assert trainer.num_devices == 8

Expand Down
3 changes: 2 additions & 1 deletion tests/benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
gpus=1 if device_type == "cuda" else 0,
accelerator="gpu" if device_type == "cuda" else "cpu",
devices=1,
logger=False,
replace_sampler_ddp=False,
)
Expand Down
13 changes: 11 additions & 2 deletions tests/benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,24 @@ def plugin_parity_test(
ddp_model = model_cls()
use_cuda = gpus > 0

trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_spawn")
trainer = Trainer(
fast_dev_run=True, max_epochs=1, accelerator="gpu", devices=gpus, precision=precision, strategy="ddp_spawn"
)

max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda)

# Reset and train Custom DDP
seed_everything(seed)
custom_plugin_model = model_cls()

trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_sharded_spawn")
trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
accelerator="gpu",
devices=gpus,
precision=precision,
strategy="ddp_sharded_spawn",
)
assert isinstance(trainer.strategy, DDPSpawnShardedStrategy)

max_memory_custom, custom_model_time = record_ddp_fit_model_stats(
Expand Down
31 changes: 30 additions & 1 deletion tests/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v2.0."""
"""Test deprecated functionality which will be removed in v2.0.0."""
from unittest import mock

import pytest

import pytorch_lightning
from pytorch_lightning import Trainer
from tests.callbacks.test_callbacks import OldStatefulCallback
from tests.helpers import BoringModel


def test_v2_0_0_deprecated_num_processes():
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
_ = Trainer(num_processes=2)


@mock.patch("torch.cuda.is_available", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=2)
def test_v2_0_0_deprecated_gpus(*_):
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
_ = Trainer(gpus=0)


@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
def test_v2_0_0_deprecated_tpu_cores(*_):
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
_ = Trainer(tpu_cores=8)


@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
def test_v2_0_0_deprecated_ipus(_, monkeypatch):
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", True)
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
_ = Trainer(ipus=4)


def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v2.0 deprecation
model = BoringModel()
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun
def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus):
"""Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit
sanitizing the gpus as only one of the GPUs is visible."""
trainer = Trainer(gpus=gpus)
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
trainer = Trainer(gpus=gpus)
assert isinstance(trainer._accelerator_connector.cluster_environment, TorchElasticEnvironment)
# when use gpu
if device_parser.parse_gpu_ids(gpus) is not None:
Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,13 @@ def _predict_batch(trainer, model, batches):
[
{},
# these precision plugins modify the optimization flow, so testing them explicitly
pytest.param(dict(gpus=1, precision=16, amp_backend="native"), marks=RunIf(min_gpus=1)),
pytest.param(dict(gpus=1, precision=16, amp_backend="apex"), marks=RunIf(min_gpus=1, amp_apex=True)),
pytest.param(dict(accelerator="gpu", devices=1, precision=16, amp_backend="native"), marks=RunIf(min_gpus=1)),
pytest.param(
dict(gpus=1, precision=16, strategy="deepspeed"), marks=RunIf(min_gpus=1, standalone=True, deepspeed=True)
dict(accelerator="gpu", devices=1, precision=16, amp_backend="apex"), marks=RunIf(min_gpus=1, amp_apex=True)
),
pytest.param(
dict(accelerator="gpu", devices=1, precision=16, strategy="deepspeed"),
marks=RunIf(min_gpus=1, standalone=True, deepspeed=True),
),
],
)
Expand Down Expand Up @@ -496,7 +499,7 @@ def training_step(self, batch, batch_idx):
}
if kwargs.get("amp_backend") == "native" or kwargs.get("amp_backend") == "apex":
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
device = torch.device("cuda:0" if "gpus" in kwargs else "cpu")
device = torch.device("cuda:0" if "accelerator" in kwargs and kwargs["accelerator"] == "gpu" else "cpu")
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def test_tpu_cores_with_argparse(cli_args, expected):

for k, v in expected.items():
assert getattr(args, k) == v
assert Trainer.from_argparse_args(args)
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
assert Trainer.from_argparse_args(args)


@RunIf(tpu=True)
Expand Down
Loading