Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix HF LR schedulers #1307

Merged
merged 17 commits into from
Apr 27, 2022
43 changes: 19 additions & 24 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,26 @@ jobs:
matrix:
# PyTorch 1.5 is failing on Win and bolts requires torchvision>=0.5
os: [ubuntu-20.04, macOS-10.15, windows-2019]
python-version: [3.6, 3.8]
python-version: [3.7, 3.9]
requires: ['oldest', 'latest']
topic: [['devel']]
topic: [['core']]
release: [ 'stable' ]
exclude:
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
# Skip if torch<1.8 and py3.9 on Linux: https://github.com/pytorch/pytorch/issues/50014
- { os: ubuntu-20.04, python-version: 3.9, requires: 'oldest' }
- { os: ubuntu-20.04, python-version: 3.9, requires: 'latest' }
include:
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'devel' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'image' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'video' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'tabular' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'text' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'core' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'tabular' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'text' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'pointcloud' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'serve' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'graph' ] }
- { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'audio' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'serve' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'graph' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'audio' ] }

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand All @@ -67,15 +67,10 @@ jobs:
- name: Set min. dependencies
if: matrix.requires == 'oldest'
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"

- name: Filter requirements
run: |
import sys
if sys.version_info.minor < 7:
fname = 'requirements.txt'
lines = [line for line in open(fname).readlines() if not line.startswith('pytorchvideo')]
open(fname, 'w').writelines(lines)
fname = 'requirements.txt'
ignore = ['pandas', 'torchmetrics']
lines = [line if any([line.startswith(package) for package in ignore]) else line.replace('>', '=') for line in open(fname).readlines()]
open(fname, 'w').writelines(lines)
shell: python

- run: echo "::set-output name=period::$(python -c 'import time ; days = time.time() / 60 / 60 / 24 ; print(int(days / 7))' 2>&1)"
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where the default Flash zero configurations for `ObjectDetector`, `InstanceSegmentation`, and `KeypointDetector` would error with the latest version of some requirements ([#1306](https://github.com/PyTorchLightning/lightning-flash/pull/1306))

- Fixed a bug where LR schedulers from HuggingFace could not be used with newer versions of PyTorch Lightning ([#1307](https://github.com/PyTorchLightning/lightning-flash/pull/1307))

## [0.7.0] - 2022-02-15

### Added
Expand Down
20 changes: 4 additions & 16 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,25 +650,13 @@ def get_num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if not getattr(self, "trainer", None):
raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.")
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.train_dataloader())

num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
if hasattr(self.trainer, "estimated_stepping_batches"):
return self.trainer.estimated_stepping_batches

effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
from flash.core.trainer import Trainer

if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps
return Trainer.estimated_stepping_batches.fget(self.trainer)

@staticmethod
def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float]) -> int:
Expand Down
61 changes: 61 additions & 0 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import warnings
from argparse import ArgumentParser, Namespace
from functools import wraps
Expand All @@ -21,6 +22,7 @@
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import Trainer as PlTrainer
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader
Expand All @@ -31,6 +33,7 @@
from flash.core.data.io.transform_predictions import TransformPredictions
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_0, _PL_GREATER_EQUAL_1_5_0, _PL_GREATER_EQUAL_1_6_0


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -239,3 +242,61 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) ->
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447
return from_argparse_args(Trainer, args, **kwargs)

@property
def estimated_stepping_batches(self) -> Union[int, float]:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation
factor and distributed setup.

Examples
________

.. code-block:: python

