Skip to content

Commit

Permalink
Update README, allow SGAN-NADA models to be used with the coee (Issue #9
Browse files Browse the repository at this point in the history
), general code cleaning
  • Loading branch information
PDillis committed Apr 23, 2022
1 parent 6dc2867 commit 4765f95
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 127 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,25 @@ This repository adds/has the following changes (not yet the complete list):
* Added the rest of the affine transformations
* Added widget for class-conditional models (***TODO:*** mix classes with continuous values for `cls`!)
* General model and code additions
* ***TODO:*** [Better sampling?](https://arxiv.org/abs/2110.08009)
* [Multi-modal truncation trick](https://arxiv.org/abs/2202.12211): find the different clusters in your model and use the closest one to your dlatent, in order to increase the fidelity (TODO: finish skeleton implementation)
* StyleGAN3: anchor the latent space for easier to follow interpolations (thanks to [Rivers Have Wings](https://github.com/crowsonkb) and [nshepperd](https://github.com/nshepperd)).
* Use CPU instead of GPU if desired (not recommended, but perfectly fine for generating images, whenever the custom CUDA kernels fail to compile).
* Add missing dependencies and channels so that the [`conda`](https://docs.conda.io/en/latest/) environment is correctly setup in Windows
(PR's [#111](https://github.com/NVlabs/stylegan3/pull/111) /[#116](https://github.com/NVlabs/stylegan3/pull/116) /[#125](https://github.com/NVlabs/stylegan3/pull/125) and [#80](https://github.com/NVlabs/stylegan3/pull/80) /[#143](https://github.com/NVlabs/stylegan3/pull/143) from the base, respectively)
* ***TODO:*** Current state fails to install the CUDA version of PyTorch, so make different files for each OS (Ubuntu/Windows)
* Use [StyleGAN-NADA](https://github.com/rinongal/StyleGAN-nada) models with any part of the code (Issue [#9](https://github.com/PDillis/stylegan3-fun/issues/9))
* The StyleGAN-NADA models must first be converted via [Vadim Epstein](https://github.com/eps696) 's conversion code found [here](https://github.com/eps696/stylegan2ada#tweaking-models).
* ***TODO*** list (this is a long one with more to come, so any help is appreciated):
* [Generate images/interpolations with the layers of the model](https://twitter.com/makeitrad1/status/1517251876504600576?s=20&t=X5Df8N2gG_zGh5jJLVkvvw)
* Access the layers via: `renderer.Renderer.run_synthesis_net(net, captuer_layer=layer)`
* Multi-modal truncation trick: finish skeleton code, add
* [PTI](https://github.com/danielroich/PTI) for better inversion
* [Better sampling](https://arxiv.org/abs/2110.08009)
* [Progressive growing modules for StyleGAN-XL](https://github.com/autonomousvision/stylegan_xl) to be able to use the pretrained models
* [Add cross-model interpolation](https://twitter.com/arfafax/status/1297681537337446400?s=20&t=xspnTaLFTvd7y4krg8tkxA)
* Generate class labels automatically with dataset structure (subfolders and such)
* Make it easy to download pretrained models from Drive, otherwise a lot of models can't be used with `dnnlib.util.open_url`
(e.g., [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human) models)

***TODO:*** Finish documentation for better user experience, add videos/images, code samples, visuals...

Expand Down
153 changes: 62 additions & 91 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ def main():
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename: can be URL, local file, or the name of the model in torch_utils.gen_utils.resume_specs', required=True)
@click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True)
@click.option('--cfg', help='Config of the network, used only if you want to use one of the models that are in torch_utils.gen_utils.resume_specs')
# Recreate snapshot grid during training (doesn't work!!!)
@click.option('--recreate-snapshot-grid', 'training_snapshot', is_flag=True, help='Add flag if you wish to recreate the snapshot grid created during training')
@click.option('--snapshot-size', type=click.Choice(['1080p', '4k', '8k']), help='Size of the snapshot', default='4k', show_default=True)
@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
# Synthesis options (feed a list of seeds or give the projected w to synthesize)
@click.option('--seeds', type=gen_utils.num_range, help='List of random seeds')
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
Expand All @@ -56,8 +53,6 @@ def generate_images(
network_pkl: str,
device: Optional[str],
cfg: str,
training_snapshot: bool,
snapshot_size: str,
seeds: Optional[List[int]],
truncation_psi: float,
class_idx: Optional[int],
Expand Down Expand Up @@ -100,16 +95,16 @@ def generate_images(
python generate.py images --cfg=stylegan2 --network=wikiart1024-C --class=155 \\
--trunc=0.7 --seeds=10-50 --save-grid
"""
print(f'Loading networks from "{network_pkl}"...')
device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')

# If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url
try:
network_pkl = gen_utils.resume_specs[cfg][network_pkl]
except KeyError:
# Otherwise, it's a local file or an url
pass

print(f'Loading networks from "{network_pkl}"...')
device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')

with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

Expand Down Expand Up @@ -152,60 +147,33 @@ def generate_images(
if class_idx is not None:
print('warn: --class=lbl ignored when running on an unconditional network')

if training_snapshot:
# This doesn't really work, so more work is warranted; TODO: move it to torch_utils/gen_utils.py
print('Recreating the snapshot grid...')
size_dict = {'1080p': (1920, 1080, 3, 2), '4k': (3840, 2160, 7, 4), '8k': (7680, 4320, 7, 4)}
grid_width = int(np.clip(size_dict[snapshot_size][0] // G.img_resolution, size_dict[snapshot_size][2], 32))
grid_height = int(np.clip(size_dict[snapshot_size][1] // G.img_resolution, size_dict[snapshot_size][3], 32))
num_images = grid_width * grid_height

rnd = np.random.RandomState(0)
torch.manual_seed(0)
all_indices = list(range(70000)) # irrelevant
rnd.shuffle(all_indices)

grid_z = rnd.randn(num_images, G.z_dim) # TODO: generate with torch, as in the training_loop.py file
grid_img = gen_utils.z_to_img(G, torch.from_numpy(grid_z).to(device), label, truncation_psi, noise_mode)
PIL.Image.fromarray(gen_utils.create_image_grid(grid_img, (grid_width, grid_height)),
'RGB').save(os.path.join(run_dir, 'fakes.jpg'))
print('Saving individual images...')
for idx, z in enumerate(grid_z):
z = torch.from_numpy(z).unsqueeze(0).to(device)
w = G.mapping(z, None) # to save the dlatent in .npy format
img = gen_utils.z_to_img(G, z, label, truncation_psi, noise_mode)[0]
PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'img{idx:04d}.jpg'))
np.save(os.path.join(run_dir, f'img{idx:04d}.npy'), w.unsqueeze(0).cpu().numpy())
else:
if seeds is None:
ctx.fail('--seeds option is required when not using --projected-w')

# Generate images.
images = []
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
img = gen_utils.z_to_img(G, z, label, truncation_psi, noise_mode)[0]
if save_grid:
images.append(img)
PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'seed{seed:04d}.jpg'))
if seeds is None:
ctx.fail('--seeds option is required when not using --projected-w')

# Generate images.
images = []
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
img = gen_utils.z_to_img(G, z, label, truncation_psi, noise_mode)[0]
if save_grid:
print('Saving image grid...')
# We let the function infer the shape of the grid
if (grid_width, grid_height) == (None, None):
PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images)),
'RGB').save(os.path.join(run_dir, 'grid.jpg'))
# The user tells the specific shape of the grid, but one value may be None
else:
PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images), (grid_width, grid_height)),
'RGB').save(os.path.join(run_dir, 'grid.jpg'))
images.append(img)
PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'seed{seed:04d}.jpg'))

if save_grid:
print('Saving image grid...')
# We let the function infer the shape of the grid
if (grid_width, grid_height) == (None, None):
PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images)),
'RGB').save(os.path.join(run_dir, 'grid.jpg'))
# The user tells the specific shape of the grid, but one value may be None
else:
PIL.Image.fromarray(gen_utils.create_image_grid(np.array(images), (grid_width, grid_height)),
'RGB').save(os.path.join(run_dir, 'grid.jpg'))

# Save the configuration used
ctx.obj = {
'network_pkl': network_pkl,
'training_snapshot': training_snapshot,
'snapshot_size': snapshot_size,
'seeds': seeds,
'truncation_psi': truncation_psi,
'class_idx': class_idx,
Expand All @@ -226,6 +194,7 @@ def generate_images(
@main.command(name='random-video')
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
# Synthesis options
@click.option('--seeds', type=gen_utils.num_range, help='List of random seeds', required=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
Expand All @@ -242,10 +211,11 @@ def generate_images(
@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
# Extra parameters for saving the results
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR')
@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='', show_default=True)
@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results')
def random_interpolation_video(
ctx: click.Context,
network_pkl: Union[str, os.PathLike],
cfg: str,
seeds: List[int],
truncation_psi: float,
new_center: Tuple[str, Union[int, np.ndarray]],
Expand Down Expand Up @@ -278,18 +248,28 @@ def random_interpolation_video(
--fps=60 -sec=60 --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
"""
# If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url
try:
network_pkl = gen_utils.resume_specs[cfg][network_pkl]
except KeyError:
# Otherwise, it's a local file or an url
pass

print(f'Loading networks from "{network_pkl}"...')
device = torch.device('cuda')

with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

# Stabilize/anchor the latent space
if anchor_latent_space:
gen_utils.anchor_latent_space(G)

# Create the run dir with the given name description; add slowdown if different than the default (1)
description = 'random-video' if len(description) == 0 else description
description = f'{description}-{slowdown}xslowdown' if slowdown != 1 else description
run_dir = gen_utils.make_run_dir(outdir, description)
# Create the run dir with the given name description; add slowdown if different from the default (1)
desc = 'random-video'
desc = f'random-video-{description}' if description is not None else desc
desc = f'{desc}-{slowdown}xslowdown' if slowdown != 1 else desc
run_dir = gen_utils.make_run_dir(outdir, desc)

# Number of frames in the video and its total duration in seconds
num_frames = int(np.rint(duration_sec * fps))
Expand Down Expand Up @@ -360,42 +340,33 @@ def random_interpolation_video(
slowdown //= 2

if new_center is None:
def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Get the images with the labels
images = gen_utils.z_to_img(G, latents, label, truncation_psi, noise_mode)
# Generate the grid for this timestamp
grid = gen_utils.create_image_grid(images, grid_size)
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid

w_avg = G.mapping.w_avg
else:
new_center, new_center_value = new_center
# We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
if isinstance(new_center_value, int):
new_w_avg = gen_utils.get_w_from_seed(G, device, new_center_value, truncation_psi=1.0) # We want the pure dlatent
w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
truncation_psi=1.0) # We want the pure dlatent
elif isinstance(new_center_value, np.ndarray):
new_w_avg = torch.from_numpy(new_center_value).to(device)
w_avg = torch.from_numpy(new_center_value).to(device)
else:
ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')

def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Do the truncation trick with this new center
w = G.mapping(latents, None)
w = new_w_avg + (w - new_w_avg) * truncation_psi
# Get the images with the new center
images = gen_utils.w_to_img(G, w, noise_mode)
# Generate the grid for this timestamp
grid = gen_utils.create_image_grid(images, grid_size)
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid
# Auxiliary function for moviepy
def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Do the truncation trick (with the global centroid or the new center provided by the user)
w = G.mapping(latents, None)
w = w_avg + (w - w_avg) * truncation_psi
# Get the images with the new center
images = gen_utils.w_to_img(G, w, noise_mode)
# Generate the grid for this timestamp
grid = gen_utils.create_image_grid(images, grid_size)
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid

# Generate video using the respective make_frame function
videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
Expand Down Expand Up @@ -424,7 +395,7 @@ def make_frame(t):
'duration_sec': duration_sec,
'video_fps': fps,
'run_dir': run_dir,
'description': description,
'description': desc,
'compress': compress,
'smoothing_sec': smoothing_sec
}
Expand Down
73 changes: 39 additions & 34 deletions legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,45 @@

def load_network_pkl(f, force_fp16=False):
data = _LegacyUnpickler(f).load()

# Legacy TensorFlow pickle => convert.
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
tf_G, tf_D, tf_Gs = data
G = convert_tf_generator(tf_G)
D = convert_tf_discriminator(tf_D)
G_ema = convert_tf_generator(tf_Gs)
data = dict(G=G, D=D, G_ema=G_ema)

# Add missing fields.
if 'training_set_kwargs' not in data:
data['training_set_kwargs'] = None
if 'augment_pipe' not in data:
data['augment_pipe'] = None

# Validate contents.
assert isinstance(data['G'], torch.nn.Module)
assert isinstance(data['D'], torch.nn.Module)
assert isinstance(data['G_ema'], torch.nn.Module)
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))

# Force FP16.
if force_fp16:
for key in ['G', 'D', 'G_ema']:
old = data[key]
kwargs = copy.deepcopy(old.init_kwargs)
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
fp16_kwargs.num_fp16_res = 4
fp16_kwargs.conv_clamp = 256
if kwargs != old.init_kwargs:
new = type(old)(**kwargs).eval().requires_grad_(False)
misc.copy_params_and_buffers(old, new, require_all=True)
data[key] = new
try:
# Legacy TensorFlow pickle => convert.
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
tf_G, tf_D, tf_Gs = data
G = convert_tf_generator(tf_G)
D = convert_tf_discriminator(tf_D)
G_ema = convert_tf_generator(tf_Gs)
data = dict(G=G, D=D, G_ema=G_ema)

# Add missing fields.
if 'training_set_kwargs' not in data:
data['training_set_kwargs'] = None
if 'augment_pipe' not in data:
data['augment_pipe'] = None

# Validate contents.
if 'G' in data:
assert isinstance(data['G'], torch.nn.Module)
if 'D' in data:
assert isinstance(data['D'], torch.nn.Module)
assert isinstance(data['G_ema'], torch.nn.Module)
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))

# Force FP16.
if force_fp16:
for key in ['G', 'D', 'G_ema']:
old = data[key]
kwargs = copy.deepcopy(old.init_kwargs)
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
fp16_kwargs.num_fp16_res = 4
fp16_kwargs.conv_clamp = 256
if kwargs != old.init_kwargs:
new = type(old)(**kwargs).eval().requires_grad_(False)
misc.copy_params_and_buffers(old, new, require_all=True)
data[key] = new
except KeyError:
# Most likely a StyleGAN-NADA pkl, so pass and return data
pass
return data

#----------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 4765f95

Please sign in to comment.