Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

progress bar improvements #132

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/animatediff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,6 @@ def save_output(

out_file = out_file.with_suffix( f".{codec_extn(output_format)}" )

logger.info("Creating ffmpeg encoder...")
encoder = FfmpegEncoder(
frames_dir=frame_dir,
out_file=out_file,
Expand All @@ -1089,7 +1088,6 @@ def save_output(
lossless=False,
param= output_map["encode_param"] if "encode_param" in output_map else {}
)
logger.info("Encoding interpolated frames with ffmpeg...")
result = encoder.encode()
logger.debug(f"ffmpeg result: {result}")

Expand Down Expand Up @@ -1148,9 +1146,6 @@ def preview_callback(i: int, video: torch.Tensor, save_fn: Callable[[torch.Tenso

seed_everything(seed)

logger.info(f"{len( region_condi_list )=}")
logger.info(f"{len( region_list )=}")

pipeline_output = pipeline(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, don't delete the logs left in the code.
Most of the time, they are still being debugged or are left to investigate behavior that could be improved.

negative_prompt=n_prompt,
num_inference_steps=steps,
Expand Down
35 changes: 24 additions & 11 deletions src/animatediff/pipelines/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,21 +864,34 @@ def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, d



def decode_latents(self, latents: torch.Tensor):
def decode_latents(self, latents: torch.Tensor, progress_bar: tqdm = None):
video_length = latents.shape[2]
latents = 1 / self.vae.config.scaling_factor * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
# video = self.vae.decode(latents).sample
video = []

if progress_bar is not None: # if we have a progress bar, we close it (rich doesn't support multiple progress bars)
task_id = progress_bar._prog.add_task('Decoding latents...', start=True, total=latents.shape[0])
remove_task = True
else:
progress_bar = self.progress_bar(total=latents.shape[0], desc="Decoding latents...")
task_id = progress_bar._task_id
remove_task = False

for frame_idx in range(latents.shape[0]):
video.append(
self.vae.decode(latents[frame_idx : frame_idx + 1].to(self.vae.device, self.vae.dtype)).sample.cpu()
)
latent = latents[frame_idx : frame_idx + 1].to(self.vae.device, self.vae.dtype)
video.append(self.vae.decode(latent).sample.cpu())
progress_bar._prog.update(task_id, advance=1, refresh=True)

video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.float().numpy()

if remove_task:
progress_bar._prog.remove_task(task_id)

return video

def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -2533,7 +2546,7 @@ def get_controlnet_variable(

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=total_steps) as progress_bar:
with self.progress_bar(total=len(timesteps), desc="Diffusing video...") as progress_bar:
for i, t in enumerate(timesteps):
stopwatch_start()

Expand Down Expand Up @@ -2823,7 +2836,6 @@ def sample_to_device( sample ):
pred = pred.to(dtype=latents.dtype, device=latents.device)
noise_pred[:, :, context] = noise_pred[:, :, context] + pred
counter[:, :, context] = counter[:, :, context] + 1
progress_bar.update()

# perform guidance
noise_size = prompt_encoder.get_condi_size()
Expand All @@ -2842,7 +2854,7 @@ def sample_to_device( sample ):
):
denoised = latents - noise_pred
denoised = self.interpolate_latents(denoised, interpolation_factor, device)
video = torch.from_numpy(self.decode_latents(denoised))
video = torch.from_numpy(self.decode_latents(denoised, progress_bar=progress_bar))
callback(i, video)

# compute the previous noisy sample x_t -> x_t-1
Expand Down Expand Up @@ -2886,6 +2898,7 @@ def sample_to_device( sample ):
tmp_latent = None

stopwatch_stop("LOOP end")
progress_bar.update()

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not want to change the units of progress

if c_ref_enable:
self.unload_controlnet_ref_only(
Expand Down Expand Up @@ -2921,7 +2934,7 @@ def sample_to_device( sample ):

return AnimationPipelineOutput(videos=video)

def progress_bar(self, iterable=None, total=None):
def progress_bar(self, iterable=None, total=None, desc=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
Expand All @@ -2930,9 +2943,9 @@ def progress_bar(self, iterable=None, total=None):
)

if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
return tqdm(iterable, desc=desc, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
return tqdm(total=total, desc=desc, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

Expand Down
2 changes: 1 addition & 1 deletion src/animatediff/rife/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(

def encode(self) -> tuple:
self.input: InputNode = ffmpeg.input(
str(self.frames_dir.resolve().joinpath("%08d.png")), framerate=self.in_fps
str(self.frames_dir.resolve().joinpath("%08d.png")), framerate=self.in_fps, loglevel="warning"
).filter("fps", fps=self.in_fps)
match self.codec:
case VideoCodec.gif:
Expand Down