Skip to content

Commit

Permalink
Repair discord bot, update samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
Stax124 committed Jun 19, 2024
1 parent b997205 commit 5aefc71
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 44 deletions.
18 changes: 9 additions & 9 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"python.testing.pytestArgs": ["."],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic",
"python.languageServer": "Pylance",
"rust-analyzer.linkedProjects": ["./manager/Cargo.toml"],
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
}
"python.testing.pytestArgs": ["."],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic",
"python.languageServer": "Pylance",
"rust-analyzer.linkedProjects": ["./manager/Cargo.toml"],
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
}
}
11 changes: 7 additions & 4 deletions bot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass, field

from dataclasses_json.api import DataClassJsonMixin
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers

from .types import Samplers

logger = logging.getLogger(__name__)

Expand All @@ -26,10 +27,12 @@ class Config(DataClassJsonMixin):
default_height: int = 512
default_count: int = 1
default_steps: int = 30
default_scheduler: KarrasDiffusionSchedulers = (
KarrasDiffusionSchedulers.DPMSolverMultistepScheduler
)
default_scheduler: Samplers = Samplers.DPMPP_2M
default_cfg: float = 7.0

default_adetailer: bool = False
default_hires_fix: bool = False

default_verbose: bool = False


Expand Down
34 changes: 21 additions & 13 deletions bot/helper.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import asyncio
import difflib
from typing import Any, Dict, List, Literal
from typing import TYPE_CHECKING, Any, Dict, List, Literal

import aiohttp
from aiohttp import ClientSession

from bot import shared as shared_bot
from core.types import ModelResponse

if TYPE_CHECKING:
from ..core.types import ModelResponse


async def find_closest_model(model: str):
"""Find the closest model to the one provided"""
models, _ = await shared_bot.models.cached_loaded_models()
return difflib.get_close_matches(model, models, n=1, cutoff=0.1)[0]
res = difflib.get_close_matches(model, models, n=1, cutoff=0.1)

if not res:
return None
return res[0]


