From bdc624168e6c992b9b8d4bf2d50044278216625c Mon Sep 17 00:00:00 2001 From: Hans Date: Sun, 29 Oct 2023 13:13:39 +0100 Subject: [PATCH] progress bar improvements --- src/animatediff/generate.py | 5 ---- src/animatediff/pipelines/animation.py | 35 ++++++++++++++++++-------- src/animatediff/rife/ffmpeg.py | 2 +- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/animatediff/generate.py b/src/animatediff/generate.py index a78db6aa..b928f9ee 100644 --- a/src/animatediff/generate.py +++ b/src/animatediff/generate.py @@ -1079,7 +1079,6 @@ def save_output( out_file = out_file.with_suffix( f".{codec_extn(output_format)}" ) - logger.info("Creating ffmpeg encoder...") encoder = FfmpegEncoder( frames_dir=frame_dir, out_file=out_file, @@ -1089,7 +1088,6 @@ def save_output( lossless=False, param= output_map["encode_param"] if "encode_param" in output_map else {} ) - logger.info("Encoding interpolated frames with ffmpeg...") result = encoder.encode() logger.debug(f"ffmpeg result: {result}") @@ -1148,9 +1146,6 @@ def preview_callback(i: int, video: torch.Tensor, save_fn: Callable[[torch.Tenso seed_everything(seed) - logger.info(f"{len( region_condi_list )=}") - logger.info(f"{len( region_list )=}") - pipeline_output = pipeline( negative_prompt=n_prompt, num_inference_steps=steps, diff --git a/src/animatediff/pipelines/animation.py b/src/animatediff/pipelines/animation.py index ee7e86a2..6c12f2f9 100644 --- a/src/animatediff/pipelines/animation.py +++ b/src/animatediff/pipelines/animation.py @@ -864,21 +864,34 @@ def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, d - def decode_latents(self, latents: torch.Tensor): + def decode_latents(self, latents: torch.Tensor, progress_bar: tqdm = None): video_length = latents.shape[2] latents = 1 / self.vae.config.scaling_factor * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") - # video = self.vae.decode(latents).sample video = [] + + if progress_bar is not None: # if we have a progress bar, we close it (rich doesn't support multiple progress bars) + task_id = progress_bar._prog.add_task('Decoding latents...', start=True, total=latents.shape[0]) + remove_task = True + else: + progress_bar = self.progress_bar(total=latents.shape[0], desc="Decoding latents...") + task_id = progress_bar._task_id + remove_task = False + for frame_idx in range(latents.shape[0]): - video.append( - self.vae.decode(latents[frame_idx : frame_idx + 1].to(self.vae.device, self.vae.dtype)).sample.cpu() - ) + latent = latents[frame_idx : frame_idx + 1].to(self.vae.device, self.vae.dtype) + video.append(self.vae.decode(latent).sample.cpu()) + progress_bar._prog.update(task_id, advance=1, refresh=True) + video = torch.cat(video) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 video = video.float().numpy() + + if remove_task: + progress_bar._prog.remove_task(task_id) + return video def prepare_extra_step_kwargs(self, generator, eta): @@ -2533,7 +2546,7 @@ def get_controlnet_variable( # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=total_steps) as progress_bar: + with self.progress_bar(total=len(timesteps), desc="Diffusing video...") as progress_bar: for i, t in enumerate(timesteps): stopwatch_start() @@ -2823,7 +2836,6 @@ def sample_to_device( sample ): pred = pred.to(dtype=latents.dtype, device=latents.device) noise_pred[:, :, context] = noise_pred[:, :, context] + pred counter[:, :, context] = counter[:, :, context] + 1 - progress_bar.update() # perform guidance noise_size = prompt_encoder.get_condi_size() @@ -2842,7 +2854,7 @@ def sample_to_device( sample ): ): denoised = latents - noise_pred denoised = self.interpolate_latents(denoised, interpolation_factor, device) - video = torch.from_numpy(self.decode_latents(denoised)) + video = torch.from_numpy(self.decode_latents(denoised, progress_bar=progress_bar)) callback(i, video) # compute the previous noisy sample x_t -> x_t-1 @@ -2886,6 +2898,7 @@ def sample_to_device( sample ): tmp_latent = None stopwatch_stop("LOOP end") + progress_bar.update() if c_ref_enable: self.unload_controlnet_ref_only( @@ -2921,7 +2934,7 @@ def sample_to_device( sample ): return AnimationPipelineOutput(videos=video) - def progress_bar(self, iterable=None, total=None): + def progress_bar(self, iterable=None, total=None, desc=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): @@ -2930,9 +2943,9 @@ def progress_bar(self, iterable=None, total=None): ) if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) + return tqdm(iterable, desc=desc, **self._progress_bar_config) elif total is not None: - return tqdm(total=total, **self._progress_bar_config) + return tqdm(total=total, desc=desc, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") diff --git a/src/animatediff/rife/ffmpeg.py b/src/animatediff/rife/ffmpeg.py index d3b825db..7eee335c 100644 --- a/src/animatediff/rife/ffmpeg.py +++ b/src/animatediff/rife/ffmpeg.py @@ -101,7 +101,7 @@ def __init__( def encode(self) -> tuple: self.input: InputNode = ffmpeg.input( - str(self.frames_dir.resolve().joinpath("%08d.png")), framerate=self.in_fps + str(self.frames_dir.resolve().joinpath("%08d.png")), framerate=self.in_fps, loglevel="warning" ).filter("fps", fps=self.in_fps) match self.codec: case VideoCodec.gif: