From 4765f95ca643cc53c2b626d410cd33c759b784a0 Mon Sep 17 00:00:00 2001 From: PDillis Date: Sat, 23 Apr 2022 18:10:43 +0200 Subject: [PATCH] Update README, allow SGAN-NADA models to be used with the coee (Issue #9), general code cleaning --- README.md | 15 +++++- generate.py | 153 +++++++++++++++++++++------------------------------- legacy.py | 73 +++++++++++++------------ train.py | 2 +- 4 files changed, 116 insertions(+), 127 deletions(-) diff --git a/README.md b/README.md index ed64487e..a3e72961 100644 --- a/README.md +++ b/README.md @@ -128,12 +128,25 @@ This repository adds/has the following changes (not yet the complete list): * Added the rest of the affine transformations * Added widget for class-conditional models (***TODO:*** mix classes with continuous values for `cls`!) * General model and code additions - * ***TODO:*** [Better sampling?](https://arxiv.org/abs/2110.08009) * [Multi-modal truncation trick](https://arxiv.org/abs/2202.12211): find the different clusters in your model and use the closest one to your dlatent, in order to increase the fidelity (TODO: finish skeleton implementation) * StyleGAN3: anchor the latent space for easier to follow interpolations (thanks to [Rivers Have Wings](https://github.com/crowsonkb) and [nshepperd](https://github.com/nshepperd)). * Use CPU instead of GPU if desired (not recommended, but perfectly fine for generating images, whenever the custom CUDA kernels fail to compile). * Add missing dependencies and channels so that the [`conda`](https://docs.conda.io/en/latest/) environment is correctly setup in Windows (PR's [#111](https://github.com/NVlabs/stylegan3/pull/111) /[#116](https://github.com/NVlabs/stylegan3/pull/116) /[#125](https://github.com/NVlabs/stylegan3/pull/125) and [#80](https://github.com/NVlabs/stylegan3/pull/80) /[#143](https://github.com/NVlabs/stylegan3/pull/143) from the base, respectively) + * ***TODO:*** Current state fails to install the CUDA version of PyTorch, so make different files for each OS (Ubuntu/Windows) + * Use [StyleGAN-NADA](https://github.com/rinongal/StyleGAN-nada) models with any part of the code (Issue [#9](https://github.com/PDillis/stylegan3-fun/issues/9)) + * The StyleGAN-NADA models must first be converted via [Vadim Epstein](https://github.com/eps696) 's conversion code found [here](https://github.com/eps696/stylegan2ada#tweaking-models). +* ***TODO*** list (this is a long one with more to come, so any help is appreciated): + * [Generate images/interpolations with the layers of the model](https://twitter.com/makeitrad1/status/1517251876504600576?s=20&t=X5Df8N2gG_zGh5jJLVkvvw) + * Access the layers via: `renderer.Renderer.run_synthesis_net(net, captuer_layer=layer)` + * Multi-modal truncation trick: finish skeleton code, add + * [PTI](https://github.com/danielroich/PTI) for better inversion + * [Better sampling](https://arxiv.org/abs/2110.08009) + * [Progressive growing modules for StyleGAN-XL](https://github.com/autonomousvision/stylegan_xl) to be able to use the pretrained models + * [Add cross-model interpolation](https://twitter.com/arfafax/status/1297681537337446400?s=20&t=xspnTaLFTvd7y4krg8tkxA) + * Generate class labels automatically with dataset structure (subfolders and such) + * Make it easy to download pretrained models from Drive, otherwise a lot of models can't be used with `dnnlib.util.open_url` + (e.g., [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human) models) ***TODO:*** Finish documentation for better user experience, add videos/images, code samples, visuals... diff --git a/generate.py b/generate.py index 3f83a762..b0d89506 100644 --- a/generate.py +++ b/generate.py @@ -32,10 +32,7 @@ def main(): @click.pass_context @click.option('--network', 'network_pkl', help='Network pickle filename: can be URL, local file, or the name of the model in torch_utils.gen_utils.resume_specs', required=True) @click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True) -@click.option('--cfg', help='Config of the network, used only if you want to use one of the models that are in torch_utils.gen_utils.resume_specs') -# Recreate snapshot grid during training (doesn't work!!!) -@click.option('--recreate-snapshot-grid', 'training_snapshot', is_flag=True, help='Add flag if you wish to recreate the snapshot grid created during training') -@click.option('--snapshot-size', type=click.Choice(['1080p', '4k', '8k']), help='Size of the snapshot', default='4k', show_default=True) +@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs') # Synthesis options (feed a list of seeds or give the projected w to synthesize) @click.option('--seeds', type=gen_utils.num_range, help='List of random seeds') @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) @@ -56,8 +53,6 @@ def generate_images( network_pkl: str, device: Optional[str], cfg: str, - training_snapshot: bool, - snapshot_size: str, seeds: Optional[List[int]], truncation_psi: float, class_idx: Optional[int], @@ -100,9 +95,6 @@ def generate_images( python generate.py images --cfg=stylegan2 --network=wikiart1024-C --class=155 \\ --trunc=0.7 --seeds=10-50 --save-grid """ - print(f'Loading networks from "{network_pkl}"...') - device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu') - # If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url try: network_pkl = gen_utils.resume_specs[cfg][network_pkl] @@ -110,6 +102,9 @@ def generate_images( # Otherwise, it's a local file or an url pass + print(f'Loading networks from "{network_pkl}"...') + device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu') + with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore @@ -152,60 +147,33 @@ def generate_images( if class_idx is not None: print('warn: --class=lbl ignored when running on an unconditional network') - if training_snapshot: - # This doesn't really work, so more work is warranted; TODO: move it to torch_utils/gen_utils.py - print('Recreating the snapshot grid...') - size_dict = {'1080p': (1920, 1080, 3, 2), '4k': (3840, 2160, 7, 4), '8k': (7680, 4320, 7, 4)} - grid_width = int(np.clip(size_dict[snapshot_size][0] // G.img_resolution, size_dict[snapshot_size][2], 32)) - grid_height = int(np.clip(size_dict[snapshot_size][1] // G.img_resolution, size_dict[snapshot_size][3], 32)) - num_images = grid_width * grid_height - - rnd = np.random.RandomState(0) - torch.manual_seed(0) - all_indices = list(range(70000)) # irrelevant - rnd.shuffle(all_indices) - - grid_z = rnd.randn(num_images, G.z_dim) # TODO: generate with torch, as in the training_loop.py file - grid_img = gen_utils.z_to_img(G, torch.from_numpy(grid_z).to(device), label, truncation_psi, noise_mode) - PIL.Image.fromarray(gen_utils.create_image_grid(grid_img, (grid_width, grid_height)), - 'RGB').save(os.path.join(run_dir, 'fakes.jpg')) - print('Saving individual images...') - for idx, z in enumerate(grid_z): - z = torch.from_numpy(z).unsqueeze(0).to(device) - w = G.mapping(z, None) # to save the dlatent in .npy format - img = gen_utils.z_to_img(G, z, label, truncation_psi, noise_mode)[0] - PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'img{idx:04d}.jpg')) - np.save(os.path.join(run_dir, f'img{idx:04d}.npy'), w.unsqueeze(0).cpu().numpy()) - else: - if seeds is None: - ctx.fail('--seeds option is required when not using --projected-w') - - # Generate images. - images = [] - for seed_idx, seed in enumerate(seeds): - print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) - z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) - img = gen_utils.z_to_img(G, z, label, truncation_psi, noise_mode)[0] - if save_grid: - images.append(img) - PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'seed{seed:04d}.jpg')) + if seeds is None: + ctx.fail('--seeds option is required when not using --projected-w') + # Generate images. + images = [] + for seed_idx, seed in enumerate(seeds): + print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + img = gen_utils.z_to_img(G, z, label, truncation_psi, noise_mode)[0] if save_grid: - print('Saving image grid...') - # We let the function infer the shape of the grid - if (grid_width, grid_height) == (None, None): - PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images)), - 'RGB').save(os.path.join(run_dir, 'grid.jpg')) - # The user tells the specific shape of the grid, but one value may be None - else: - PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images), (grid_width, grid_height)), - 'RGB').save(os.path.join(run_dir, 'grid.jpg')) + images.append(img) + PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'seed{seed:04d}.jpg')) + + if save_grid: + print('Saving image grid...') + # We let the function infer the shape of the grid + if (grid_width, grid_height) == (None, None): + PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images)), + 'RGB').save(os.path.join(run_dir, 'grid.jpg')) + # The user tells the specific shape of the grid, but one value may be None + else: + PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images), (grid_width, grid_height)), + 'RGB').save(os.path.join(run_dir, 'grid.jpg')) # Save the configuration used ctx.obj = { 'network_pkl': network_pkl, - 'training_snapshot': training_snapshot, - 'snapshot_size': snapshot_size, 'seeds': seeds, 'truncation_psi': truncation_psi, 'class_idx': class_idx, @@ -226,6 +194,7 @@ def generate_images( @main.command(name='random-video') @click.pass_context @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs') # Synthesis options @click.option('--seeds', type=gen_utils.num_range, help='List of random seeds', required=True) @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) @@ -242,10 +211,11 @@ def generate_images( @click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)') # Extra parameters for saving the results @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR') -@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='', show_default=True) +@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results') def random_interpolation_video( ctx: click.Context, network_pkl: Union[str, os.PathLike], + cfg: str, seeds: List[int], truncation_psi: float, new_center: Tuple[str, Union[int, np.ndarray]], @@ -278,18 +248,28 @@ def random_interpolation_video( --fps=60 -sec=60 --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl """ + # If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url + try: + network_pkl = gen_utils.resume_specs[cfg][network_pkl] + except KeyError: + # Otherwise, it's a local file or an url + pass + print(f'Loading networks from "{network_pkl}"...') device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + # Stabilize/anchor the latent space if anchor_latent_space: gen_utils.anchor_latent_space(G) - # Create the run dir with the given name description; add slowdown if different than the default (1) - description = 'random-video' if len(description) == 0 else description - description = f'{description}-{slowdown}xslowdown' if slowdown != 1 else description - run_dir = gen_utils.make_run_dir(outdir, description) + # Create the run dir with the given name description; add slowdown if different from the default (1) + desc = 'random-video' + desc = f'random-video-{description}' if description is not None else desc + desc = f'{desc}-{slowdown}xslowdown' if slowdown != 1 else desc + run_dir = gen_utils.make_run_dir(outdir, desc) # Number of frames in the video and its total duration in seconds num_frames = int(np.rint(duration_sec * fps)) @@ -360,42 +340,33 @@ def random_interpolation_video( slowdown //= 2 if new_center is None: - def make_frame(t): - frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1)) - latents = torch.from_numpy(all_latents[frame_idx]).to(device) - # Get the images with the labels - images = gen_utils.z_to_img(G, latents, label, truncation_psi, noise_mode) - # Generate the grid for this timestamp - grid = gen_utils.create_image_grid(images, grid_size) - # Grayscale => RGB - if grid.shape[2] == 1: - grid = grid.repeat(3, 2) - return grid - + w_avg = G.mapping.w_avg else: new_center, new_center_value = new_center # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray) if isinstance(new_center_value, int): - new_w_avg = gen_utils.get_w_from_seed(G, device, new_center_value, truncation_psi=1.0) # We want the pure dlatent + w_avg = gen_utils.get_w_from_seed(G, device, new_center_value, + truncation_psi=1.0) # We want the pure dlatent elif isinstance(new_center_value, np.ndarray): - new_w_avg = torch.from_numpy(new_center_value).to(device) + w_avg = torch.from_numpy(new_center_value).to(device) else: ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!') - def make_frame(t): - frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1)) - latents = torch.from_numpy(all_latents[frame_idx]).to(device) - # Do the truncation trick with this new center - w = G.mapping(latents, None) - w = new_w_avg + (w - new_w_avg) * truncation_psi - # Get the images with the new center - images = gen_utils.w_to_img(G, w, noise_mode) - # Generate the grid for this timestamp - grid = gen_utils.create_image_grid(images, grid_size) - # Grayscale => RGB - if grid.shape[2] == 1: - grid = grid.repeat(3, 2) - return grid + # Auxiliary function for moviepy + def make_frame(t): + frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1)) + latents = torch.from_numpy(all_latents[frame_idx]).to(device) + # Do the truncation trick (with the global centroid or the new center provided by the user) + w = G.mapping(latents, None) + w = w_avg + (w - w_avg) * truncation_psi + # Get the images with the new center + images = gen_utils.w_to_img(G, w, noise_mode) + # Generate the grid for this timestamp + grid = gen_utils.create_image_grid(images, grid_size) + # Grayscale => RGB + if grid.shape[2] == 1: + grid = grid.repeat(3, 2) + return grid # Generate video using the respective make_frame function videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) @@ -424,7 +395,7 @@ def make_frame(t): 'duration_sec': duration_sec, 'video_fps': fps, 'run_dir': run_dir, - 'description': description, + 'description': desc, 'compress': compress, 'smoothing_sec': smoothing_sec } diff --git a/legacy.py b/legacy.py index 8cf53cb9..6e52cdf6 100644 --- a/legacy.py +++ b/legacy.py @@ -21,40 +21,45 @@ def load_network_pkl(f, force_fp16=False): data = _LegacyUnpickler(f).load() - - # Legacy TensorFlow pickle => convert. - if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): - tf_G, tf_D, tf_Gs = data - G = convert_tf_generator(tf_G) - D = convert_tf_discriminator(tf_D) - G_ema = convert_tf_generator(tf_Gs) - data = dict(G=G, D=D, G_ema=G_ema) - - # Add missing fields. - if 'training_set_kwargs' not in data: - data['training_set_kwargs'] = None - if 'augment_pipe' not in data: - data['augment_pipe'] = None - - # Validate contents. - assert isinstance(data['G'], torch.nn.Module) - assert isinstance(data['D'], torch.nn.Module) - assert isinstance(data['G_ema'], torch.nn.Module) - assert isinstance(data['training_set_kwargs'], (dict, type(None))) - assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) - - # Force FP16. - if force_fp16: - for key in ['G', 'D', 'G_ema']: - old = data[key] - kwargs = copy.deepcopy(old.init_kwargs) - fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) - fp16_kwargs.num_fp16_res = 4 - fp16_kwargs.conv_clamp = 256 - if kwargs != old.init_kwargs: - new = type(old)(**kwargs).eval().requires_grad_(False) - misc.copy_params_and_buffers(old, new, require_all=True) - data[key] = new + try: + # Legacy TensorFlow pickle => convert. + if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): + tf_G, tf_D, tf_Gs = data + G = convert_tf_generator(tf_G) + D = convert_tf_discriminator(tf_D) + G_ema = convert_tf_generator(tf_Gs) + data = dict(G=G, D=D, G_ema=G_ema) + + # Add missing fields. + if 'training_set_kwargs' not in data: + data['training_set_kwargs'] = None + if 'augment_pipe' not in data: + data['augment_pipe'] = None + + # Validate contents. + if 'G' in data: + assert isinstance(data['G'], torch.nn.Module) + if 'D' in data: + assert isinstance(data['D'], torch.nn.Module) + assert isinstance(data['G_ema'], torch.nn.Module) + assert isinstance(data['training_set_kwargs'], (dict, type(None))) + assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) + + # Force FP16. + if force_fp16: + for key in ['G', 'D', 'G_ema']: + old = data[key] + kwargs = copy.deepcopy(old.init_kwargs) + fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) + fp16_kwargs.num_fp16_res = 4 + fp16_kwargs.conv_clamp = 256 + if kwargs != old.init_kwargs: + new = type(old)(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(old, new, require_all=True) + data[key] = new + except KeyError: + # Most likely a StyleGAN-NADA pkl, so pass and return data + pass return data #---------------------------------------------------------------------------- diff --git a/train.py b/train.py index 5e35feeb..7813c038 100644 --- a/train.py +++ b/train.py @@ -195,7 +195,7 @@ def main(**kwargs): c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict()) c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) - c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) + c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) # TODO: Use ComplexSGD: https://arxiv.org/abs/2102.08431 c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss') c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)