From 3f09593dfcec8e5d6facc5171596700c506fee7a Mon Sep 17 00:00:00 2001 From: Changmin Choi Date: Wed, 11 Nov 2020 02:08:50 +0900 Subject: [PATCH] use hydra logging (DDP is supported) --- config/hydra/job_logging/custom.yaml | 22 ++++++++++++++++++ model/model.py | 19 ++++++++-------- tests/model/model_test.py | 3 --- tests/test_case.py | 8 +++++-- trainer.py | 9 +++----- utils/test_model.py | 6 ++--- utils/train_model.py | 6 ++--- utils/utils.py | 34 ++++++++++++++++++++++++++-- 8 files changed, 77 insertions(+), 30 deletions(-) create mode 100644 config/hydra/job_logging/custom.yaml diff --git a/config/hydra/job_logging/custom.yaml b/config/hydra/job_logging/custom.yaml new file mode 100644 index 0000000..0ed9b05 --- /dev/null +++ b/config/hydra/job_logging/custom.yaml @@ -0,0 +1,22 @@ +# @package hydra.job_logging +# python logging configuration for tasks +version: 1 +formatters: + simple: + format: '%(message)s' + detailed: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' +handlers: + console: + class: logging.StreamHandler + formatter: detailed + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: detailed + filename: trainer.log +root: + level: INFO + handlers: [console, file] + +disable_existing_loggers: False \ No newline at end of file diff --git a/model/model.py b/model/model.py index 14cb8a1..33e3bb9 100644 --- a/model/model.py +++ b/model/model.py @@ -4,13 +4,11 @@ from collections import OrderedDict import os.path as osp +from omegaconf import OmegaConf import wandb -import logging +import os -from utils.utils import DotDict, is_logging_process - - -logger = logging.getLogger(osp.basename(__file__)) +from utils.utils import is_logging_process, get_logger class Model: @@ -25,6 +23,7 @@ def __init__(self, cfg, net_arch, loss_f, rank=0): self.GT = None self.step = 0 self.epoch = -1 + self._logger = get_logger(cfg, os.path.basename(__file__)) # init optimizer optimizer_mode = self.cfg.train.optimizer.mode @@ -78,7 +77,7 @@ def save_network(self, save_file=True): if self.cfg.log.use_wandb: wandb.save(save_path) if is_logging_process(): - logger.info("Saved network checkpoint to: %s" % save_path) + self._logger.info("Saved network checkpoint to: %s" % save_path) return state_dict def load_network(self, loaded_net=None): @@ -103,7 +102,9 @@ def load_network(self, loaded_net=None): self.net.load_state_dict(loaded_clean_net, strict=self.cfg.load.strict_load) if is_logging_process() and add_log: - logger.info("Checkpoint %s is loaded" % self.cfg.load.network_chkpt_path) + self._logger.info( + "Checkpoint %s is loaded" % self.cfg.load.network_chkpt_path + ) def save_training_state(self): if is_logging_process(): @@ -120,7 +121,7 @@ def save_training_state(self): if self.cfg.log.use_wandb: wandb.save(save_path) if is_logging_process(): - logger.info("Saved training state to: %s" % save_path) + self._logger.info("Saved training state to: %s" % save_path) def load_training_state(self): if self.cfg.load.wandb_load_path is not None: @@ -138,6 +139,6 @@ def load_training_state(self): self.step = resume_state["step"] self.epoch = resume_state["epoch"] if is_logging_process(): - logger.info( + self._logger.info( "Resuming from training state: %s" % self.cfg.load.resume_state_path ) diff --git a/tests/model/model_test.py b/tests/model/model_test.py index 6a0983b..882c350 100644 --- a/tests/model/model_test.py +++ b/tests/model/model_test.py @@ -2,13 +2,10 @@ import torch import torch.nn as nn import os -import logging from tests.test_case import ProjectTestCase from model.model_arch import Net_arch from model.model import Model -logger = logging.getLogger(os.path.basename(__file__)) - class TestModel(ProjectTestCase): @classmethod diff --git a/tests/test_case.py b/tests/test_case.py index ec1f2bb..8162db2 100644 --- a/tests/test_case.py +++ b/tests/test_case.py @@ -4,7 +4,7 @@ import pathlib import shutil import tempfile -import logging +from utils.utils import get_logger from hydra.experimental import initialize, compose TEST_DIR = tempfile.mkdtemp(prefix="project_tests") @@ -30,7 +30,11 @@ def setup_method(self): self.cfg.log.use_tensorboard = False # set logger - self.logger = logging.getLogger(os.path.basename(__file__)) + self.logger = get_logger( + self.cfg, + os.path.basename(__file__), + str((self.working_dir / "trainer.log").resolve()), + ) def teardown_method(self): shutil.rmtree(self.TEST_DIR) diff --git a/trainer.py b/trainer.py index 8357a17..c6f179e 100644 --- a/trainer.py +++ b/trainer.py @@ -5,7 +5,6 @@ import traceback import random import os -import logging import hydra import torch @@ -18,14 +17,11 @@ from model.model import Model from utils.train_model import train_model from utils.test_model import test_model -from utils.utils import set_random_seed, is_logging_process +from utils.utils import set_random_seed, is_logging_process, get_logger from utils.writer import Writer from dataset.dataloader import create_dataloader, DataloaderMode -logger = logging.getLogger(os.path.basename(__file__)) - - def setup(cfg, rank): os.environ["MASTER_ADDR"] = cfg.dist.master_addr os.environ["MASTER_PORT"] = cfg.dist.master_port @@ -53,6 +49,7 @@ def distributed_run(fn, cfg): def train_loop(rank, cfg): + logger = get_logger(cfg, os.path.basename(__file__)) if cfg.device == "cuda" and cfg.dist.gpus != 0: cfg.device = rank # turn off background generator when distributed run is on @@ -139,7 +136,7 @@ def train_loop(rank, cfg): cleanup() -@hydra.main(config_path="config/default.yaml") +@hydra.main(config_path="config", config_name="default") def main(hydra_cfg): hydra_cfg.device = hydra_cfg.device.lower() diff --git a/utils/test_model.py b/utils/test_model.py index 72e923b..8275da4 100644 --- a/utils/test_model.py +++ b/utils/test_model.py @@ -1,12 +1,10 @@ import torch -import logging import os -from utils.utils import is_logging_process - -logger = logging.getLogger(os.path.basename(__file__)) +from utils.utils import is_logging_process, get_logger def test_model(cfg, model, test_loader, writer): + logger = get_logger(cfg, os.path.basename(__file__)) model.net.eval() total_test_loss = 0 test_loop_len = 0 diff --git a/utils/train_model.py b/utils/train_model.py index 06eebd0..01ac73c 100644 --- a/utils/train_model.py +++ b/utils/train_model.py @@ -1,12 +1,10 @@ import math -import logging import os -from utils.utils import is_logging_process - -logger = logging.getLogger(os.path.basename(__file__)) +from utils.utils import is_logging_process, get_logger def train_model(cfg, model, train_loader, writer): + logger = get_logger(cfg, os.path.basename(__file__)) model.net.train() for input_, target in train_loader: model.feed_data(input=input_, GT=target) diff --git a/utils/utils.py b/utils/utils.py index 7962f9d..ed37f26 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,11 +1,12 @@ import subprocess -import yaml import random +import logging +import os.path as osp import numpy as np import torch import torch.distributed as dist -from copy import deepcopy from datetime import datetime +from omegaconf import OmegaConf def set_random_seed(seed): @@ -19,6 +20,35 @@ def is_logging_process(): return not dist.is_initialized() or dist.get_rank() == 0 +def get_logger(cfg, name=None, log_file_path=None): + # log_file_path is used when unit testing + if is_logging_process(): + project_root_path = osp.dirname(osp.dirname(osp.abspath(__file__))) + hydra_conf = OmegaConf.load(osp.join(project_root_path, "config/default.yaml")) + + job_logging_name = None + for job_logging_name in hydra_conf.defaults: + if isinstance(job_logging_name, dict): + job_logging_name = job_logging_name.get("hydra/job_logging") + if job_logging_name is not None: + break + job_logging_name = None + if job_logging_name is None: + job_logging_name = "custom" # default name + + logging_conf = OmegaConf.load( + osp.join( + project_root_path, + "config/hydra/job_logging", + job_logging_name + ".yaml", + ) + ) + if log_file_path is not None: + logging_conf.handlers.file.filename = log_file_path + logging.config.dictConfig(OmegaConf.to_container(logging_conf, resolve=True)) + return logging.getLogger(name) + + def get_timestamp(): return datetime.now().strftime("%y%m%d-%H%M%S")