forked from starpig1129/AI-Data-Analysis-MultiAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnode.py
279 lines (234 loc) · 11.9 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage,ToolMessage
from openai import InternalServerError
from state import State
import logging
import json
import re
import os
from pathlib import Path
from langchain.agents import AgentExecutor
# Set up logger
logger = logging.getLogger(__name__)
def agent_node(state: State, agent: AgentExecutor, name: str) -> State:
"""
Process an agent's action and update the state accordingly.
"""
logger.info(f"Processing agent: {name}")
try:
result = agent.invoke(state)
logger.debug(f"Agent {name} result: {result}")
output = result["output"] if isinstance(result, dict) and "output" in result else str(result)
ai_message = AIMessage(content=output, name=name)
state["messages"].append(ai_message)
state["sender"] = name
if name == "hypothesis_agent" and not state["hypothesis"]:
state["hypothesis"] = ai_message
logger.info("Hypothesis updated")
elif name == "process_agent":
state["process_decision"] = ai_message
logger.info("Process decision updated")
elif name == "visualization_agent":
state["visualization_state"] = ai_message
logger.info("Visualization state updated")
elif name == "searcher_agent":
state["searcher_state"] = ai_message
logger.info("Searcher state updated")
elif name == "report_agent":
state["report_section"] = ai_message
logger.info("Report section updated")
elif name == "quality_review_agent":
state["quality_review"] = ai_message
state["needs_revision"] = "revision needed" in output.lower()
logger.info(f"Quality review updated. Needs revision: {state['needs_revision']}")
logger.info(f"Agent {name} processing completed")
return state
except Exception as e:
logger.error(f"Error occurred while processing agent {name}: {str(e)}", exc_info=True)
error_message = AIMessage(content=f"Error: {str(e)}", name=name)
return {"messages": [error_message]}
def human_choice_node(state: State) -> State:
"""
Handle human input to choose the next step in the process.
If regenerating hypothesis, prompt for specific areas to modify.
"""
logger.info("Prompting for human choice")
print("Please choose the next step:")
print("1. Regenerate hypothesis")
print("2. Continue the research process")
while True:
choice = input("Please enter your choice (1 or 2): ")
if choice in ["1", "2"]:
break
logger.warning(f"Invalid input received: {choice}")
print("Invalid input, please try again.")
if choice == "1":
modification_areas = input("Please specify which parts of the hypothesis you want to modify: ")
content = f"Regenerate hypothesis. Areas to modify: {modification_areas}"
state["hypothesis"] = ""
state["modification_areas"] = modification_areas
logger.info("Hypothesis cleared for regeneration")
logger.info(f"Areas to modify: {modification_areas}")
else:
content = "Continue the research process"
state["process"] = "Continue the research process"
logger.info("Continuing research process")
human_message = HumanMessage(content=content)
state["messages"].append(human_message)
state["sender"] = 'human'
logger.info("Human choice processed")
return state
def create_message(message: dict[str], name: str) -> BaseMessage:
"""
Create a BaseMessage object based on the message type.
"""
content = message.get("content", "")
message_type = message.get("type", "").lower()
logger.debug(f"Creating message of type {message_type} for {name}")
return HumanMessage(content=content) if message_type == "human" else AIMessage(content=content, name=name)
def note_agent_node(state: State, agent: AgentExecutor, name: str) -> State:
"""
Process the note agent's action and update the entire state.
"""
logger.info(f"Processing note agent: {name}")
try:
current_messages = state.get("messages", [])
head_messages, tail_messages = [], []
if len(current_messages) > 6:
head_messages = current_messages[:2]
tail_messages = current_messages[-2:]
state = {**state, "messages": current_messages[2:-2]}
logger.debug("Trimmed messages for processing")
result = agent.invoke(state)
logger.debug(f"Note agent {name} result: {result}")
output = result["output"] if isinstance(result, dict) and "output" in result else str(result)
cleaned_output = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', output)
parsed_output = json.loads(cleaned_output)
logger.debug(f"Parsed output: {parsed_output}")
new_messages = [create_message(msg, name) for msg in parsed_output.get("messages", [])]
messages = new_messages if new_messages else current_messages
combined_messages = head_messages + messages + tail_messages
updated_state: State = {
"messages": combined_messages,
"hypothesis": str(parsed_output.get("hypothesis", state.get("hypothesis", ""))),
"process": str(parsed_output.get("process", state.get("process", ""))),
"process_decision": str(parsed_output.get("process_decision", state.get("process_decision", ""))),
"visualization_state": str(parsed_output.get("visualization_state", state.get("visualization_state", ""))),
"searcher_state": str(parsed_output.get("searcher_state", state.get("searcher_state", ""))),
"code_state": str(parsed_output.get("code_state", state.get("code_state", ""))),
"report_section": str(parsed_output.get("report_section", state.get("report_section", ""))),
"quality_review": str(parsed_output.get("quality_review", state.get("quality_review", ""))),
"needs_revision": bool(parsed_output.get("needs_revision", state.get("needs_revision", False))),
"sender": 'note_agent'
}
logger.info("Updated state successfully")
return updated_state
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}", exc_info=True)
return _create_error_state(state, AIMessage(content=f"Error parsing output: {output}", name=name), name, "JSON decode error")
except InternalServerError as e:
logger.error(f"OpenAI Internal Server Error: {e}", exc_info=True)
return _create_error_state(state, AIMessage(content=f"OpenAI Error: {str(e)}", name=name), name, "OpenAI error")
except Exception as e:
logger.error(f"Unexpected error in note_agent_node: {e}", exc_info=True)
return _create_error_state(state, AIMessage(content=f"Unexpected error: {str(e)}", name=name), name, "Unexpected error")
def _create_error_state(state: State, error_message: AIMessage, name: str, error_type: str) -> State:
"""
Create an error state when an exception occurs.
"""
logger.info(f"Creating error state for {name}: {error_type}")
error_state:State = {
"messages": state.get("messages", []) + [error_message],
"hypothesis": str(state.get("hypothesis", "")),
"process": str(state.get("process", "")),
"process_decision": str(state.get("process_decision", "")),
"visualization_state": str(state.get("visualization_state", "")),
"searcher_state": str(state.get("searcher_state", "")),
"code_state": str(state.get("code_state", "")),
"report_section": str(state.get("report_section", "")),
"quality_review": str(state.get("quality_review", "")),
"needs_revision": bool(state.get("needs_revision", False)),
"sender": 'note_agent'
}
return error_state
def human_review_node(state: State) -> State:
"""
Display current state to the user and update the state based on user input.
Includes error handling for robustness.
"""
try:
print("Current research progress:")
print(state)
print("\nDo you need additional analysis or modifications?")
while True:
user_input = input("Enter 'yes' to continue analysis, or 'no' to end the research: ").lower()
if user_input in ['yes', 'no']:
break
print("Invalid input. Please enter 'yes' or 'no'.")
if user_input == 'yes':
while True:
additional_request = input("Please enter your additional analysis request: ").strip()
if additional_request:
state["messages"].append(HumanMessage(content=additional_request))
state["needs_revision"] = True
break
print("Request cannot be empty. Please try again.")
else:
state["needs_revision"] = False
state["sender"] = "human"
logger.info("Human review completed successfully.")
return state
except KeyboardInterrupt:
logger.warning("Human review interrupted by user.")
return None
except Exception as e:
logger.error(f"An error occurred during human review: {str(e)}", exc_info=True)
return None
def refiner_node(state: State, agent: AgentExecutor, name: str) -> State:
"""
Read MD file contents and PNG file names from the specified storage path,
add them as report materials to a new message,
then process with the agent and update the original state.
If token limit is exceeded, use only MD file names instead of full content.
"""
try:
# Get storage path
storage_path = Path(os.getenv('STORAGE_PATH', './data_storage/'))
# Collect materials
materials = []
md_files = list(storage_path.glob("*.md"))
png_files = list(storage_path.glob("*.png"))
# Process MD files
for md_file in md_files:
with open(md_file, "r", encoding="utf-8") as f:
materials.append(f"MD file '{md_file.name}':\n{f.read()}")
# Process PNG files
materials.extend(f"PNG file: '{png_file.name}'" for png_file in png_files)
# Combine materials
combined_materials = "\n\n".join(materials)
report_content = f"Report materials:\n{combined_materials}"
# Create refiner state
refiner_state = state.copy()
refiner_state["messages"] = [BaseMessage(content=report_content)]
try:
# Attempt to invoke agent with full content
result = agent.invoke(refiner_state)
except Exception as token_error:
# If token limit is exceeded, retry with only MD file names
logger.warning("Token limit exceeded. Retrying with MD file names only.")
md_file_names = [f"MD file: '{md_file.name}'" for md_file in md_files]
png_file_names = [f"PNG file: '{png_file.name}'" for png_file in png_files]
simplified_materials = "\n".join(md_file_names + png_file_names)
simplified_report_content = f"Report materials (file names only):\n{simplified_materials}"
refiner_state["messages"] = [BaseMessage(content=simplified_report_content)]
result = agent.invoke(refiner_state)
# Update original state
state["messages"].append(AIMessage(content=result))
state["sender"] = name
logger.info("Refiner node processing completed")
return state
except Exception as e:
logger.error(f"Error occurred while processing refiner node: {str(e)}", exc_info=True)
state["messages"].append(AIMessage(content=f"Error: {str(e)}", name=name))
return state
logger.info("Agent processing module initialized")