Skip to content

Commit

Permalink
Updating truss based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Het Trivedi committed Nov 13, 2023
1 parent e4a2dbb commit 87b079f
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 189 deletions.
63 changes: 36 additions & 27 deletions comfyui-truss/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This truss is designed to allow ComfyUI users to easily convert their workflows

## Exporting the ComfyUI workflow

This Truss is designed to run a Comfy UI workflow that is in the form of a JSON file. During model inference the entire JSON file containing the workflow will get passed as a request to the model.
This Truss is designed to run a Comfy UI workflow that is in the form of a JSON file.

Inside ComfyUI, you can save workflows as a JSON file. However, the regular JSON format that ComfyUI uses will not work. Instead, the workflow has to be saved in the API format. Here is how you can do that:

Expand Down Expand Up @@ -47,27 +47,9 @@ For your ComfyUI workflow, you probably used one or more models. Those models ne

In this case, I have 2 models: SDXL and a controlnet. Each model needs to have 2 things, `url` and `path`. The `url` is the location for downloading the model. The `path` is where this model will get stored inside the Truss. For the path, follow the same guidelines as used in ComfyUI. Models should get stored inside `checkpoints`, controlnets should be stored inside `controlnet`, etc.

We also need to place the JSON workflow from step 1 inside the data directory. In the data directory create an open a file called `data/comfy_ui_workflow.json`. Copy and paste the entire JSON workflow that we saved in step 1 into this file.

## Deployment

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `comfyui-truss` as your working directory, you can deploy the model with:

```sh
truss push
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## Model Inference

The main thing we need for inference is the JSON workflow we exported in step 1. Inside the JSON workflow file, there might be some inputs such as the positive prompt or negative prompt that are hard coded. We want these inputs to be dynamically sent to the model, so we can use handlebars to templatize them. Here is an example of a JSON workflow with templatized inputs:
In the JSON workflow file, there might be some inputs such as the positive prompt or negative prompt that are hard coded. We want these inputs to be dynamically sent to the model, so we can use handlebars to templatize them. Here is an example of a JSON workflow with templatized inputs:

```json
{
Expand Down Expand Up @@ -131,7 +113,30 @@ The main thing we need for inference is the JSON workflow we exported in step 1.
}
```

This is not the entire JSON workflow file, but the nodes 6, 7, and 11 accept variable inputs. You can do this by using the handlebars format of `{{variable_name_here}}`. Inside a seperate JSON object we can define the values for these variables such as:
This is not the entire JSON workflow file, but the nodes 6, 7, and 11 accept variable inputs. You can do this by using the handlebars format of `{{variable_name_here}}`.

Once you have both the `data/comfy_ui_workflow.json` and `data/model.json` set up correctly we can begin deployment.

## Deployment

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `comfyui-truss` as your working directory, you can deploy the model with:

```sh
truss push
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## Model Inference

When an inference request is sent to the Truss, the `comfy_ui_workflow.json` in the data directory is sent to ComfyUI. If you recall, there are some templatized variables inside that json file using the handlebars format of `{{variable_name_here}}`. During inference time, we can dynamically pass in those templatized variables to our Truss prediction request like so:

```python
values = {
Expand All @@ -141,13 +146,12 @@ values = {
}
```

Just be sure that the variable names in the workflow template match the names inside the values object.
Just be sure that the variable names in the `comfy_ui_workflow.json` template match the names inside the values object.

Here is a complete example of how you make a prediction request to your truss in python:

