refactor: tool calls sent when pending and after completion

This commit is contained in:
Siddhant Rai
2025-06-11 12:40:32 +05:30
parent dd9d18208d
commit 3351f71813
6 changed files with 94 additions and 72 deletions

View File

@@ -180,7 +180,7 @@ class LLMHandler(ABC):
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> List[Dict]:
) -> Generator:
"""
Execute tool calls and update conversation history.
@@ -198,7 +198,13 @@ class LLMHandler(ABC):
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(tools_dict, call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
while True:
try:
yield next(tool_executor_gen)
except StopIteration as e:
tool_response, call_id = e.value
break
updated_messages.append(
{
@@ -231,7 +237,7 @@ class LLMHandler(ABC):
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Union[str, Dict]:
) -> Generator:
"""
Handle non-streaming response flow.
@@ -248,9 +254,15 @@ class LLMHandler(ABC):
self.llm_calls.append(build_stack_data(agent.llm))
while parsed.requires_tool_call:
messages = self.handle_tool_calls(
tool_handler_gen = self.handle_tool_calls(
agent, parsed.tool_calls, tools_dict, messages
)
while True:
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
@@ -297,9 +309,15 @@ class LLMHandler(ABC):
if call.arguments:
existing.arguments += call.arguments
if parsed.finish_reason == "tool_calls":
messages = self.handle_tool_calls(
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
)
while True:
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
break
tool_calls = {}
response = agent.llm.gen_stream(