forked from open-mmlab/Live2Diff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
121 lines (113 loc) · 2.99 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import os
from typing import List, NamedTuple
class Args(NamedTuple):
host: str
port: int
reload: bool
max_queue_size: int
timeout: float
safety_checker: bool
taesd: bool
ssl_certfile: str
ssl_keyfile: str
debug: bool
acceleration: str
engine_dir: str
config: str
seed: int
num_inference_steps: int
strength: float
t_index_list: List[int]
prompt: str
def pretty_print(self):
print("\n")
for field, value in self._asdict().items():
print(f"{field}: {value}")
print("\n")
MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
TIMEOUT = float(os.environ.get("TIMEOUT", 0))
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True"
USE_TAESD = os.environ.get("USE_TAESD", "True") == "True"
ENGINE_DIR = os.environ.get("ENGINE_DIR", "engines")
ACCELERATION = os.environ.get("ACCELERATION", "tensorrt")
default_host = os.getenv("HOST", "0.0.0.0")
default_port = int(os.getenv("PORT", "7860"))
default_mode = os.getenv("MODE", "default")
parser = argparse.ArgumentParser(description="Run the app")
parser.add_argument("--host", type=str, default=default_host, help="Host address")
parser.add_argument("--port", type=int, default=default_port, help="Port number")
parser.add_argument("--reload", action="store_true", help="Reload code on change")
parser.add_argument(
"--max-queue-size",
dest="max_queue_size",
type=int,
default=MAX_QUEUE_SIZE,
help="Max Queue Size",
)
parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout")
parser.add_argument(
"--safety-checker",
dest="safety_checker",
action="store_true",
default=SAFETY_CHECKER,
help="Safety Checker",
)
parser.add_argument(
"--taesd",
dest="taesd",
action="store_true",
help="Use Tiny Autoencoder",
)
parser.add_argument(
"--no-taesd",
dest="taesd",
action="store_false",
help="Use Tiny Autoencoder",
)
parser.add_argument(
"--ssl-certfile",
dest="ssl_certfile",
type=str,
default=None,
help="SSL certfile",
)
parser.add_argument(
"--ssl-keyfile",
dest="ssl_keyfile",
type=str,
default=None,
help="SSL keyfile",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Debug",
)
parser.add_argument(
"--acceleration",
type=str,
default=ACCELERATION,
choices=["none", "xformers", "tensorrt"],
help="Acceleration",
)
parser.add_argument(
"--engine-dir",
dest="engine_dir",
type=str,
default=ENGINE_DIR,
help="Engine Dir",
)
parser.add_argument(
"--config",
default="./demo_cfg.yaml",
)
parser.add_argument("--num-inference-steps", type=int, default=None)
parser.add_argument("--strength", type=float, default=None)
parser.add_argument("--t-index-list", type=list)
parser.add_argument("--seed", default=42)
parser.add_argument("--prompt", type=str)
parser.set_defaults(taesd=USE_TAESD)
config = Args(**vars(parser.parse_args()))
config.pretty_print()