Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for logging in different trainer stages with DeviceStatsMonitor #15794

Closed
thesofakillers opened this issue Nov 23, 2022 · 5 comments · Fixed by #16002
Closed

Add support for logging in different trainer stages with DeviceStatsMonitor #15794

thesofakillers opened this issue Nov 23, 2022 · 5 comments · Fixed by #16002
Assignees
Labels
callback: device stats feature Is an improvement or enhancement good first issue Good for newcomers

Comments

@thesofakillers
Copy link

thesofakillers commented Nov 23, 2022

Bug description

I would like to use DeviceStatsMonitor during a trainer.test() call. I followed the relative documentation which makes no mention of whether this callback is exclusive to trainer.fit().

Despite following the docs, I get no device stats logs in my tensorboard

How to reproduce the bug

run the following script. You will see that no stats will be logged, despite having the DeviceStatsMonitor callback

import pytorch_lightning as pl
import numpy as np
from pytorch_lightning.callbacks import DeviceStatsMonitor
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn


class SimpleDataset(Dataset):
    def __init__(self):
        X = np.arange(10000)
        y = X * 2
        X = [[_] for _ in X]
        y = [[_] for _ in y]
        self.X = torch.Tensor(X)
        self.y = torch.Tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx]}


class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
        self.criterion = MSELoss()

    def forward(self, inputs_id, labels=None):
        outputs = self.fc(inputs_id)
        loss = 0
        if labels is not None:
            loss = self.criterion(outputs, labels)
        return loss, outputs

    def test_dataloader(self):
        dataset = SimpleDataset()
        return DataLoader(dataset, batch_size=1000)

    def test_step(self, batch, batch_idx):
        input_ids = batch["X"]
        labels = batch["y"]
        loss, outputs = self(input_ids, labels)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = Adam(self.parameters())
        return optimizer


if __name__ == "__main__":
    model = MyModel()
    logger = pl.loggers.TensorBoardLogger(save_dir="example", name="test")
    trainer = pl.Trainer(
        logger=logger, max_epochs=5, callbacks=[DeviceStatsMonitor(cpu_stats=True)]
    )
    trainer.test(model)

Environment

* CUDA:
	- GPU:               None
	- available:         False
	- version:           None
* Lightning:
	- pytorch-lightning: 1.7.7
	- torch:             1.12.1
	- torchmetrics:      0.10.3