def configure_optimizers(self):
optimizer = ...
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
)
return [optimizer], [scheduler]
"""
if _PL_GREATER_EQUAL_1_6_0:
return super().estimated_stepping_batches
# Copied from PL 1.6
accumulation_scheduler = self.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
raise MisconfigurationException(
"Estimated stepping batches cannot be computed with different"
" `accumulate_grad_batches` at different epochs."
)

# infinite training
if self.max_epochs == -1 and self.max_steps == -1:
return float("inf")

if self.train_dataloader is None:
rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.")
if _PL_GREATER_EQUAL_1_5_0:
self.reset_train_dataloader()
else:
self.reset_train_dataloader(self.lightning_module)

total_batches = self.num_training_batches

# iterable dataset
if total_batches == float("inf"):
return self.max_steps

if _PL_GREATER_EQUAL_1_4_0:
self.accumulate_grad_batches = accumulation_scheduler.get_accumulate_grad_batches(self.current_epoch)
else:
# Call the callback hook manually to guarantee that `self.accumulate_grad_batches` has been set
accumulation_scheduler.on_train_epoch_start(self, self.lightning_module)
effective_batch_size = self.accumulate_grad_batches
max_estimated_steps = math.ceil(total_batches / effective_batch_size) * max(self.max_epochs, 1)

max_estimated_steps = (
min(max_estimated_steps, self.max_steps) if self.max_steps not in [None, -1] else max_estimated_steps
)
return max_estimated_steps
6 changes: 2 additions & 4 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,12 @@ def from_labelstudio(

Args:
export_json: path to label studio export file
train_export_json: path to label studio export file for train set,
overrides export_json if specified
train_export_json: path to label studio export file for train set, overrides export_json if specified
val_export_json: path to label studio export file for validation
test_export_json: path to label studio export file for test
predict_export_json: path to label studio export file for predict
data_folder: path to label studio data folder
train_data_folder: path to label studio data folder for train data set,
overrides data_folder if specified
train_data_folder: path to label studio data folder for train data set, overrides data_folder if specified
val_data_folder: path to label studio data folder for validation data
test_data_folder: path to label studio data folder for test data
predict_data_folder: path to label studio data folder for predict data
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ def _expand_reqs(extras: dict, keys: list) -> list:
base_req = setup_tools._load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt")
# find all extra requirements
_load_req = partial(setup_tools._load_requirements, path_dir=_PATH_REQUIRE)
SKIP_REQ_FILES = "devel.txt"
found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt")))
# filter unwanted files
found_req_files = [n for n in found_req_files if n not in SKIP_REQ_FILES]
# remove datatype prefix
found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files]
# define basic and extra extras
extras_req = {
Expand All @@ -71,6 +69,7 @@ def _expand_reqs(extras: dict, keys: list) -> list:
)
# some extra combinations
extras_req["vision"] = _expand_reqs(extras_req, ["image", "video"])
extras_req["core"] = _expand_reqs(extras_req, ["image", "tabular", "text"])
extras_req["all"] = _expand_reqs(extras_req, ["vision", "tabular", "text", "audio"])
extras_req["dev"] = _expand_reqs(extras_req, ["all", "test", "docs"])
# filter the uniques
Expand Down
28 changes: 24 additions & 4 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.callbacks import Callback
from torch import nn, Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -406,13 +407,32 @@ def test_external_optimizers_torch_optimizer(tmpdir, optim):
("cosine_with_hard_restarts_schedule_with_warmup", {"num_warmup_steps": 0.1, "num_cycles": 3}),
],
)
def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched):
@pytest.mark.parametrize("use_datamodule", [False, True])
@pytest.mark.parametrize("limit", [None, 5, 0.1])
def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched, use_datamodule, limit):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
task = ClassificationTask(model, optimizer=deepcopy(optim), lr_scheduler=deepcopy(sched), loss_fn=F.nll_loss)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=10, gpus=torch.cuda.device_count())
ds = DummyDataset()
trainer.fit(task, train_dataloader=DataLoader(ds))

if limit is not None:
batch_count = limit if isinstance(limit, int) else int(limit * 10)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=limit)
else:
batch_count = 10
trainer = flash.Trainer(max_epochs=1)

ds = DummyDataset(num_samples=10)

if use_datamodule:

class TestDataModule(LightningDataModule):
def train_dataloader(self):
return DataLoader(ds)

trainer.fit(task, datamodule=TestDataModule())
else:
trainer.fit(task, train_dataloader=DataLoader(ds))

assert task.get_num_training_steps() == batch_count
assert isinstance(trainer.optimizers[0], torch.optim.Adadelta)
assert isinstance(trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR)

Expand Down
9 changes: 8 additions & 1 deletion tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
from unittest import mock

Expand Down Expand Up @@ -114,7 +115,13 @@
"tabular_forecasting.py",
marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"),
),
pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")),
pytest.param(
"template.py",
marks=[
pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed"),
pytest.mark.skipif(sys.version_info >= (3, 9), reason="Undiagnosed segmentation fault in 3.9"),
],
),
pytest.param(
"text_classification.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
Expand Down