-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathglobals.py
executable file
·118 lines (98 loc) · 3.52 KB
/
globals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import struct
from enum import Enum
import aiohttp
from typing import List, Union, Any, Optional
from PIL import Image, ImageOps
from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
class SimplePrompt(BaseModel):
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
token: Optional[str]
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
sockets = dict()
prompt_metadata: dict[str, SimplePrompt] = {}
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
max_output_id_length = 24
async def send_image(image_data, sid=None, output_id:str = None):
max_length = max_output_id_length
output_id = output_id[:max_length]
padded_output_id = output_id.ljust(max_length, '\x00')
encoded_output_id = padded_output_id.encode('ascii', 'replace')
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
quality = image_data[3]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
elif image_type == "WEBP":
type_num = 3
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
# 4 bytes for the type
bytesIO.write(header)
# 10 bytes for the output_id
position_before = bytesIO.tell()
bytesIO.write(encoded_output_id)
position_after = bytesIO.tell()
bytes_written = position_after - position_before
print(f"Bytes written: {bytes_written}")
image.save(bytesIO, format=image_type, quality=quality, compress_level=1)
preview_bytes = bytesIO.getvalue()
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
print("send error:", err)
def encode_bytes(event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(event, data, sid=None):
message = encode_bytes(event, data)
print("sending image to ", event, sid)
if sid is None:
_sockets = list(sockets.values())
for ws in _sockets:
await send_socket_catch_exception(ws.send_bytes, message)
elif sid in sockets:
await send_socket_catch_exception(sockets[sid].send_bytes, message)