* Packages:
	- absl-py:           1.3.0
	- aiohttp:           3.8.3
	- aiosignal:         1.3.1
	- antlr4-python3-runtime: 4.9.3
	- anyio:             3.6.2
	- appnope:           0.1.3
	- argon2-cffi:       21.3.0
	- argon2-cffi-bindings: 21.2.0
	- asttokens:         2.1.0
	- astunparse:        1.6.3
	- async-timeout:     4.0.2
	- attrs:             22.1.0
	- babel:             2.11.0
	- backcall:          0.2.0
	- beautifulsoup4:    4.11.1
	- bigbench:          0.0.1
	- black:             22.10.0
	- bleach:            5.0.1
	- bleurt:            0.0.2
	- cachetools:        5.2.0
	- certifi:           2022.9.24
	- cffi:              1.15.1
	- chardet:           3.0.4
	- charset-normalizer: 2.1.1
	- claficle:          0.1.0
	- click:             8.1.3
	- colorama:          0.4.6
	- contourpy:         1.0.5
	- cycler:            0.11.0
	- datasets:          2.7.0
	- debugpy:           1.6.3
	- decorator:         5.1.1
	- defusedxml:        0.7.1
	- dill:              0.3.6
	- editdistance:      0.6.0
	- entrypoints:       0.4
	- etils:             0.8.0
	- executing:         1.2.0
	- fastjsonschema:    2.16.2
	- filelock:          3.8.0
	- flake8:            4.0.1
	- flatbuffers:       2.0.7
	- fonttools:         4.37.2
	- frozenlist:        1.3.3
	- fsspec:            2022.11.0
	- future:            0.18.2
	- gast:              0.4.0
	- gin-config:        0.5.0
	- gitdb:             4.0.9
	- gitpython:         3.1.29
	- google-auth:       2.14.1
	- google-auth-oauthlib: 0.4.6
	- google-pasta:      0.2.0
	- googleapis-common-protos: 1.56.4
	- googletrans:       3.1.0a0
	- grpcio:            1.50.0
	- h11:               0.9.0
	- h2:                3.2.0
	- h5py:              3.7.0
	- hpack:             3.0.0
	- hstspreload:       2022.11.1
	- httpcore:          0.9.1
	- httpx:             0.13.3
	- huggingface-hub:   0.11.0
	- hydra-core:        1.2.0
	- hyperframe:        5.2.0
	- idna:              2.10
	- immutabledict:     2.2.1
	- importlib-metadata: 5.0.0
	- importlib-resources: 5.10.0
	- iniconfig:         1.1.1
	- ipykernel:         6.17.1
	- ipython:           8.6.0
	- ipython-genutils:  0.2.0
	- ipywidgets:        7.7.2
	- jax:               0.3.17
	- jaxlib:            0.3.15
	- jedi:              0.18.1
	- jinja2:            3.1.2
	- joblib:            1.2.0
	- json5:             0.9.10
	- jsonschema:        4.17.0
	- jupyter-client:    7.4.7
	- jupyter-core:      5.0.0
	- jupyter-server:    1.23.2
	- jupyter-server-mathjax: 0.2.6
	- jupyterlab:        3.5.0
	- jupyterlab-code-formatter: 1.5.3
	- jupyterlab-pygments: 0.2.2
	- jupyterlab-server: 2.16.3
	- jupyterlab-widgets: 1.1.1
	- keras:             2.10.0
	- keras-preprocessing: 1.1.2
	- kiwisolver:        1.4.4
	- libclang:          14.0.6
	- markdown:          3.4.1
	- markupsafe:        2.1.1
	- matplotlib:        3.6.0
	- matplotlib-inline: 0.1.6
	- mccabe:            0.6.1
	- mesh-tensorflow:   0.1.21
	- mistune:           2.0.4
	- multidict:         6.0.2
	- multiprocess:      0.70.14
	- mypy:              0.971
	- mypy-extensions:   0.4.3
	- nbclassic:         0.4.8
	- nbclient:          0.7.0
	- nbconvert:         7.2.5
	- nbdime:            3.1.1
	- nbformat:          5.7.0
	- nest-asyncio:      1.5.6
	- nltk:              3.7
	- notebook:          6.5.2
	- notebook-shim:     0.2.2
	- numpy:             1.23.4
	- oauthlib:          3.2.2
	- omegaconf:         2.2.3
	- opt-einsum:        3.3.0
	- packaging:         21.3
	- pandas:            1.5.1
	- pandocfilters:     1.5.0
	- parso:             0.8.3
	- pathspec:          0.10.2
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            9.2.0
	- pip:               22.1.2
	- pkgutil-resolve-name: 1.3.10
	- platformdirs:      2.5.4
	- pluggy:            1.0.0
	- portalocker:       2.5.1
	- prometheus-client: 0.15.0
	- promise:           2.3
	- prompt-toolkit:    3.0.32
	- protobuf:          3.20.3
	- psutil:            5.9.4
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- py:                1.11.0
	- pyarrow:           10.0.0
	- pyasn1:            0.4.8
	- pyasn1-modules:    0.2.8
	- pycodestyle:       2.8.0
	- pycparser:         2.21
	- pydeprecate:       0.3.2
	- pyflakes:          2.4.0
	- pygments:          2.13.0
	- pyparsing:         3.0.9
	- pyrsistent:        0.19.2
	- pytest:            7.1.3
	- python-dateutil:   2.8.2
	- pytorch-lightning: 1.7.7
	- pytz:              2022.6
	- pyyaml:            6.0
	- pyzmq:             24.0.1
	- regex:             2022.10.31
	- requests:          2.28.1
	- requests-oauthlib: 1.3.1
	- requests-unixsocket: 0.3.0
	- responses:         0.18.0
	- restrictedpython:  5.2
	- rfc3986:           1.5.0
	- rouge-score:       0.1.2
	- rsa:               4.9
	- sacrebleu:         2.2.1
	- scikit-learn:      1.1.2
	- scipy:             1.9.1
	- seaborn:           0.12.0
	- send2trash:        1.8.0
	- sentencepiece:     0.1.97
	- seqio:             0.0.10
	- setuptools:        65.5.1
	- six:               1.16.0
	- smmap:             5.0.0
	- sniffio:           1.3.0
	- soupsieve:         2.3.2.post1
	- stack-data:        0.6.1
	- t5:                0.9.3
	- tabulate:          0.8.10
	- tensorboard:       2.11.0
	- tensorboard-data-server: 0.6.1
	- tensorboard-plugin-wit: 1.8.1
	- tensorflow-datasets: 4.6.0
	- tensorflow-estimator: 2.10.0
	- tensorflow-hub:    0.12.0
	- tensorflow-io-gcs-filesystem: 0.27.0
	- tensorflow-metadata: 1.10.0
	- tensorflow-text:   2.10.0
	- termcolor:         2.1.0
	- terminado:         0.17.0
	- tf-slim:           1.1.0
	- tfds-nightly:      4.6.0.dev202209160046
	- threadpoolctl:     3.1.0
	- tinycss2:          1.2.1
	- tokenizers:        0.13.2
	- toml:              0.10.2
	- tomli:             2.0.1
	- torch:             1.12.1
	- torchmetrics:      0.10.3
	- tornado:           6.2
	- tqdm:              4.64.1
	- traitlets:         5.5.0
	- transformers:      4.24.0
	- types-pyyaml:      6.0.12.2
	- typing-extensions: 4.4.0
	- urllib3:           1.26.12
	- wcwidth:           0.2.5
	- webencodings:      0.5.1
	- websocket-client:  1.4.2
	- werkzeug:          2.2.2
	- wheel:             0.38.4
	- widgetsnbextension: 3.6.1
	- wrapt:             1.14.1
	- xxhash:            3.1.0
	- yarl:              1.8.1
	- zipp:              3.10.0
