-
Notifications
You must be signed in to change notification settings - Fork 34
/
prompt.py
153 lines (136 loc) · 5.95 KB
/
prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import json
import os
import aiofiles
from api2d import Main
from chatgpt import Main as GPT
from load_config import get_yaml_config, check_file_exists, print_tip
from translate import Sample as translate
config = get_yaml_config()
memory = config["book"]["memory"]
is_translate = config["potential"]["translate"]
role_enabled = config["stable_diffusion"]["role"]
max_token = config["chatgpt"]["max_token"]
def write_to_json(data, filename):
try:
with open(filename, "rb+") as file:
file.seek(0, 2) # 移到文件末尾
if file.tell(): # 如果文件非空
file.seek(-1, 2) # 定位到文件的最后一个字符(即结尾的 ']' 前)
file.truncate() # 删除最后一个字符(']')
if file.tell() > 1:
file.write(b",\n") # 如果不是文件开头,写入逗号和换行符
else:
file.write(b"\n") # 否则只写入换行符
file.write(json.dumps(data).encode()) # 写入新的 JSON 对象
file.write(b"\n]") # 重新添加结尾的 ']'
else: # 如果文件为空
file.write(
json.dumps([data], indent=4).encode()
) # 创建新文件并写入数组
except FileNotFoundError:
with open(filename, "wb") as file: # 如果文件不存在,创建并写入
file.write(json.dumps([data], indent=4).encode())
def extract_str(text):
xxx = text.split("**Negative Prompt:**", 1)
prompt = (
xxx[0]
.replace("**Negative Prompt:**", "")
.replace("**Prompt:**", "")
.replace("Prompt:", "")
.replace("\n", "")
)
negative_prompt = (
xxx[1]
.replace("**Negative Prompt:**", "")
.replace("Negative", "")
.replace("Prompt:", "")
.replace("**Prompt:**", "")
.replace("\n", "")
)
return prompt, negative_prompt
async def process_line2(line, line_number, prompt_json_save_path, messages_save_path, name, path):
await print_tip(f"正在处理第{line_number}段")
is_exists = await check_file_exists(prompt_json_save_path)
is_message_exists = await check_file_exists(messages_save_path)
if memory and is_exists:
with open(prompt_json_save_path, "r", encoding="utf-8") as file:
prompt_data = json.load(file)
if line_number <= len(prompt_data):
await print_tip(f"使用缓存:跳过第{line_number}段")
return
else:
if is_message_exists:
async with aiofiles.open(
messages_save_path, "r", encoding="utf-8"
) as f:
content = await f.read()
messages = json.loads(content)
text = f"第{line_number}段:" + line.strip()
if not is_message_exists:
with open(f"{name}prompt.txt", "r", encoding="utf-8") as f:
messages = [
{
"role": "system",
"content": f.read(),
}
]
result, message, total_tokens = GPT().chat(text, messages)
await print_tip(f"当前total_tokens:{total_tokens}")
if total_tokens >= max_token:
# token 已经达到上限 重新请求GPT 清空之前的记录
os.remove(messages_save_path)
return await process_line2(line, line_number, prompt_json_save_path, messages_save_path, name, path)
else:
prompt, negative_prompt = extract_str(message)
obj = {"prompt": prompt, "negative_prompt": negative_prompt}
if role_enabled:
# 固定人物
if os.path.join(path, "role.json"):
with open(os.path.join(path, "role.json"), "r", encoding="utf-8") as file:
role_data = json.load(file)
roles = []
for role in role_data:
if role["name"] in text:
roles.append(role["name"])
obj["role"] = roles
write_to_json(obj, prompt_json_save_path)
messages = result
with open(messages_save_path, "w") as f:
f.write(json.dumps(messages))
return
async def translates(text, line_number, prompt_json_save_path, path):
is_exists = await check_file_exists(prompt_json_save_path)
if memory and is_exists:
with open(prompt_json_save_path, "r", encoding="utf-8") as file:
prompt_data = json.load(file)
if line_number <= len(prompt_data):
return
prompt = translate.main(text)
obj = {"prompt": prompt, "negative_prompt": "nsfw,(low quality,normal quality,worst quality,jpeg artifacts),cropped,monochrome,lowres,low saturation,((watermark)),(white letters)"}
if role_enabled:
# 固定人物
if os.path.join(path, "role.json"):
with open(os.path.join(path, "role.json"), "r", encoding="utf-8") as file:
role_data = json.load(file)
roles = []
for role in role_data:
if role["name"] in text:
roles.append(role["name"])
obj["role"] = roles
write_to_json(obj, prompt_json_save_path)
async def generate_prompt(path, save_path, name):
await print_tip("开始生成提示词")
async with aiofiles.open(f"{path}/{name}.txt", "r", encoding="utf8") as file:
# 初始化行数计数器
lines = await file.readlines()
# 循环输出每一行内容
prompt_json_save_path = os.path.join(save_path, f"{name}.json")
messages_save_path = os.path.join(save_path, f"{name}messages.json")
for line_number, line in enumerate(lines, start=1):
if line:
if is_translate:
await translates(line, line_number, prompt_json_save_path, path)
else:
await process_line2(line, line_number, prompt_json_save_path, messages_save_path, name, path)
if __name__ == "__main__":
generate_prompt()