async def inference_call(
Expand Down Expand Up @@ -41,7 +49,7 @@ async def call():
return status, response


async def get_available_models():
async def get_available_models() -> tuple[list[ModelResponse], int]:
"List all available models"

from core import shared
Expand All @@ -51,15 +59,15 @@ async def get_available_models():
f"http://localhost:{shared.api_port}/api/models/available"
) as response:
status = response.status
data: List[Dict[str, Any]] = await response.json()
rec: List[Dict[str, Any]] = await response.json()
data = [ModelResponse(**i) for i in rec]
models = [
i["name"]
i
for i in filter(
lambda model: (
model["valid"] is True
model.valid is True
and (
model["backend"] == "PyTorch"
or model["backend"] == "AITemplate"
model.backend == "PyTorch" or model.backend == "AITemplate"
)
),
data,
Expand All @@ -79,15 +87,15 @@ async def get_loaded_models():
f"http://localhost:{shared.api_port}/api/models/loaded"
) as response:
status = response.status
data: List[Dict[str, Any]] = await response.json()
rec: List[Dict[str, Any]] = await response.json()
data = [ModelResponse(**i) for i in rec]
models = [
i["name"]
i.path
for i in filter(
lambda model: (
model["valid"] is True
model.valid is True
and (
model["backend"] == "PyTorch"
or model["backend"] == "AITemplate"
model.backend == "PyTorch" or model.backend == "AITemplate"
)
),
data,
Expand Down
31 changes: 22 additions & 9 deletions bot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from bot import shared as shared_bot
from bot.helper import get_available_models, get_loaded_models
from core import shared
from core.types import InferenceBackend

if TYPE_CHECKING:
from bot.bot import ModularBot
Expand All @@ -26,6 +25,8 @@ def __init__(self, bot: "ModularBot") -> None:
async def loaded_models(self, ctx: Context) -> None:
"Show models loaded in the API"

await ctx.defer()

async with ClientSession() as session:
async with session.get(
f"http://localhost:{shared.api_port}/api/models/loaded"
Expand All @@ -52,6 +53,8 @@ async def loaded_models(self, ctx: Context) -> None:
async def available_models(self, ctx: Context) -> None:
"List all available models"

await ctx.defer()

available_models, status = await get_available_models()
shared_bot.models.set_cached_available_models(available_models)
available_models, status = await shared_bot.models.cached_available_models()
Expand All @@ -61,27 +64,37 @@ async def available_models(self, ctx: Context) -> None:

if status == 200:
await ctx.send(
"Available models:\n`{}`".format("\n".join(available_models))
"Available models:\n`{}`".format(
"\n".join([i.path for i in available_models])
if available_models
else "No models available"
)
)
else:
await ctx.send(f"Error: {status}")

@commands.hybrid_command(name="load")
@commands.has_permissions(administrator=True)
async def load_model(
self,
ctx: Context,
model: str,
backend: InferenceBackend = "PyTorch",
) -> None:
async def load_model(self, ctx: Context, model: str) -> None:
"Load a model"

message = await ctx.send(f"Loading model {model}...")

available_models, status = await get_available_models()
resolved_model = next((i for i in available_models if i.path == model), None)

if not resolved_model:
await message.edit(content=f"Model not found: {model}")
return

async with ClientSession() as session:
async with session.post(
f"http://localhost:{shared.api_port}/api/models/load",
params={"model": model, "backend": backend},
params={
"model": resolved_model.path,
"backend": resolved_model.backend,
"type": resolved_model.type,
},
) as response:
status = response.status
response = await response.json()
Expand Down
48 changes: 40 additions & 8 deletions bot/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from uuid import uuid4

import discord
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
from discord import File
from discord.ext import commands
from discord.ext.commands import Cog, Context
Expand All @@ -14,6 +13,8 @@
from bot.shared import config
from core.utils import convert_base64_to_bytes

from .types import Samplers

if TYPE_CHECKING:
from bot.bot import ModularBot

Expand Down Expand Up @@ -41,7 +42,9 @@ async def dream_unsupported(
height: int = config.default_height,
count: int = config.default_count,
seed: Optional[int] = None,
scheduler: KarrasDiffusionSchedulers = config.default_scheduler,
adetailer: bool = config.default_adetailer,
hires_fix: bool = config.default_hires_fix,
scheduler: Samplers = config.default_scheduler,
verbose: bool = config.default_verbose,
):
"Generate an image from prompt"
Expand All @@ -62,12 +65,38 @@ async def dream_unsupported(
prompt = prompt + config.extra_prompt
negative_prompt = negative_prompt + config.extra_negative_prompt

try:
model = await find_closest_model(model)
except IndexError:
resolved_model = await find_closest_model(model)
if not resolved_model:
await ctx.send(f"No loaded model that is close to `{model}` found")
return

flags = {}
if adetailer:
flags["adetailer"] = {
"cfg_scale": 7,
"mask_blur": 4,
"mask_dilation": 4,
"mask_padding": 32,
"iterations": 1,
"upscale": 2,
"sampler": "dpmpp_2m",
"strength": 0.45,
"seed": 0,
"self_attention_scale": 0,
"sigmas": "exponential",
"steps": 30,
}
if hires_fix:
flags["highres_fix"] = {
"mode": "latent",
"image_upscaler": "RealESRGAN_x4plus_anime_6B",
"scale": 2,
"latent_scale_mode": "bislerp",
"strength": 0.7,
"steps": 30,
"antialiased": False,
}

payload = {
"data": {
"prompt": prompt,
Expand All @@ -82,11 +111,12 @@ async def dream_unsupported(
"batch_count": count,
"scheduler": scheduler.value,
},
"model": model,
"flags": flags,
"model": resolved_model,
"save_image": False,
}

message = await ctx.send(f"Generating image with `{model}`...")
message = await ctx.send(f"Generating image with `{resolved_model}`...")

try:
status, response = await inference_call(payload=payload)
Expand All @@ -100,12 +130,14 @@ async def dream_unsupported(
)
embed.add_field(name="Seed", value=seed)
embed.add_field(name="Time", value=f"{response.get('time'):.2f}s")
embed.add_field(name="Model", value=model)
embed.add_field(name="Model", value=resolved_model)
embed.add_field(name="Negative Prompt", value=negative_prompt)
embed.add_field(name="Guidance Scale", value=guidance_scale)
embed.add_field(name="Steps", value=steps)
embed.add_field(name="Width", value=width)
embed.add_field(name="Height", value=height)
embed.add_field(name="ADetailer", value=adetailer)
embed.add_field(name="Hires Fix", value=hires_fix)

await message.edit(embed=embed)

Expand Down
21 changes: 21 additions & 0 deletions bot/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from enum import Enum


class Samplers(Enum):
EULER_A = "euler_a"
EULER = "euler"
LMS = "lms"
HEUN = "heun"
HEUNPP = "heunpp"
DPM_FAST = "dpm_fast"
DPM_ADAPTIVE = "dpm_adaptive"
DPM2 = "dpm2"
DPM2_A = "dpm2_a"
DPMPP_2S_A = "dpmpp_2s_a"
DPMPP_2M = "dpmpp_2m"
DPMPP_2M_SHARP = "dpmpp_2m_sharp"
DPMPP_SDE = "dpmpp_sde"
DPMPP_2M_SDE = "dpmpp_2m_sde"
DPMPP_3M_SDE = "dpmpp_3m_sde"
UNIPC_MULTISTEP = "unipc_multistep"
RESTART = "restart"
2 changes: 1 addition & 1 deletion requirements/bot.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
discord.py==2.3.0
discord.py==2.3.2

0 comments on commit 5aefc71

Please sign in to comment.