* System:
	- OS:                Darwin
	- architecture:
		- 64bit
		- 
	- processor:         i386
	- python:            3.8.13
	- version:           Darwin Kernel Version 19.6.0: Tue Jun 21 21:18:39 PDT 2022; root:xnu-6153.141.66~1/RELEASE_X86_64

More info

I have verified this on both GPU and CPU. The example above uses CPU.

cc @Borda @awaelchli

@thesofakillers thesofakillers added the needs triage Waiting to be triaged by maintainers label Nov 23, 2022
@awaelchli
Copy link
Contributor

Hi @thesofakillers
That's because the callback only implements the training hooks right now. Adding support for multiple stages would be welcome!

@awaelchli awaelchli added feature Is an improvement or enhancement good first issue Good for newcomers callback: device stats and removed needs triage Waiting to be triaged by maintainers labels Nov 25, 2022
@awaelchli awaelchli changed the title DeviceStatsMonitor not logging during trainer.test() Add support for logging in multiple trainer stages with DeviceStatsMonitor Nov 25, 2022
@awaelchli awaelchli changed the title Add support for logging in multiple trainer stages with DeviceStatsMonitor Add support for logging in different trainer stages with DeviceStatsMonitor Nov 25, 2022
@albertwujj
Copy link

Hi, I'd like to work on this. I'm new to this library, but am currently reading through everything related, the Trainer run function, the train/eval/predict Loops and EpochLoops, the logger connector.

@albertwujj
Copy link

albertwujj commented Nov 27, 2022

Currently, for 'fit' runs DeviceStatsMonitor only logs every n steps as defined by Trainer's 'log_every_n_steps' variable.

How do we decide how often to log for 'test', AKA 'eval', runs? With the same 'log_every_n_steps' variable or something else?

@albertwujj
Copy link

albertwujj commented Nov 27, 2022

To have a base to start with, here is a fork where I enabled DeviceStatsMonitor logging for eval runs. f37d373

Test code used:

import pytorch_lightning as pl
import numpy as np
from pytorch_lightning.callbacks import DeviceStatsMonitor
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import os

class SimpleDataset(Dataset):
    def __init__(self):
        X = np.arange(10000)
        y = X * 2
        X = [[n] for n in X]
        y = [[n] for n in y]
        self.X = torch.Tensor(X)
        self.y = torch.Tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx]}

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
        self.criterion = MSELoss()

    def forward(self, inputs_id, labels=None):
        outputs = self.fc(inputs_id)
        loss = 0
        if labels is not None:
            loss = self.criterion(outputs, labels)
        return loss, outputs

    def train_dataloader(self):
        dataset = SimpleDataset()
        return DataLoader(dataset, batch_size=1000)

    def test_dataloader(self):
        dataset = SimpleDataset()
        return DataLoader(dataset, batch_size=1000)

    def training_step(self, batch, batch_idx):
        input_ids = batch["X"]
        labels = batch["y"]
        loss, outputs = self(input_ids, labels)
        return {"loss": loss, "outputs": outputs}

    def test_step(self, batch, batch_idx):
        input_ids = batch["X"]
        labels = batch["y"]
        loss, outputs = self(input_ids, labels)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = Adam(self.parameters())
        return optimizer

if __name__ == "__main__":
    print('hello' + os.getcwd() + 'hello')
    model = MyModel()
    logger = pl.loggers.CSVLogger(save_dir="example", name="test")
    trainer = pl.Trainer(
        logger=logger, max_epochs=5, callbacks=[DeviceStatsMonitor(cpu_stats=True), ], log_every_n_steps=1
    )
    trainer.fit(model)
    trainer.test(model)

I logged the same every N runs as fit runs. However, the fit epoch loop keeps track of a _batches_that_stepped, but the eval epoch loop does not. As far as I can tell (not certain), the eval epoch loop's variable batch_progress.total.completed tracks the same thing,

Clarifications/comments/instructions welcome!

When I know what to do, I will continue, can add support for logging in predict loop, and eventually send a PR.

@thesofakillers
Copy link
Author

not stale

@Lightning-AI Lightning-AI deleted a comment from stale bot Jan 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback: device stats feature Is an improvement or enhancement good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants