From f43205afaf2f6dbb2c82017413a42c96e82212a6 Mon Sep 17 00:00:00 2001 From: PDillis Date: Mon, 4 Apr 2022 01:20:38 +0200 Subject: [PATCH] Fix for `--freezeM` options and typos (Issue #6) --- train.py | 6 +++--- training/networks_stylegan2.py | 3 ++- training/networks_stylegan3.py | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index e6776d02..5e35feeb 100644 --- a/train.py +++ b/train.py @@ -214,9 +214,9 @@ def main(**kwargs): c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth - c.G_kwargs.mapping_kwargs.freeze_layers = opts.freezeM - c.G_kwargs.mapping_kwargs.freeze_embed = opts.freezeE - c.D_kwargs.block_kwargs.freeze_layers = opts.freezeD + c.G_kwargs.mapping_kwargs.freeze_layers = opts.freezem + c.G_kwargs.mapping_kwargs.freeze_embed = opts.freezee + c.D_kwargs.block_kwargs.freeze_layers = opts.freezed c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group if opts.gamma is not None: c.loss_kwargs.r1_gamma = float(opts.gamma) diff --git a/training/networks_stylegan2.py b/training/networks_stylegan2.py index 3b5394f2..2b4ca516 100644 --- a/training/networks_stylegan2.py +++ b/training/networks_stylegan2.py @@ -109,10 +109,11 @@ def __init__(self, self.bias_gain = lr_multiplier weight = torch.randn([out_features, in_features]) / lr_multiplier + bias = torch.full([out_features], np.float32(bias_init)) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) - self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.bias = torch.nn.Parameter(bias) else: self.register_buffer('weight', weight) if bias is not None: diff --git a/training/networks_stylegan3.py b/training/networks_stylegan3.py index 6db382c1..45f416b0 100644 --- a/training/networks_stylegan3.py +++ b/training/networks_stylegan3.py @@ -85,11 +85,12 @@ def __init__(self, self.bias_gain = lr_multiplier weight = torch.randn([out_features, in_features]) * (weight_init / lr_multiplier) + bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) + bias = torch.from_numpy(bias_init / lr_multiplier) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) - bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) - self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None + self.bias = torch.nn.Parameter(bias) else: self.register_buffer('weight', weight) if bias is not None: