From 3351f71813d3057251bce3756d8a83adfee4959e Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 11 Jun 2025 12:40:32 +0530 Subject: [PATCH] refactor: tool calls sent when pending and after completion --- application/agents/base.py | 54 +++++++++++++++---- application/agents/classic_agent.py | 44 --------------- application/api/answer/routes.py | 5 +- application/llm/handlers/base.py | 28 ++++++++-- .../src/conversation/conversationSlice.ts | 32 +++++++---- frontend/src/conversation/types/index.ts | 3 +- 6 files changed, 94 insertions(+), 72 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index f48418b3..adebc125 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -136,6 +136,15 @@ class BaseAgent(ABC): parser = ToolActionParser(self.llm.__class__.__name__) tool_id, action_name, call_args = parser.parse_args(call) + call_id = getattr(call, "id", None) or str(uuid.uuid4()) + tool_call_data = { + "tool_name": tools_dict[tool_id]["name"], + "call_id": call_id, + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + } + yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}} + tool_data = tools_dict[tool_id] action_data = ( tool_data["config"]["actions"][action_name] @@ -188,19 +197,26 @@ class BaseAgent(ABC): else: print(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) - call_id = getattr(call, "id", None) + tool_call_data["result"] = result - tool_call_data = { - "tool_name": tool_data["name"], - "call_id": call_id if call_id is not None else "None", - "action_name": f"{action_name}_{tool_id}", - "arguments": call_args, - "result": result, - } + yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}} self.tool_calls.append(tool_call_data) return result, call_id + def _get_truncated_tool_calls(self): + return [ + { + **tool_call, + "result": ( + f"{str(tool_call['result'])[:50]}..." + if len(str(tool_call["result"])) > 50 + else tool_call["result"] + ), + } + for tool_call in self.tool_calls + ] + def _build_messages( self, system_prompt: str, @@ -264,13 +280,11 @@ class BaseAgent(ABC): and self.tools ): gen_kwargs["tools"] = self.tools - resp = self.llm.gen_stream(**gen_kwargs) if log_context: data = build_stack_data(self.llm, exclude_attributes=["client"]) log_context.stacks.append({"component": "llm", "data": data}) - return resp def _llm_handler( @@ -288,3 +302,23 @@ class BaseAgent(ABC): data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"]) log_context.stacks.append({"component": "llm_handler", "data": data}) return resp + + def _handle_response(self, response, tools_dict, messages, log_context): + if isinstance(response, str): + yield {"answer": response} + return + if hasattr(response, "message") and getattr(response.message, "content", None): + yield {"answer": response.message.content} + return + + processed_response_gen = self._llm_handler( + response, tools_dict, messages, log_context, self.attachments + ) + + for event in processed_response_gen: + if isinstance(event, str): + yield {"answer": event} + elif hasattr(event, "message") and getattr(event.message, "content", None): + yield {"answer": event.message.content} + elif isinstance(event, dict) and "type" in event: + yield event diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index d0576511..6fe73de0 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -23,7 +23,6 @@ class ClassicAgent(BaseAgent): def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: - """Main execution flow for the agent.""" # Step 1: Retrieve relevant data retrieved_data = self._retriever_search(retriever, query, log_context) @@ -52,46 +51,3 @@ class ClassicAgent(BaseAgent): log_context.stacks.append( {"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}} ) - - def _handle_response(self, response, tools_dict, messages, log_context): - """Handle different types of LLM responses consistently.""" - # Handle simple string responses - if isinstance(response, str): - yield {"answer": response} - return - - # Handle content from message objects - if hasattr(response, "message") and getattr(response.message, "content", None): - yield {"answer": response.message.content} - return - - # Handle complex responses that may require tool use - processed_response = self._llm_handler( - response, tools_dict, messages, log_context, self.attachments - ) - - # Yield the final processed response - if isinstance(processed_response, str): - yield {"answer": processed_response} - elif hasattr(processed_response, "message") and getattr( - processed_response.message, "content", None - ): - yield {"answer": processed_response.message.content} - else: - for line in processed_response: - if isinstance(line, str): - yield {"answer": line} - - def _get_truncated_tool_calls(self): - """Return tool calls with truncated results for cleaner output.""" - return [ - { - **tool_call, - "result": ( - f"{str(tool_call['result'])[:50]}..." - if len(str(tool_call["result"])) > 50 - else tool_call["result"] - ), - } - for tool_call in self.tool_calls - ] diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 44ba035b..83c3db6f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -310,12 +310,13 @@ def complete_stream( yield f"data: {data}\n\n" elif "tool_calls" in line: tool_calls = line["tool_calls"] - data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls}) - yield f"data: {data}\n\n" elif "thought" in line: thought += line["thought"] data = json.dumps({"type": "thought", "thought": line["thought"]}) yield f"data: {data}\n\n" + elif "type" in line: + data = json.dumps(line) + yield f"data: {data}\n\n" if isNoneDoc: for doc in source_log_docs: diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index ede7cec3..43205472 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -180,7 +180,7 @@ class LLMHandler(ABC): def handle_tool_calls( self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict] - ) -> List[Dict]: + ) -> Generator: """ Execute tool calls and update conversation history. @@ -198,7 +198,13 @@ class LLMHandler(ABC): for call in tool_calls: try: self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action(tools_dict, call) + tool_executor_gen = agent._execute_tool_action(tools_dict, call) + while True: + try: + yield next(tool_executor_gen) + except StopIteration as e: + tool_response, call_id = e.value + break updated_messages.append( { @@ -231,7 +237,7 @@ class LLMHandler(ABC): def handle_non_streaming( self, agent, response: Any, tools_dict: Dict, messages: List[Dict] - ) -> Union[str, Dict]: + ) -> Generator: """ Handle non-streaming response flow. @@ -248,9 +254,15 @@ class LLMHandler(ABC): self.llm_calls.append(build_stack_data(agent.llm)) while parsed.requires_tool_call: - messages = self.handle_tool_calls( + tool_handler_gen = self.handle_tool_calls( agent, parsed.tool_calls, tools_dict, messages ) + while True: + try: + yield next(tool_handler_gen) + except StopIteration as e: + messages = e.value + break response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools @@ -297,9 +309,15 @@ class LLMHandler(ABC): if call.arguments: existing.arguments += call.arguments if parsed.finish_reason == "tool_calls": - messages = self.handle_tool_calls( + tool_handler_gen = self.handle_tool_calls( agent, list(tool_calls.values()), tools_dict, messages ) + while True: + try: + yield next(tool_handler_gen) + except StopIteration as e: + messages = e.value + break tool_calls = {} response = agent.llm.gen_stream( diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 961260ea..03532792 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -14,6 +14,7 @@ import { ConversationState, Attachment, } from './conversationModels'; +import { ToolCallsType } from './types'; const initialState: ConversationState = { queries: [], @@ -110,11 +111,11 @@ export const fetchAnswer = createAsyncThunk< query: { sources: data.source ?? [] }, }), ); - } else if (data.type === 'tool_calls') { + } else if (data.type === 'tool_call') { dispatch( - updateToolCalls({ + updateToolCall({ index: targetIndex, - query: { tool_calls: data.tool_calls }, + tool_call: data.data as ToolCallsType, }), ); } else if (data.type === 'error') { @@ -280,12 +281,23 @@ export const conversationSlice = createSlice({ state.queries[index].sources!.push(query.sources![0]); } }, - updateToolCalls( - state, - action: PayloadAction<{ index: number; query: Partial }>, - ) { - const { index, query } = action.payload; - state.queries[index].tool_calls = query?.tool_calls ?? []; + updateToolCall(state, action) { + const { index, tool_call } = action.payload; + + if (!state.queries[index].tool_calls) { + state.queries[index].tool_calls = []; + } + + const existingIndex = state.queries[index].tool_calls.findIndex( + (call) => call.call_id === tool_call.call_id, + ); + + if (existingIndex !== -1) { + Object.assign( + state.queries[index].tool_calls[existingIndex], + tool_call, + ); + } else state.queries[index].tool_calls.push(tool_call); }, updateQuery( state, @@ -378,7 +390,7 @@ export const { updateConversationId, updateThought, updateStreamingSource, - updateToolCalls, + updateToolCall, setConversation, setAttachments, addAttachment, diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index 9b5f2365..4ccb04a1 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -3,5 +3,6 @@ export type ToolCallsType = { action_name: string; call_id: string; arguments: Record; - result: Record; + result?: Record; + status?: 'pending' | 'completed'; };