From 6dc2867beb9012dc3700b496932cf840b8d7b22c Mon Sep 17 00:00:00 2001 From: Diego Porres Date: Fri, 22 Apr 2022 19:43:37 +0200 Subject: [PATCH] Update projector script; alternative losses still unstable --- projector.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/projector.py b/projector.py index 13297d14..c2fdb6a9 100755 --- a/projector.py +++ b/projector.py @@ -226,7 +226,7 @@ def project( n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1 message = f'step {step + 1:{n_digits}d}/{num_steps}: percept loss {percept_error.item():.7e} | ' \ f'pixel mse {mse_error.item():.7e} | ssim {ssim_loss.item():.7e} | loss {loss.item():.7e}' - print(message)#, end='\r') + print(message, end='\r') last_status = {'percept_error': percept_error.item(), 'pixel_mse': mse_error.item(), @@ -351,6 +351,7 @@ def project( @click.command() @click.pass_context @click.option('--network', '-net', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--cfg', help='Config of the network, used only if you want to use one of the models that are in torch_utils.gen_utils.resume_specs', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r'])) @click.option('--target', '-t', 'target_fname', type=click.Path(exists=True, dir_okay=False), help='Target image file to project to', required=True, metavar='FILE') # Optimization options @click.option('--num-steps', '-nsteps', help='Number of optimization steps', type=click.IntRange(min=0), default=1000, show_default=True) @@ -380,6 +381,7 @@ def project( def run_projection( ctx: click.Context, network_pkl: str, + cfg: str, target_fname: str, num_steps: int, initial_learning_rate: float, @@ -422,6 +424,12 @@ def run_projection( loss_paper = 'sgan2' # Load networks. + # If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url + try: + network_pkl = gen_utils.resume_specs[cfg][network_pkl] + except KeyError: + # Otherwise, it's a local file or an url + pass print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as fp: