Skip to content

Commit

Permalink
Update projector script; alternative losses still unstable
Browse files Browse the repository at this point in the history
  • Loading branch information
PDillis committed Apr 22, 2022
1 parent f43205a commit 6dc2867
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def project(
n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
message = f'step {step + 1:{n_digits}d}/{num_steps}: percept loss {percept_error.item():.7e} | ' \
f'pixel mse {mse_error.item():.7e} | ssim {ssim_loss.item():.7e} | loss {loss.item():.7e}'
print(message)#, end='\r')
print(message, end='\r')

last_status = {'percept_error': percept_error.item(),
'pixel_mse': mse_error.item(),
Expand Down Expand Up @@ -351,6 +351,7 @@ def project(
@click.command()
@click.pass_context
@click.option('--network', '-net', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--cfg', help='Config of the network, used only if you want to use one of the models that are in torch_utils.gen_utils.resume_specs', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']))
@click.option('--target', '-t', 'target_fname', type=click.Path(exists=True, dir_okay=False), help='Target image file to project to', required=True, metavar='FILE')
# Optimization options
@click.option('--num-steps', '-nsteps', help='Number of optimization steps', type=click.IntRange(min=0), default=1000, show_default=True)
Expand Down Expand Up @@ -380,6 +381,7 @@ def project(
def run_projection(
ctx: click.Context,
network_pkl: str,
cfg: str,
target_fname: str,
num_steps: int,
initial_learning_rate: float,
Expand Down Expand Up @@ -422,6 +424,12 @@ def run_projection(
loss_paper = 'sgan2'

# Load networks.
# If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url
try:
network_pkl = gen_utils.resume_specs[cfg][network_pkl]
except KeyError:
# Otherwise, it's a local file or an url
pass
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as fp:
Expand Down

0 comments on commit 6dc2867

Please sign in to comment.