```python
headers = {"Authorization": f"Api-Key YOUR-BASETEN-API-KEY-HERE"}

This is the content of `data/comfy_ui_workflow.json`:
```json
sdxl_controlnet_workflow = {
"3": {
"inputs": {
Expand Down Expand Up @@ -284,14 +288,19 @@ sdxl_controlnet_workflow = {
"class_type": "PreviewImage"
}
}
```

Here is the actual API request sent to Truss:
```python
headers = {"Authorization": f"Api-Key YOUR-BASETEN-API-KEY-HERE"}

values = {
"positive_prompt": "An igloo on a snowy day, 4k, hd",
"negative_prompt": "blurry, text, low quality",
"controlnet_image": "https://storage.googleapis.com/logos-bucket-01/baseten_logo.png"
}

data = {"json_workflow": sdxl_controlnet_workflow, "values": values}
data = {"workflow_values": values}
res = requests.post("https://model-{MODEL_ID}.api.baseten.co/development/predict", headers=headers, json=data)
res = res.json()
model_output = res.get("result")
Expand Down
6 changes: 3 additions & 3 deletions comfyui-truss/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
environment_variables: {}
external_package_dirs: []
model_metadata: {}
description: Deploy a ComfyUI workflow as a Truss
model_metadata:
example_model_input: {"workflow_values": {"positive_prompt": "An igloo on a snowy day, 4k, hd", "negative_prompt": "blurry, text, low quality", "controlnet_image": "https://storage.googleapis.com/logos-bucket-01/baseten_logo.png"}}
model_name: comfyui-truss
python_version: py310
requirements:
Expand All @@ -22,5 +23,4 @@ resources:
memory: 14Gi
use_gpu: true
accelerator: T4
secrets: {}
system_packages: []
14 changes: 9 additions & 5 deletions comfyui-truss/data/model.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[
{
"url": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors",
"path": "checkpoints/stable-diffusion-1-5.safetensors"
}
]
{
"url": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors",
"path": "checkpoints/sd_xl_base_1.0.safetensors"
},
{
"url": "https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.fp16.safetensors",
"path": "controlnet/diffusers_xl_canny_full.safetensors"
}
]
186 changes: 186 additions & 0 deletions comfyui-truss/model/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import json
import os
import subprocess
import sys
import tempfile
import urllib.parse
import urllib.request

import requests


def download_model(model_url, destination_path):
print(f"Downloading model {model_url} ...")
try:
response = requests.get(model_url, stream=True)
response.raise_for_status()
print("download response: ", response)

# Open the destination file and write the content in chunks
print("opening: ", destination_path)
with open(destination_path, "wb") as file:
print("writing chunks...")
for chunk in response.iter_content(chunk_size=8192):
if chunk: # Filter out keep-alive new chunks
file.write(chunk)

print("done writing chunks!!!!")

print(f"Downloaded file to: {destination_path}")
except requests.exceptions.RequestException as e:
print(f"Download failed: {e}")


def download_tempfile(file_url, filename):
try:
response = requests.get(file_url)
response.raise_for_status()
filetype = filename.split(".")[-1]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{filetype}")
temp_file.write(response.content)
return temp_file.name, temp_file
except Exception as e:
print("Error downloading and saving image:", e)
return None


def setup_comfyui(original_working_directory, data_dir):
git_repo_url = "https://github.com/comfyanonymous/ComfyUI.git"
commit_hash = "248aa3e56355d75ac3d8632af769e6c700d9bfac"
git_clone_command = ["git", "clone", git_repo_url]

try:
# clone the repo
subprocess.run(git_clone_command, check=True)
print("Git repository cloned successfully!")

os.chdir(os.path.join(original_working_directory, "ComfyUI"))

# Pin comfyUI to a specific commit
checkout_command = ["git", "checkout", commit_hash]
subprocess.run(checkout_command, check=True)

model_json = os.path.join(original_working_directory, data_dir, "model.json")
with open(model_json, "r") as file:
data = json.load(file)

print(f"model json file: {data}")

if data and len(data) > 0:
for model in data:
download_model(
model_url=model.get("url"),
destination_path=os.path.join(
os.getcwd(), "models", model.get("path")
),
)

print("Finished downloading models!")

# run the comfy-ui server
subprocess.run([sys.executable, "main.py"], check=True)

except Exception as e:
print(e)
raise Exception("Error setting up comfy UI repo")


def queue_prompt(prompt, client_id, server_address):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())


def get_image(filename, subfolder, folder_type, server_address):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(
"http://{}/view?{}".format(server_address, url_values)
) as response:
return response.read()


def get_history(prompt_id, server_address):
with urllib.request.urlopen(
"http://{}/history/{}".format(server_address, prompt_id)
) as response:
return json.loads(response.read())


def get_images(ws, prompt, client_id, server_address):
prompt_id = queue_prompt(prompt, client_id, server_address)["prompt_id"]
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data

history = get_history(prompt_id, server_address)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = get_image(
image["filename"],
image["subfolder"],
image["type"],
server_address,
)
images_output.append(image_data)
output_images[node_id] = images_output

return output_images


def fill_template(workflow, template_values):
if isinstance(workflow, dict):
# If it's a dictionary, recursively process its keys and values
for key, value in workflow.items():
workflow[key] = fill_template(value, template_values)
return workflow
elif isinstance(workflow, list):
# If it's a list, recursively process its elements
return [fill_template(item, template_values) for item in workflow]
elif (
isinstance(workflow, str)
and workflow.startswith("{{")
and workflow.endswith("}}")
):
# If it's a placeholder, replace it with the corresponding value
placeholder = workflow[2:-2]
if placeholder in template_values:
return template_values[placeholder]
else:
return workflow # Placeholder not found in values
else:
# If it's neither a dictionary, list, nor a placeholder, leave it unchanged
return workflow


def convert_request_file_url_to_path(template_values):
tempfiles = []
new_template_values = template_values.copy()
for key, value in template_values.items():
if isinstance(value, str) and (
value.startswith("https://") or value.startswith("http://")
):
if value[-1] == "/":
value = value[:-1]
filename = value.split("/")[-1]

file_destination_path, file_object = download_tempfile(
file_url=value, filename=filename
)
tempfiles.append(file_object)
new_template_values[key] = file_destination_path

return new_template_values, tempfiles
Loading

0 comments on commit 87b079f

Please sign in to comment.