Skip to content

Commit

Permalink
Fix for --freezeM options and typos (Issue #6)
Browse files Browse the repository at this point in the history
  • Loading branch information
PDillis committed Apr 3, 2022
1 parent 2cd5e5c commit f43205a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def main(**kwargs):
c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase
c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax
c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth
c.G_kwargs.mapping_kwargs.freeze_layers = opts.freezeM
c.G_kwargs.mapping_kwargs.freeze_embed = opts.freezeE
c.D_kwargs.block_kwargs.freeze_layers = opts.freezeD
c.G_kwargs.mapping_kwargs.freeze_layers = opts.freezem
c.G_kwargs.mapping_kwargs.freeze_embed = opts.freezee
c.D_kwargs.block_kwargs.freeze_layers = opts.freezed
c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group
if opts.gamma is not None:
c.loss_kwargs.r1_gamma = float(opts.gamma)
Expand Down
3 changes: 2 additions & 1 deletion training/networks_stylegan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def __init__(self,
self.bias_gain = lr_multiplier

weight = torch.randn([out_features, in_features]) / lr_multiplier
bias = torch.full([out_features], np.float32(bias_init)) if bias else None

if trainable:
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.bias = torch.nn.Parameter(bias)
else:
self.register_buffer('weight', weight)
if bias is not None:
Expand Down
5 changes: 3 additions & 2 deletions training/networks_stylegan3.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ def __init__(self,
self.bias_gain = lr_multiplier

weight = torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)
bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
bias = torch.from_numpy(bias_init / lr_multiplier) if bias else None

if trainable:
self.weight = torch.nn.Parameter(weight)
bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
self.bias = torch.nn.Parameter(bias)
else:
self.register_buffer('weight', weight)
if bias is not None:
Expand Down

0 comments on commit f43205a

Please sign in to comment.