Skip to content

Commit

Permalink
REF: refactor controlnet for image model (#2346)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Oct 11, 2024
1 parent 92fc84b commit 91493d5
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 74 deletions.
8 changes: 4 additions & 4 deletions examples/StableDiffusionControlNet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"from diffusers.utils import load_image\n",
"\n",
"mlsd = MLSDdetector.from_pretrained(\"lllyasviel/ControlNet\")\n",
"image_path = os.path.expanduser(\"~/draft.png\")\n",
"image_path = os.path.expanduser(\"draft.png\")\n",
"image = load_image(image_path)\n",
"image = mlsd(image)\n",
"image"
Expand Down Expand Up @@ -181,7 +181,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -195,9 +195,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
11 changes: 6 additions & 5 deletions xinference/model/image/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,19 @@ def create_image_model_instance(
for name in controlnet:
for cn_model_spec in model_spec.controlnet:
if cn_model_spec.model_name == name:
if not model_path:
model_path = cache(cn_model_spec)
controlnet_model_paths.append(model_path)
controlnet_model_path = cache(cn_model_spec)
controlnet_model_paths.append(controlnet_model_path)
break
else:
raise ValueError(
f"controlnet `{name}` is not supported for model `{model_name}`."
)
if len(controlnet_model_paths) == 1:
kwargs["controlnet"] = controlnet_model_paths[0]
kwargs["controlnet"] = (controlnet[0], controlnet_model_paths[0])
else:
kwargs["controlnet"] = controlnet_model_paths
kwargs["controlnet"] = [
(n, path) for n, path in zip(controlnet, controlnet_model_paths)
]
if not model_path:
model_path = cache(model_spec)
if peft_model_config is not None:
Expand Down
37 changes: 33 additions & 4 deletions xinference/model/image/sdapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import io
import warnings

from PIL import Image
from PIL import Image, ImageOps


class SDAPIToDiffusersConverter:
Expand All @@ -30,7 +31,7 @@ class SDAPIToDiffusersConverter:
txt2img_arg_mapping = {
"steps": "num_inference_steps",
"cfg_scale": "guidance_scale",
# "denoising_strength": "strength",
"denoising_strength": "strength",
}
img2img_identical_args = {
"prompt",
Expand All @@ -42,9 +43,11 @@ class SDAPIToDiffusersConverter:
}
img2img_arg_mapping = {
"init_images": "image",
"mask": "mask_image",
"steps": "num_inference_steps",
"cfg_scale": "guidance_scale",
"denoising_strength": "strength",
"inpaint_full_res_padding": "padding_mask_crop",
}

@staticmethod
Expand Down Expand Up @@ -121,12 +124,38 @@ def _decode_b64_img(img_str: str) -> Image:

def img2img(self, **kwargs):
init_images = kwargs.pop("init_images", [])
kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images]
kwargs["init_images"] = init_images = [
self._decode_b64_img(i) for i in init_images
]
if len(init_images) == 1:
kwargs["init_images"] = init_images[0]
mask_image = kwargs.pop("mask", None)
if mask_image:
if kwargs.pop("inpainting_mask_invert"):
mask_image = ImageOps.invert(mask_image)

kwargs["mask"] = self._decode_b64_img(mask_image)

# process inpaint_full_res and inpaint_full_res_padding
if kwargs.pop("inpaint_full_res", None):
kwargs["inpaint_full_res_padding"] = kwargs.pop(
"inpaint_full_res_padding", 0
)
else:
# inpaint_full_res_padding is turned `into padding_mask_crop`
# in diffusers, if padding_mask_crop is passed, it will do inpaint_full_res
# so if not inpaint_full_rs, we need to pop this option
kwargs.pop("inpaint_full_res_padding", None)

clip_skip = kwargs.get("override_settings", {}).get("clip_skip")
converted_kwargs = self._check_kwargs("img2img", kwargs)
if clip_skip:
converted_kwargs["clip_skip"] = clip_skip
result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore

if not converted_kwargs.get("mask_image"):
result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
else:
result = self.inpainting(response_format="b64_json", **converted_kwargs) # type: ignore

# convert to SD API result
return {
Expand Down
Loading

0 comments on commit 91493d5

Please sign in to comment.