From f82c9964c74e094d64e237c74048956af2511719 Mon Sep 17 00:00:00 2001 From: zsxkib Date: Tue, 17 Oct 2023 12:42:12 +0000 Subject: [PATCH] Replicate init + working MVP prompt walking working Replicate (Internal Code Review Changes) Remove nb.ipynb and update .gitignore downloading custom weights from civit is now using pget (not wget) --- .dockerignore | 17 +++ .gitignore | 6 + README.md | 2 + cog.yaml | 69 ++++++++++ cog_download_models.sh | 14 ++ predict.py | 297 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 405 insertions(+) create mode 100644 .dockerignore create mode 100644 cog.yaml create mode 100755 cog_download_models.sh create mode 100644 predict.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..4522d57d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,17 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Exclude Git files +.git +.github +.gitignore + +# Exclude Python cache files +__pycache__ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude Python virtual environment +/venv diff --git a/.gitignore b/.gitignore index 9ccdb5e5..fbc9d4fc 100644 --- a/.gitignore +++ b/.gitignore @@ -239,3 +239,9 @@ src/animatediff/_version.py # envrc .env* !.envrc.example + +# Cog +output.gif +output.mp4 +nb.ipynb +.cog/ diff --git a/README.md b/README.md index ddc2fb0d..0795b076 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ [AnimateDiff](https://github.com/guoyww/AnimateDiff) with prompt travel + [ControlNet](https://github.com/lllyasviel/ControlNet) + [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) +[![Replicate](https://replicate.com/zsxkib/animatediff-prompt-travel/badge)](https://replicate.com/zsxkib/animatediff-prompt-travel) + I added a experimental feature to animatediff-cli to change the prompt in the middle of the frame. It seems to work surprisingly well! diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 00000000..15e16f9b --- /dev/null +++ b/cog.yaml @@ -0,0 +1,69 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + # set to true if your model requires a GPU + gpu: true + + # a list of ubuntu apt packages to install + system_packages: + - "libgl1-mesa-glx" + # - "libglib2.0-0" + + # python version in the form '3.10' + python_version: "3.10" + + # a list of packages in the format == + python_packages: + - "torch==2.0.1" + - "torchvision==0.15.2" + - "torchaudio==2.0.2" + - "accelerate>=0.20.3" + - "colorama>=0.4.3,<0.5.0" + - "cmake>=3.25.0" + - "diffusers==0.18.2" + - "einops>=0.6.1" + - "gdown>=4.6.6" + - "ninja>=1.11.0" + - "numpy>=1.22.4" + - "omegaconf>=2.3.0" + - "pillow>=9.4.0,<10.0.0" + - "pydantic>=1.10.0,<2.0.0" + - "rich>=13.0.0,<14.0.0" + - "safetensors>=0.3.1" + - "sentencepiece>=0.1.99" + - "shellingham>=1.5.0,<2.0.0" + - "torch>=2.0.0,<2.2.0" + - "torchaudio" + - "torchvision" + - "transformers>=4.30.2" + - "typer>=0.9.0,<1.0.0" + - "controlnet_aux" + - "matplotlib" + - "ffmpeg-python>=0.2.0" + - "black>=22.3.0" + - "ruff>=0.0.234" + - "setuptools-scm>=7.0.0" + - "pre-commit>=3.3.0" + - "ipython" + - "xformers>=0.0.21" + - "onnxruntime-gpu" + - "pandas" + - "segment-anything-hq==0.3" + - "groundingdino-py==0.4.0" + - "gitpython" + - "mediapipe" + - "xformers" + - "git+https://github.com/s9roll7/animatediff-cli-prompt-travel.git" + - "git+https://github.com/s9roll7/animatediff-cli-prompt-travel.git#egg=animatediff[stylize]" + - "git+https://github.com/s9roll7/animatediff-cli-prompt-travel.git#egg=animatediff[dwpose]" + - "git+https://github.com/s9roll7/animatediff-cli-prompt-travel.git#egg=animatediff[stylize_mask]" + + run: + - pip install --upgrade pip + - apt-get update && apt-get install -y ffmpeg + - pip install imageio[ffmpeg] + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" diff --git a/cog_download_models.sh b/cog_download_models.sh new file mode 100755 index 00000000..d9bd37b0 --- /dev/null +++ b/cog_download_models.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +mkdir -p data/share/Stable-diffusion/ + +pget https://civitai.com/api/download/models/78775 data/share/Stable-diffusion/toonyou_beta3.safetensors || true +pget https://civitai.com/api/download/models/72396 data/share/Stable-diffusion/lyriel_v16.safetensors || true +pget https://civitai.com/api/download/models/71009 data/share/Stable-diffusion/rcnzCartoon3d_v10.safetensors || true +pget https://civitai.com/api/download/models/79068 data/share/Stable-diffusion/majicmixRealistic_v5Preview.safetensors || true +pget https://civitai.com/api/download/models/29460 data/share/Stable-diffusion/realisticVisionV40_v20Novae.safetensors || true + +# Download Motion_Module models +wget -O data/models/motion-module/mm_sd_v14.ckpt https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v14.ckpt || true +wget -O data/models/motion-module/mm_sd_v15.ckpt https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15.ckpt || true +wget -O data/models/motion-module/mm_sd_v15_v2.ckpt https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt || true diff --git a/predict.py b/predict.py new file mode 100644 index 00000000..5a3dba0d --- /dev/null +++ b/predict.py @@ -0,0 +1,297 @@ +# Prediction interface for Cog ⚙️ +# https://github.com/replicate/cog/blob/main/docs/python.md + +import os +import re +import subprocess +from cog import BasePredictor, Input, Path + +FAKE_PROMPT_TRAVEL_JSON = """ +{{ + "name": "sample", + "path": "{dreambooth_path}", + "motion_module": "models/motion-module/mm_sd_v15_v2.ckpt", + "compile": false, + "seed": [ + {seed} + ], + "scheduler": "{scheduler}", + "steps": {steps}, + "guidance_scale": {guidance_scale}, + "clip_skip": {clip_skip}, + "prompt_fixed_ratio": {prompt_fixed_ratio}, + "head_prompt": "{head_prompt}", + "prompt_map": {{ + {prompt_map} + }}, + "tail_prompt": "{tail_prompt}", + "n_prompt": [ + "{negative_prompt}" + ], + "output":{{ + "format" : "{output_format}", + "fps" : {playback_frames_per_second}, + "encode_param":{{ + "crf": 10 + }} + }} +}} +""" + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + pass + + def download_custom_model(self, custom_base_model_url: str): + # Validate the custom_base_model_url to ensure it's from "civitai.com" + if not re.match(r"^https://civitai\.com/api/download/models/\d+$", custom_base_model_url): + raise ValueError( + "Invalid URL. Only downloads from 'https://civitai.com/api/download/models/' are allowed." + ) + + cmd = ["pget", custom_base_model_url, "data/share/Stable-diffusion/custom.safetensors"] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + stdout_output, stderr_output = process.communicate() + + print("Output from wget command:") + print(stdout_output) + if stderr_output: + print("Errors from wget command:") + print(stderr_output) + + if process.returncode: + raise ValueError(f"Failed to download the custom model. Wget returned code: {process.returncode}") + return "custom" + + def transform_prompt_map(self, prompt_map_string: str): + """ + Transform the given prompt_map string into a formatted string suitable for JSON injection. + + Parameters + ---------- + prompt_map_string : str + A string containing animation prompts in the format 'frame number : prompt at this frame', + separated by '|'. Colons inside the prompt description are allowed. + + Returns + ------- + str + A formatted string where each prompt is represented as '"frame": "description"'. + """ + + segments = prompt_map_string.split("|") + + formatted_segments = [] + for segment in segments: + frame, prompt = segment.split(":", 1) + frame = frame.strip() + prompt = prompt.strip() + + formatted_segment = f'"{frame}": "{prompt}"' + formatted_segments.append(formatted_segment) + + return ", ".join(formatted_segments) + + def predict( + self, + head_prompt: str = Input( + description="Primary animation prompt. If a prompt map is provided, this will be prefixed at the start of every individual prompt in the map", + default="masterpiece, best quality, a haunting and detailed depiction of a ship at sea, battered by waves, ominous,((dark clouds:1.3)),distant lightning, rough seas, rain, silhouette of the ship against the stormy sky", + ), + prompt_map: str = Input( + description="Prompt for changes in animation. Provide 'frame number : prompt at this frame', separate different prompts with '|'. Make sure the frame number does not exceed the length of video (frames)", + default="0: ship steadily moving,((waves crashing against the ship:1.0)) | 32: (((lightning strikes))), distant thunder, ship rocked by waves | 64: ship silhouette,(((heavy rain))),wind howling, waves rising higher | 96: ship navigating through the storm, rain easing off", + ), + tail_prompt: str = Input( + description="Additional prompt that will be appended at the end of the main prompt or individual prompts in the map", + default="dark horizon, flashes of lightning illuminating the ship, sailors working hard, ship's lanterns flickering, eerie, mysterious, sails flapping loudly, stormy atmosphere", + ), + negative_prompt: str = Input( + default="(worst quality, low quality:1.4), black and white, b&w, sunny, clear skies, calm seas, beach, daytime, ((bright colors)), cartoonish, modern ships, sketchy, unfinished, modern buildings, trees, island", + ), + frames: int = Input( + description="Length of the video in frames (playback is at 8 fps e.g. 16 frames @ 8 fps is 2 seconds)", + default=128, + ge=1, + le=1024, + ), + width: int = Input( + description="Width of generated video in pixels, must be divisable by 8", + default=256, + ge=64, + le=2160, + ), + height: int = Input( + description="Height of generated video in pixels, must be divisable by 8", + default=384, + ge=64, + le=2160, + ), + base_model: str = Input( + description="Choose the base model for animation generation. If 'CUSTOM' is selected, provide a custom model URL in the next parameter", + default="majicmixRealistic_v5Preview", + choices=[ + "realisticVisionV20_v20", + "lyriel_v16", + "majicmixRealistic_v5Preview", + "rcnzCartoon3d_v10", + "toonyou_beta3", + "CUSTOM", + ], + ), + custom_base_model_url: str = Input( + description="Only used when base model is set to 'CUSTOM'. URL of the custom model to download if 'CUSTOM' is selected in the base model. Only downloads from 'https://civitai.com/api/download/models/' are allowed", + default="", + ), + prompt_fixed_ratio: float = Input( + description="Defines the ratio of adherence to the fixed part of the prompt versus the dynamic part (from prompt map). Value should be between 0 (only dynamic) to 1 (only fixed).", + default=0.5, + ge=0, + le=1, + ), + scheduler: str = Input( + description="Diffusion scheduler", + default="k_dpmpp_sde", + choices=[ + "ddim", + "pndm", + "heun", + "unipc", + "euler", + "euler_a", + "lms", + "k_lms", + "dpm_2", + "k_dpm_2", + "dpm_2_a", + "k_dpm_2_a", + "dpmpp_2m", + "k_dpmpp_2m", + "dpmpp_sde", + "k_dpmpp_sde", + "dpmpp_2m_sde", + "k_dpmpp_2m_sde", + ], + ), + steps: int = Input( + description="Number of inference steps", + ge=1, + le=100, + default=25, + ), + guidance_scale: float = Input( + description="Guidance Scale. How closely do we want to adhere to the prompt and its contents", + ge=0.0, + le=20, + default=7.5, + ), + clip_skip: int = Input( + description="Skip the last N-1 layers of the CLIP text encoder (lower values follow prompt more closely)", + default=2, + ge=1, + le=6, + ), + context: int = Input( + description="Number of frames to condition on (default: max of or 32). max for motion module v1 is 24", + default=16, + ge=1, + le=32, + ), + output_format: str = Input( + description="Output format of the video. Can be 'mp4' or 'gif'", + default="mp4", + choices=["mp4", "gif"], + ), + playback_frames_per_second: int = Input(default=8, ge=1, le=60), + seed: int = Input( + description="Seed for different images and reproducibility. Use -1 to randomise seed", + default=-1, + ), + ) -> Path: + """ + Run a single prediction on the model + NOTE: lora_map, motion_lora_map, and controlnets are NOT supported (cut scope) + """ + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if base_model.upper() == "CUSTOM": + base_model = self.download_custom_model(custom_base_model_url) + + prompt_travel_json = FAKE_PROMPT_TRAVEL_JSON.format( + dreambooth_path=f"share/Stable-diffusion/{base_model}.safetensors", + output_format=output_format, + seed=seed, + steps=steps, + guidance_scale=guidance_scale, + prompt_fixed_ratio=prompt_fixed_ratio, + head_prompt=head_prompt, + tail_prompt=tail_prompt, + negative_prompt=negative_prompt, + playback_frames_per_second=playback_frames_per_second, + prompt_map=self.transform_prompt_map(prompt_map), + scheduler=scheduler, + clip_skip=clip_skip, + ) + + print(f"{'-'*80}") + print(prompt_travel_json) + print(f"{'-'*80}") + + file_path = "config/prompts/custom_prompt_travel.json" + directory = os.path.dirname(file_path) + if not os.path.exists(directory): + os.makedirs(directory) + with open(file_path, "w") as file: + file.write(prompt_travel_json) + + cmd = [ + "animatediff", + "generate", + "-c", + str(file_path), + "-W", + str(width), + "-H", + str(height), + "-L", + str(frames), + "-C", + str(context), + ] + print(f"Running command: {' '.join(cmd)}") + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + ( + stdout_output, + stderr_output, + ) = process.communicate() + + print(stdout_output) + if stderr_output: + print(f"Error: {stderr_output}") + + if process.returncode: + raise ValueError(f"Command exited with code: {process.returncode}") + + print("Identifying the GIF path from the generated outputs...") + recent_dir = max( + ( + os.path.join("output", d) + for d in os.listdir("output") + if os.path.isdir(os.path.join("output", d)) + ), + key=os.path.getmtime, + ) + + print(f"Identified directory: {recent_dir}") + media_files = [f for f in os.listdir(recent_dir) if f.endswith((".gif", ".mp4"))] + + if not media_files: + raise ValueError(f"No GIF or MP4 files found in directory: {recent_dir}") + + media_path = os.path.join(recent_dir, media_files[0]) + print(f"Identified Media Path: {media_path}") + + return Path(media_path)