Skip to content

Commit

Permalink
use hydra logging (DDP is supported)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryul99 committed Nov 10, 2020
1 parent ba676f4 commit 3f09593
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 30 deletions.
22 changes: 22 additions & 0 deletions config/hydra/job_logging/custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @package hydra.job_logging
# python logging configuration for tasks
version: 1
formatters:
simple:
format: '%(message)s'
detailed:
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: detailed
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: detailed
filename: trainer.log
root:
level: INFO
handlers: [console, file]

disable_existing_loggers: False
19 changes: 10 additions & 9 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

from collections import OrderedDict
import os.path as osp
from omegaconf import OmegaConf
import wandb
import logging
import os

from utils.utils import DotDict, is_logging_process


logger = logging.getLogger(osp.basename(__file__))
from utils.utils import is_logging_process, get_logger


class Model:
Expand All @@ -25,6 +23,7 @@ def __init__(self, cfg, net_arch, loss_f, rank=0):
self.GT = None
self.step = 0
self.epoch = -1
self._logger = get_logger(cfg, os.path.basename(__file__))

# init optimizer
optimizer_mode = self.cfg.train.optimizer.mode
Expand Down Expand Up @@ -78,7 +77,7 @@ def save_network(self, save_file=True):
if self.cfg.log.use_wandb:
wandb.save(save_path)
if is_logging_process():
logger.info("Saved network checkpoint to: %s" % save_path)
self._logger.info("Saved network checkpoint to: %s" % save_path)
return state_dict

def load_network(self, loaded_net=None):
Expand All @@ -103,7 +102,9 @@ def load_network(self, loaded_net=None):

self.net.load_state_dict(loaded_clean_net, strict=self.cfg.load.strict_load)
if is_logging_process() and add_log:
logger.info("Checkpoint %s is loaded" % self.cfg.load.network_chkpt_path)
self._logger.info(
"Checkpoint %s is loaded" % self.cfg.load.network_chkpt_path
)

def save_training_state(self):
if is_logging_process():
Expand All @@ -120,7 +121,7 @@ def save_training_state(self):
if self.cfg.log.use_wandb:
wandb.save(save_path)
if is_logging_process():
logger.info("Saved training state to: %s" % save_path)
self._logger.info("Saved training state to: %s" % save_path)

def load_training_state(self):
if self.cfg.load.wandb_load_path is not None:
Expand All @@ -138,6 +139,6 @@ def load_training_state(self):
self.step = resume_state["step"]
self.epoch = resume_state["epoch"]
if is_logging_process():
logger.info(
self._logger.info(
"Resuming from training state: %s" % self.cfg.load.resume_state_path
)
3 changes: 0 additions & 3 deletions tests/model/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import torch
import torch.nn as nn
import os
import logging
from tests.test_case import ProjectTestCase
from model.model_arch import Net_arch
from model.model import Model

logger = logging.getLogger(os.path.basename(__file__))


class TestModel(ProjectTestCase):
@classmethod
Expand Down
8 changes: 6 additions & 2 deletions tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import shutil
import tempfile
import logging
from utils.utils import get_logger
from hydra.experimental import initialize, compose

TEST_DIR = tempfile.mkdtemp(prefix="project_tests")
Expand All @@ -30,7 +30,11 @@ def setup_method(self):
self.cfg.log.use_tensorboard = False

# set logger
self.logger = logging.getLogger(os.path.basename(__file__))
self.logger = get_logger(
self.cfg,
os.path.basename(__file__),
str((self.working_dir / "trainer.log").resolve()),
)

def teardown_method(self):
shutil.rmtree(self.TEST_DIR)
9 changes: 3 additions & 6 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import traceback
import random
import os
import logging

import hydra
import torch
Expand All @@ -18,14 +17,11 @@
from model.model import Model
from utils.train_model import train_model
from utils.test_model import test_model
from utils.utils import set_random_seed, is_logging_process
from utils.utils import set_random_seed, is_logging_process, get_logger
from utils.writer import Writer
from dataset.dataloader import create_dataloader, DataloaderMode


logger = logging.getLogger(os.path.basename(__file__))


def setup(cfg, rank):
os.environ["MASTER_ADDR"] = cfg.dist.master_addr
os.environ["MASTER_PORT"] = cfg.dist.master_port
Expand Down Expand Up @@ -53,6 +49,7 @@ def distributed_run(fn, cfg):


def train_loop(rank, cfg):
logger = get_logger(cfg, os.path.basename(__file__))
if cfg.device == "cuda" and cfg.dist.gpus != 0:
cfg.device = rank
# turn off background generator when distributed run is on
Expand Down Expand Up @@ -139,7 +136,7 @@ def train_loop(rank, cfg):
cleanup()


@hydra.main(config_path="config/default.yaml")
@hydra.main(config_path="config", config_name="default")
def main(hydra_cfg):
hydra_cfg.device = hydra_cfg.device.lower()

Expand Down
6 changes: 2 additions & 4 deletions utils/test_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import torch
import logging
import os
from utils.utils import is_logging_process

logger = logging.getLogger(os.path.basename(__file__))
from utils.utils import is_logging_process, get_logger


def test_model(cfg, model, test_loader, writer):
logger = get_logger(cfg, os.path.basename(__file__))
model.net.eval()
total_test_loss = 0
test_loop_len = 0
Expand Down
6 changes: 2 additions & 4 deletions utils/train_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import math
import logging
import os
from utils.utils import is_logging_process

logger = logging.getLogger(os.path.basename(__file__))
from utils.utils import is_logging_process, get_logger


def train_model(cfg, model, train_loader, writer):
logger = get_logger(cfg, os.path.basename(__file__))
model.net.train()
for input_, target in train_loader:
model.feed_data(input=input_, GT=target)
Expand Down
34 changes: 32 additions & 2 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import subprocess
import yaml
import random
import logging
import os.path as osp
import numpy as np
import torch
import torch.distributed as dist
from copy import deepcopy
from datetime import datetime
from omegaconf import OmegaConf


def set_random_seed(seed):
Expand All @@ -19,6 +20,35 @@ def is_logging_process():
return not dist.is_initialized() or dist.get_rank() == 0


def get_logger(cfg, name=None, log_file_path=None):
# log_file_path is used when unit testing
if is_logging_process():
project_root_path = osp.dirname(osp.dirname(osp.abspath(__file__)))
hydra_conf = OmegaConf.load(osp.join(project_root_path, "config/default.yaml"))

job_logging_name = None
for job_logging_name in hydra_conf.defaults:
if isinstance(job_logging_name, dict):
job_logging_name = job_logging_name.get("hydra/job_logging")
if job_logging_name is not None:
break
job_logging_name = None
if job_logging_name is None:
job_logging_name = "custom" # default name

logging_conf = OmegaConf.load(
osp.join(
project_root_path,
"config/hydra/job_logging",
job_logging_name + ".yaml",
)
)
if log_file_path is not None:
logging_conf.handlers.file.filename = log_file_path
logging.config.dictConfig(OmegaConf.to_container(logging_conf, resolve=True))
return logging.getLogger(name)


def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")

Expand Down

0 comments on commit 3f09593

Please sign in to comment.