fix: enhance ReActAgent's response handling and update planning prompt

This commit is contained in:
Alex
2025-05-23 14:21:02 +01:00
parent 7445928c7e
commit 5475e6f7c5
2 changed files with 154 additions and 71 deletions

View File

@@ -1,33 +1,94 @@
import os
from typing import Dict, Generator, List
from typing import Dict, Generator, List, Any
import logging
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
planning_prompt = f.read()
planning_prompt_template = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
"r",
) as f:
final_prompt = f.read()
final_prompt_template = f.read()
class ReActAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan = ""
self.plan: str = ""
self.observations: List[str] = []
def _extract_content_from_llm_response(self, resp: Any) -> str:
"""
Helper to extract string content from various LLM response types.
Handles strings, message objects (OpenAI-like), and streams.
Adapt stream handling for your specific LLM client if not OpenAI.
"""
collected_content = []
if isinstance(resp, str):
collected_content.append(resp)
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
collected_content.append(resp.message.content)
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
hasattr(resp, "choices") and resp.choices and
hasattr(resp.choices[0], "message") and
hasattr(resp.choices[0].message, "content") and
resp.choices[0].message.content is not None
):
collected_content.append(resp.choices[0].message.content) # OpenAI
elif ( # Anthropic new SDK non-streaming content block
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
hasattr(resp.content[0], "text")
):
collected_content.append(resp.content[0].text) # Anthropic
else:
# Assume resp is a stream if not a recognized object
try:
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
content_piece = ""
# OpenAI-like stream
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
hasattr(chunk.choices[0], 'delta') and \
hasattr(chunk.choices[0].delta, 'content') and \
chunk.choices[0].delta.content is not None:
content_piece = chunk.choices[0].delta.content
# Anthropic-like stream (ContentBlockDelta)
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
content_piece = chunk.delta.text
elif isinstance(chunk, str): # Simplest case: stream of strings
content_piece = chunk
# Add other stream chunk formats as needed
if content_piece:
collected_content.append(content_piece)
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
except Exception as e:
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
return "".join(collected_content)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
# Reset state for this generation call
self.plan = ""
self.observations = []
retrieved_data = self._retriever_search(retriever, query, log_context)
if self.user_api_key:
@@ -37,96 +98,117 @@ class ReActAgent(BaseAgent):
self._prepare_tools(tools_dict)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
plan = self._create_plan(query, docs_together, log_context)
for line in plan:
if isinstance(line, str):
self.plan += line
yield {"thought": line}
prompt = self.prompt + f"\nFollow this plan: {self.plan}"
messages = self._build_messages(prompt, query, retrieved_data)
# 1. Create Plan
logger.info("ReActAgent: Creating plan...")
plan_stream = self._create_plan(query, docs_together, log_context)
current_plan_parts = []
for line_chunk in plan_stream:
current_plan_parts.append(line_chunk)
yield {"thought": line_chunk}
self.plan = "".join(current_plan_parts)
if self.plan:
self.observations.append(f"Plan: {self.plan}")
resp = self._llm_gen(messages, log_context)
# 2. Execute Plan (First Reasoning Step)
execution_prompt_str = (self.prompt or "") + f"\n\nFollow this plan:\n{self.plan}"
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
if isinstance(resp, str):
self.observations.append(resp)
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
self.observations.append(resp.message.content)
resp_from_llm_gen = self._llm_gen(messages, log_context)
resp = self._llm_handler(resp, tools_dict, messages, log_context)
for tool_call in self.tool_calls:
observation = (
f"Action '{tool_call['action_name']}' of tool '{tool_call['tool_name']}' "
f"with arguments '{tool_call['arguments']}' returned: '{tool_call['result']}'"
)
self.observations.append(observation)
if isinstance(resp, str):
self.observations.append(resp)
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
self.observations.append(resp.message.content)
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
if initial_llm_thought_content:
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
else:
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
if isinstance(line, str):
self.observations.append(line)
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
logger.info("Executing plan")
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
observation_string = (
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
)
self.observations.append(observation_string)
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
if content_after_handler:
self.observations.append(f"Response after tool execution: {content_after_handler}")
logger.info(f"ReActAgent: LLM response after tool execution: {content_after_handler[:500]}...")
else:
logger.info("ReActAgent: LLM response after handler had no textual content.")
if log_context:
log_context.stacks.append(
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
)
yield {"sources": retrieved_data}
# clean tool_call_data only send first 50 characters of tool_call['result']
for tool_call in self.tool_calls:
if len(str(tool_call["result"])) > 50:
tool_call["result"] = str(tool_call["result"])[:50] + "..."
yield {"tool_calls": self.tool_calls.copy()}
final_answer = self._create_final_answer(query, self.observations, log_context)
for line in final_answer:
if isinstance(line, str):
yield {"answer": line}
display_tool_calls = []
for tc in self.tool_calls:
cleaned_tc = tc.copy()
if len(str(cleaned_tc.get("result", ""))) > 50:
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
display_tool_calls.append(cleaned_tc)
if display_tool_calls:
yield {"tool_calls": display_tool_calls}
# 3. Create Final Answer based on all observations
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
for answer_chunk in final_answer_stream:
yield {"answer": answer_chunk}
logger.info("ReActAgent: Finished generating final answer.")
def _create_plan(
self, query: str, docs_data: str, log_context: LogContext = None
) -> Generator[str, None, None]:
plan_prompt = planning_prompt.replace("{query}", query)
if "{summaries}" in planning_prompt:
summaries = docs_data
plan_prompt = plan_prompt.replace("{summaries}", summaries)
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
if "{summaries}" in plan_prompt_filled:
summaries = docs_data if docs_data else "No documents retrieved."
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
messages = [{"role": "user", "content": plan_prompt}]
print(self.tools)
plan = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
messages = [{"role": "user", "content": plan_prompt_filled}]
plan_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "planning_llm", "data": data})
return plan
for chunk in plan_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece
def _create_final_answer(
self, query: str, observations: List[str], log_context: LogContext = None
) -> str:
) -> Generator[str, None, None]:
observation_string = "\n".join(observations)
final_answer_prompt = final_prompt.format(
max_obs_len = 10000
if len(observation_string) > max_obs_len:
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
final_answer_prompt_filled = final_prompt_template.format(
query=query, observations=observation_string
)
messages = [{"role": "user", "content": final_answer_prompt}]
final_answer = self.llm.gen_stream(model=self.gpt_model, messages=messages)
messages = [{"role": "user", "content": final_answer_prompt_filled}]
# Final answer should synthesize, not call tools.
final_answer_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "final_answer_llm", "data": data})
return final_answer
for chunk in final_answer_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece

View File

@@ -1,10 +1,11 @@
You are an AI assistant and talk like you're thinking out loud. Given the following query, outline a concise thought process that includes key steps and considerations necessary for effective analysis and response. Avoid pointwise formatting. The goal is to break down the query into manageable components without excessive detail, focusing on clarity and logical progression.
Include the following elements in your thought process:
Include the following elements in your thought and execution process:
1. Identify the main objective of the query.
2. Determine any relevant context or background information needed to understand the query.
3. List potential approaches or methods to address the query.
4. Highlight any critical factors or constraints that may influence the outcome.
5. Use available tools to help you with the analysis and execute them. Execute tools needed now instead of later.
Query: {query}
Summaries: {summaries}