refactor: tool calls sent when pending and after completion

This commit is contained in:
Siddhant Rai
2025-06-11 12:40:32 +05:30
parent dd9d18208d
commit 3351f71813
6 changed files with 94 additions and 72 deletions

View File

@@ -136,6 +136,15 @@ class BaseAgent(ABC):
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
@@ -188,19 +197,26 @@ class BaseAgent(ABC):
else:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
call_id = getattr(call, "id", None)
tool_call_data["result"] = result
tool_call_data = {
"tool_name": tool_data["name"],
"call_id": call_id if call_id is not None else "None",
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": result,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
}
for tool_call in self.tool_calls
]
def _build_messages(
self,
system_prompt: str,
@@ -264,13 +280,11 @@ class BaseAgent(ABC):
and self.tools
):
gen_kwargs["tools"] = self.tools
resp = self.llm.gen_stream(**gen_kwargs)
if log_context:
data = build_stack_data(self.llm, exclude_attributes=["client"])
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(
@@ -288,3 +302,23 @@ class BaseAgent(ABC):
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp
def _handle_response(self, response, tools_dict, messages, log_context):
if isinstance(response, str):
yield {"answer": response}
return
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
return
processed_response_gen = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
for event in processed_response_gen:
if isinstance(event, str):
yield {"answer": event}
elif hasattr(event, "message") and getattr(event.message, "content", None):
yield {"answer": event.message.content}
elif isinstance(event, dict) and "type" in event:
yield event

View File

@@ -23,7 +23,6 @@ class ClassicAgent(BaseAgent):
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Main execution flow for the agent."""
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
@@ -52,46 +51,3 @@ class ClassicAgent(BaseAgent):
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _handle_response(self, response, tools_dict, messages, log_context):
"""Handle different types of LLM responses consistently."""
# Handle simple string responses
if isinstance(response, str):
yield {"answer": response}
return
# Handle content from message objects
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
return
# Handle complex responses that may require tool use
processed_response = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
# Yield the final processed response
if isinstance(processed_response, str):
yield {"answer": processed_response}
elif hasattr(processed_response, "message") and getattr(
processed_response.message, "content", None
):
yield {"answer": processed_response.message.content}
else:
for line in processed_response:
if isinstance(line, str):
yield {"answer": line}
def _get_truncated_tool_calls(self):
"""Return tool calls with truncated results for cleaner output."""
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
}
for tool_call in self.tool_calls
]

View File

@@ -310,12 +310,13 @@ def complete_stream(
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:

View File

@@ -180,7 +180,7 @@ class LLMHandler(ABC):
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> List[Dict]:
) -> Generator:
"""
Execute tool calls and update conversation history.
@@ -198,7 +198,13 @@ class LLMHandler(ABC):
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(tools_dict, call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
while True:
try:
yield next(tool_executor_gen)
except StopIteration as e:
tool_response, call_id = e.value
break
updated_messages.append(
{
@@ -231,7 +237,7 @@ class LLMHandler(ABC):
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Union[str, Dict]:
) -> Generator:
"""
Handle non-streaming response flow.
@@ -248,9 +254,15 @@ class LLMHandler(ABC):
self.llm_calls.append(build_stack_data(agent.llm))
while parsed.requires_tool_call:
messages = self.handle_tool_calls(
tool_handler_gen = self.handle_tool_calls(
agent, parsed.tool_calls, tools_dict, messages
)
while True:
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
@@ -297,9 +309,15 @@ class LLMHandler(ABC):
if call.arguments:
existing.arguments += call.arguments
if parsed.finish_reason == "tool_calls":
messages = self.handle_tool_calls(
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
)
while True:
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
break
tool_calls = {}
response = agent.llm.gen_stream(

View File

@@ -14,6 +14,7 @@ import {
ConversationState,
Attachment,
} from './conversationModels';
import { ToolCallsType } from './types';
const initialState: ConversationState = {
queries: [],
@@ -110,11 +111,11 @@ export const fetchAnswer = createAsyncThunk<
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
} else if (data.type === 'tool_call') {
dispatch(
updateToolCalls({
updateToolCall({
index: targetIndex,
query: { tool_calls: data.tool_calls },
tool_call: data.data as ToolCallsType,
}),
);
} else if (data.type === 'error') {
@@ -280,12 +281,23 @@ export const conversationSlice = createSlice({
state.queries[index].sources!.push(query.sources![0]);
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
state.queries[index].tool_calls = query?.tool_calls ?? [];
updateToolCall(state, action) {
const { index, tool_call } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = [];
}
const existingIndex = state.queries[index].tool_calls.findIndex(
(call) => call.call_id === tool_call.call_id,
);
if (existingIndex !== -1) {
Object.assign(
state.queries[index].tool_calls[existingIndex],
tool_call,
);
} else state.queries[index].tool_calls.push(tool_call);
},
updateQuery(
state,
@@ -378,7 +390,7 @@ export const {
updateConversationId,
updateThought,
updateStreamingSource,
updateToolCalls,
updateToolCall,
setConversation,
setAttachments,
addAttachment,

View File

@@ -3,5 +3,6 @@ export type ToolCallsType = {
action_name: string;
call_id: string;
arguments: Record<string, any>;
result: Record<string, any>;
result?: Record<string, any>;
status?: 'pending' | 'completed';
};