mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
refactor: tool calls sent when pending and after completion
This commit is contained in:
@@ -180,7 +180,7 @@ class LLMHandler(ABC):
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> List[Dict]:
|
||||
) -> Generator:
|
||||
"""
|
||||
Execute tool calls and update conversation history.
|
||||
|
||||
@@ -198,7 +198,13 @@ class LLMHandler(ABC):
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(tools_dict, call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_executor_gen)
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
updated_messages.append(
|
||||
{
|
||||
@@ -231,7 +237,7 @@ class LLMHandler(ABC):
|
||||
|
||||
def handle_non_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
) -> Union[str, Dict]:
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle non-streaming response flow.
|
||||
|
||||
@@ -248,9 +254,15 @@ class LLMHandler(ABC):
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
while parsed.requires_tool_call:
|
||||
messages = self.handle_tool_calls(
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, parsed.tool_calls, tools_dict, messages
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
@@ -297,9 +309,15 @@ class LLMHandler(ABC):
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
messages = self.handle_tool_calls(
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
|
||||
Reference in New Issue
Block a user