refactor: tool calls sent when pending and after completion

This commit is contained in:
Siddhant Rai
2025-06-11 12:40:32 +05:30
parent dd9d18208d
commit 3351f71813
6 changed files with 94 additions and 72 deletions

View File

@@ -136,6 +136,15 @@ class BaseAgent(ABC):
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
@@ -188,19 +197,26 @@ class BaseAgent(ABC):
else:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
call_id = getattr(call, "id", None)
tool_call_data["result"] = result
tool_call_data = {
"tool_name": tool_data["name"],
"call_id": call_id if call_id is not None else "None",
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": result,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
}
for tool_call in self.tool_calls
]
def _build_messages(
self,
system_prompt: str,
@@ -264,13 +280,11 @@ class BaseAgent(ABC):
and self.tools
):
gen_kwargs["tools"] = self.tools
resp = self.llm.gen_stream(**gen_kwargs)
if log_context:
data = build_stack_data(self.llm, exclude_attributes=["client"])
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(
@@ -288,3 +302,23 @@ class BaseAgent(ABC):
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp
def _handle_response(self, response, tools_dict, messages, log_context):
if isinstance(response, str):
yield {"answer": response}
return
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
return
processed_response_gen = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
for event in processed_response_gen:
if isinstance(event, str):
yield {"answer": event}
elif hasattr(event, "message") and getattr(event.message, "content", None):
yield {"answer": event.message.content}
elif isinstance(event, dict) and "type" in event:
yield event