refactor: tool calls sent when pending and after completion

This commit is contained in:
Siddhant Rai
2025-06-11 12:40:32 +05:30
parent dd9d18208d
commit 3351f71813
6 changed files with 94 additions and 72 deletions

View File

@@ -23,7 +23,6 @@ class ClassicAgent(BaseAgent):
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Main execution flow for the agent."""
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
@@ -52,46 +51,3 @@ class ClassicAgent(BaseAgent):
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _handle_response(self, response, tools_dict, messages, log_context):
"""Handle different types of LLM responses consistently."""
# Handle simple string responses
if isinstance(response, str):
yield {"answer": response}
return
# Handle content from message objects
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
return
# Handle complex responses that may require tool use
processed_response = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
# Yield the final processed response
if isinstance(processed_response, str):
yield {"answer": processed_response}
elif hasattr(processed_response, "message") and getattr(
processed_response.message, "content", None
):
yield {"answer": processed_response.message.content}
else:
for line in processed_response:
if isinstance(line, str):
yield {"answer": line}
def _get_truncated_tool_calls(self):
"""Return tool calls with truncated results for cleaner output."""
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
}
for tool_call in self.tool_calls
]