Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init fp16 support #139

Merged
merged 5 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint

from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES
Expand Down Expand Up @@ -31,6 +31,10 @@ def __init__(self,
self.decoder = build_component(decoder)
self.dilation_neck = build_component(dilation_neck)

# support fp16
self.fp16_enabled = False

@auto_fp16(apply_to=('x', ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@auto_fp16() should be enough

def forward(self, x):
"""Forward Function.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint

from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES
Expand All @@ -20,6 +20,10 @@ def __init__(self, encoder, decoder):
self.encoder = build_component(encoder)
self.decoder = build_component(decoder)

# support fp16
self.fp16_enabled = False

@auto_fp16(apply_to=('x', 'mask_in'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

def forward(self, x, mask_in):
"""Forward Function.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from mmcv.cnn import constant_init, normal_init
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmedit.models.builder import build_backbone, build_component
Expand Down Expand Up @@ -42,6 +42,10 @@ def __init__(self,

self.return_offset = return_offset

# support fp16
self.fp16_enabled = False

@auto_fp16(apply_to=('x', ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

def forward(self, x):
"""Forward function.

Expand Down
5 changes: 5 additions & 0 deletions mmedit/models/inpaintors/one_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import mmcv
import torch
from mmcv.runner import auto_fp16
from torchvision.utils import save_image

from mmedit.core import L1Evaluation, psnr, ssim, tensor2img
Expand Down Expand Up @@ -76,6 +77,9 @@ def __init__(self,

self.generator = build_backbone(encdec)

# support fp16
self.fp16_enabled = False

# build loss modules
if self.with_gan:
self.disc = build_component(disc)
Expand Down Expand Up @@ -113,6 +117,7 @@ def init_weights(self, pretrained=None):
if self.with_gan:
self.disc.init_weights(pretrained=pretrained)

@auto_fp16(apply_to=('masked_img', 'mask'))
def forward(self, masked_img, mask, test_mode=True, **kwargs):
"""Forward function.

Expand Down
5 changes: 5 additions & 0 deletions mmedit/models/mattors/dim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16

from ..builder import build_loss
from ..registry import MODELS
Expand Down Expand Up @@ -58,6 +59,10 @@ def __init__(self,
if loss_refine is not None:
self.loss_refine = build_loss(loss_refine)

# support fp16
self.fp16_enabled = False

@auto_fp16(apply_to=('x', 'refine'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

def _forward(self, x, refine):
raw_alpha = self.backbone(x)
pred_alpha = raw_alpha.sigmoid()
Expand Down
4 changes: 4 additions & 0 deletions mmedit/models/mattors/gca.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16

from ..builder import build_loss
from ..registry import MODELS
Expand Down Expand Up @@ -32,7 +33,10 @@ def __init__(self,
super(GCA, self).__init__(backbone, None, train_cfg, test_cfg,
pretrained)
self.loss_alpha = build_loss(loss_alpha)
# support fp16
self.fp16_enabled = False

@auto_fp16(apply_to=('x', ))
def _forward(self, x):
raw_alpha = self.backbone(x)
pred_alpha = (raw_alpha.tanh() + 1.0) / 2.0
Expand Down
5 changes: 5 additions & 0 deletions mmedit/models/mattors/indexnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import auto_fp16

from ..builder import build_loss
from ..registry import MODELS
Expand Down Expand Up @@ -38,9 +39,13 @@ def __init__(self,
self.loss_comp = (
build_loss(loss_comp) if loss_comp is not None else None)

# support fp16
self.fp16_enabled = False

def forward_dummy(self, inputs):
return self.backbone(inputs)

@auto_fp16(apply_to=('merged', 'trimap'))
def forward_train(self, merged, trimap, meta, alpha, ori_merged, fg, bg):
"""Forward function for training IndexNet model.

Expand Down
5 changes: 5 additions & 0 deletions mmedit/models/restorers/basic_restorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path as osp

import mmcv
from mmcv.runner import auto_fp16

from mmedit.core import psnr, ssim, tensor2img
from ..base import BaseModel
Expand Down Expand Up @@ -39,6 +40,9 @@ def __init__(self,
self.train_cfg = train_cfg
self.test_cfg = test_cfg

# support fp16
self.fp16_enabled = False

# generator
self.generator = build_backbone(generator)
self.init_weights(pretrained)
Expand All @@ -55,6 +59,7 @@ def init_weights(self, pretrained=None):
"""
self.generator.init_weights(pretrained)

@auto_fp16(apply_to=('lq', ))
def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.

Expand Down
6 changes: 6 additions & 0 deletions mmedit/models/restorers/srgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mmcv.runner import auto_fp16

from ..builder import build_backbone, build_component, build_loss
from ..common import set_requires_grad
from ..registry import MODELS
Expand Down Expand Up @@ -50,6 +52,9 @@ def __init__(self,
self.discriminator = build_component(
discriminator) if discriminator else None

# support fp16
self.fp16_enabled = False

# loss
self.gan_loss = build_loss(gan_loss) if gan_loss else None
self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None
Expand All @@ -75,6 +80,7 @@ def init_weights(self, pretrained=None):
if self.discriminator:
self.discriminator.init_weights(pretrained=pretrained)

@auto_fp16(apply_to=('lq', ))
def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.

Expand Down
5 changes: 5 additions & 0 deletions mmedit/models/synthesizers/cycle_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch.nn as nn
from mmcv.parallel import MMDistributedDataParallel
from mmcv.runner import auto_fp16

from mmedit.core import tensor2img
from ..base import BaseModel
Expand Down Expand Up @@ -114,6 +115,9 @@ def __init__(self,
self.test_direction = ('b2a' if self.test_direction == 'a2b'
else 'a2b')

# support fp16
self.fp16_enabled = False

self.init_weights(pretrained)

def init_weights(self, pretrained=None):
Expand Down Expand Up @@ -162,6 +166,7 @@ def setup(self, img_a, img_b, meta):

return real_a, real_b, image_path

@auto_fp16(apply_to=('img_a', 'img_b'))
def forward_train(self, img_a, img_b, meta):
"""Forward function for training.

Expand Down
4 changes: 4 additions & 0 deletions mmedit/models/synthesizers/pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mmcv
import numpy as np
import torch
from mmcv.runner import auto_fp16

from mmedit.core import tensor2img
from ..base import BaseModel
Expand Down Expand Up @@ -77,6 +78,8 @@ def __init__(self,
self.show_input = (False if self.test_cfg is None else
self.test_cfg.get('show_input', False))

# support fp16
self.fp16_enabled = False
self.init_weights(pretrained)

def init_weights(self, pretrained=None):
Expand Down Expand Up @@ -108,6 +111,7 @@ def setup(self, img_a, img_b, meta):

return real_a, real_b, image_path

@auto_fp16(apply_to=('img_a', 'img_b'))
def forward_train(self, img_a, img_b, meta):
"""Forward function for training.

Expand Down