-
Notifications
You must be signed in to change notification settings - Fork 0
/
sdstarsampler.py
219 lines (178 loc) · 9.94 KB
/
sdstarsampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import random
import time
import logging
import folder_paths
import comfy.samplers
import comfy.sample
import torch
from nodes import common_ksampler, CLIPTextEncode, KSampler
from comfy.utils import ProgressBar
from comfy_extras.nodes_latent import LatentBatch
# Detail Deamon adapted by https://github.com/muerrilla/sd-webui-detail-daemon
# Detail Deamon adapted by https://github.com/Jonseed/ComfyUI-Detail-Daemon
def parse_string_to_list(value):
"""Parse a string into a list of values, handling both numeric and string inputs."""
if isinstance(value, (int, float)):
return [int(value) if isinstance(value, int) or value.is_integer() else float(value)]
value = value.replace("\n", ",").split(",")
value = [v.strip() for v in value if v.strip()]
value = [int(float(v)) if float(v).is_integer() else float(v) for v in value if v.replace(".", "").isdigit()]
return value if value else [0]
class SDstarsampler:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL", ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT", ),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"default": "euler"}),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"default": "normal"}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
"detail_schedule": ("DETAIL_SCHEDULE",),
}
}
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "LATENT", "DETAIL_SCHEDULE")
RETURN_NAMES = ("model", "positive", "negative", "latent", "detail_schedule")
FUNCTION = "execute"
CATEGORY = "sampling"
def make_detail_schedule(self, steps, detail_amount, detail_start, detail_end, detail_bias, detail_exponent):
start = min(detail_start, detail_end)
mid = start + detail_bias * (detail_end - start)
multipliers = torch.zeros(steps + 1)
start_idx, mid_idx, end_idx = [
int(round(x * steps)) for x in [start, mid, detail_end]
]
if start_idx == mid_idx:
mid_idx = start_idx + 1
if mid_idx == end_idx:
end_idx = mid_idx + 1
# Ensure we don't exceed array bounds
end_idx = min(end_idx, steps)
mid_idx = min(mid_idx, end_idx - 1)
start_idx = min(start_idx, mid_idx)
print(f"Detail Schedule Indices - start: {start_idx}, mid: {mid_idx}, end: {end_idx}")
start_values = torch.linspace(0, 1, mid_idx - start_idx + 1)
start_values = 0.5 * (1 - torch.cos(start_values * torch.pi))
start_values = start_values**detail_exponent
if len(start_values) > 0:
start_values *= detail_amount
print(f"Start values range: {start_values[0]:.4f} to {start_values[-1]:.4f}")
end_values = torch.linspace(1, 0, end_idx - mid_idx + 1)
end_values = 0.5 * (1 - torch.cos(end_values * torch.pi))
end_values = end_values**detail_exponent
if len(end_values) > 0:
end_values *= detail_amount
print(f"End values range: {end_values[0]:.4f} to {end_values[-1]:.4f}")
multipliers[start_idx : mid_idx + 1] = start_values
multipliers[mid_idx : end_idx + 1] = end_values
print(f"Final multipliers shape: {multipliers.shape}, non-zero elements: {torch.count_nonzero(multipliers)}")
print(f"Multipliers range: {multipliers.min():.4f} to {multipliers.max():.4f}")
return multipliers
def get_dd_schedule(self, sigma, sigmas, dd_schedule):
# Find the neighboring sigma values
dists = torch.abs(sigmas - sigma)
idxlow = torch.argmin(dists)
nlow = float(sigmas[idxlow])
# If we're at the last sigma, return the last schedule value
if idxlow == len(sigmas) - 1:
return float(dd_schedule[idxlow])
# Get the high neighbor
idxhigh = idxlow + 1
nhigh = float(sigmas[idxhigh])
# Ratio of how close we are to the high neighbor
ratio = float((sigma - nlow) / (nhigh - nlow))
ratio = max(0.0, min(1.0, ratio)) # Clamp between 0 and 1
# Mix the DD schedule high/low items according to the ratio
result = float(torch.lerp(dd_schedule[idxlow], dd_schedule[idxhigh], torch.tensor(ratio)).item())
# Log every 5th adjustment to avoid spam
if not hasattr(self, '_log_counter'):
self._log_counter = 0
self._log_counter += 1
if self._log_counter % 5 == 0:
print(f"DD Schedule - sigma: {sigma:.4f}, adjustment: {result:.4f}, ratio: {ratio:.4f}")
return result
def execute(self, model, positive, negative, latent, seed, steps, cfg, sampler_name, scheduler, denoise, detail_schedule=None):
print("\n=== Starting SDstarsampler execution ===")
print(f"Parameters: steps={steps}, cfg={cfg}, sampler={sampler_name}, scheduler={scheduler}, denoise={denoise}")
if detail_schedule:
print(f"⭐ Detail Daemon Active with Settings: amount={detail_schedule['detail_amount']:.2f}, start={detail_schedule['detail_start']:.2f}, end={detail_schedule['detail_end']:.2f}, bias={detail_schedule['detail_bias']:.2f}, exponent={detail_schedule['detail_exponent']:.2f}")
# Create a copy of the input latent to avoid modifying it
current_latent = {"samples": latent["samples"].clone()}
# Initialize sampler to get sigmas for detail daemon adjustments
k_sampler = comfy.samplers.KSampler(model, steps=steps, device=latent["samples"].device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
# Create detail schedule
detail_schedule_tensor = torch.tensor(
self.make_detail_schedule(
len(k_sampler.sigmas) - 1,
detail_schedule["detail_amount"],
detail_schedule["detail_start"],
detail_schedule["detail_end"],
detail_schedule["detail_bias"],
detail_schedule["detail_exponent"]
),
dtype=torch.float32,
device=latent["samples"].device
)
# Store original sigmas and create modified ones
original_sigmas = k_sampler.sigmas.clone()
sigmas_cpu = original_sigmas.detach().cpu()
# Store original forward method
if hasattr(model.model, 'diffusion_model'):
original_forward = model.model.diffusion_model.forward
target_module = model.model.diffusion_model
else:
original_forward = model.model.forward
target_module = model.model
def wrapped_forward(x, sigma, **extra_args):
# Get the maximum sigma value for this batch
sigma_float = float(sigma.max().detach().cpu())
# Calculate progress based on log space since sigmas are logarithmically distributed
log_sigma = torch.log(torch.tensor(sigma_float + 1e-10))
log_sigma_max = torch.log(torch.tensor(1000.0))
log_sigma_min = torch.log(torch.tensor(0.1))
progress = 1.0 - (log_sigma - log_sigma_min) / (log_sigma_max - log_sigma_min)
progress = float(progress.clamp(0.0, 1.0))
# Get the schedule index based on progress
schedule_idx = int(progress * (len(detail_schedule_tensor) - 1))
schedule_idx = max(0, min(schedule_idx, len(detail_schedule_tensor) - 1))
# Get base adjustment from schedule
dd_adjustment = float(detail_schedule_tensor[schedule_idx])
# Scale adjustment based on progress and make it stronger
final_adjustment = dd_adjustment * 2.0
# Apply the adjustment
adjustment_scale = max(0.0, min(1.0, progress * 2))
final_adjustment = final_adjustment * adjustment_scale
# Ensure the adjustment is significant enough
if final_adjustment > 0.001:
adjusted_sigma = sigma * max(1e-06, 1.0 - final_adjustment * cfg)
else:
adjusted_sigma = sigma
return original_forward(x, adjusted_sigma, **extra_args)
# Temporarily replace forward method
target_module.forward = wrapped_forward
try:
# Use common_ksampler for sampling
samples = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, current_latent, denoise=denoise)[0]
finally:
# Restore original forward method
target_module.forward = original_forward
else:
# Create a copy of the input latent to avoid modifying it
current_latent = {"samples": latent["samples"].clone()}
# Use common_ksampler for sampling without detail schedule
samples = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, current_latent, denoise=denoise)[0]
return (model, positive, negative, samples, detail_schedule)
# Mapping for ComfyUI to recognize the node
NODE_CLASS_MAPPINGS = {
"SDstarsampler": SDstarsampler
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SDstarsampler": "⭐ StarSampler SD / SDXL"
}