mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
refactor: tool calls sent when pending and after completion
This commit is contained in:
@@ -136,6 +136,15 @@ class BaseAgent(ABC):
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
@@ -188,19 +197,26 @@ class BaseAgent(ABC):
|
||||
else:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
tool_call_data["result"] = result
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_truncated_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
@@ -264,13 +280,11 @@ class BaseAgent(ABC):
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm, exclude_attributes=["client"])
|
||||
log_context.stacks.append({"component": "llm", "data": data})
|
||||
|
||||
return resp
|
||||
|
||||
def _llm_handler(
|
||||
@@ -288,3 +302,23 @@ class BaseAgent(ABC):
|
||||
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
|
||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||
return resp
|
||||
|
||||
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||
if isinstance(response, str):
|
||||
yield {"answer": response}
|
||||
return
|
||||
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||
yield {"answer": response.message.content}
|
||||
return
|
||||
|
||||
processed_response_gen = self._llm_handler(
|
||||
response, tools_dict, messages, log_context, self.attachments
|
||||
)
|
||||
|
||||
for event in processed_response_gen:
|
||||
if isinstance(event, str):
|
||||
yield {"answer": event}
|
||||
elif hasattr(event, "message") and getattr(event.message, "content", None):
|
||||
yield {"answer": event.message.content}
|
||||
elif isinstance(event, dict) and "type" in event:
|
||||
yield event
|
||||
|
||||
@@ -23,7 +23,6 @@ class ClassicAgent(BaseAgent):
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Main execution flow for the agent."""
|
||||
# Step 1: Retrieve relevant data
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
@@ -52,46 +51,3 @@ class ClassicAgent(BaseAgent):
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||
"""Handle different types of LLM responses consistently."""
|
||||
# Handle simple string responses
|
||||
if isinstance(response, str):
|
||||
yield {"answer": response}
|
||||
return
|
||||
|
||||
# Handle content from message objects
|
||||
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||
yield {"answer": response.message.content}
|
||||
return
|
||||
|
||||
# Handle complex responses that may require tool use
|
||||
processed_response = self._llm_handler(
|
||||
response, tools_dict, messages, log_context, self.attachments
|
||||
)
|
||||
|
||||
# Yield the final processed response
|
||||
if isinstance(processed_response, str):
|
||||
yield {"answer": processed_response}
|
||||
elif hasattr(processed_response, "message") and getattr(
|
||||
processed_response.message, "content", None
|
||||
):
|
||||
yield {"answer": processed_response.message.content}
|
||||
else:
|
||||
for line in processed_response:
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
|
||||
def _get_truncated_tool_calls(self):
|
||||
"""Return tool calls with truncated results for cleaner output."""
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
@@ -310,12 +310,13 @@ def complete_stream(
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
|
||||
@@ -180,7 +180,7 @@ class LLMHandler(ABC):
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> List[Dict]:
|
||||
) -> Generator:
|
||||
"""
|
||||
Execute tool calls and update conversation history.
|
||||
|
||||
@@ -198,7 +198,13 @@ class LLMHandler(ABC):
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(tools_dict, call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_executor_gen)
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
updated_messages.append(
|
||||
{
|
||||
@@ -231,7 +237,7 @@ class LLMHandler(ABC):
|
||||
|
||||
def handle_non_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
) -> Union[str, Dict]:
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle non-streaming response flow.
|
||||
|
||||
@@ -248,9 +254,15 @@ class LLMHandler(ABC):
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
while parsed.requires_tool_call:
|
||||
messages = self.handle_tool_calls(
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, parsed.tool_calls, tools_dict, messages
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
@@ -297,9 +309,15 @@ class LLMHandler(ABC):
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
messages = self.handle_tool_calls(
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
ConversationState,
|
||||
Attachment,
|
||||
} from './conversationModels';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
const initialState: ConversationState = {
|
||||
queries: [],
|
||||
@@ -110,11 +111,11 @@ export const fetchAnswer = createAsyncThunk<
|
||||
query: { sources: data.source ?? [] },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_calls') {
|
||||
} else if (data.type === 'tool_call') {
|
||||
dispatch(
|
||||
updateToolCalls({
|
||||
updateToolCall({
|
||||
index: targetIndex,
|
||||
query: { tool_calls: data.tool_calls },
|
||||
tool_call: data.data as ToolCallsType,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'error') {
|
||||
@@ -280,12 +281,23 @@ export const conversationSlice = createSlice({
|
||||
state.queries[index].sources!.push(query.sources![0]);
|
||||
}
|
||||
},
|
||||
updateToolCalls(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
) {
|
||||
const { index, query } = action.payload;
|
||||
state.queries[index].tool_calls = query?.tool_calls ?? [];
|
||||
updateToolCall(state, action) {
|
||||
const { index, tool_call } = action.payload;
|
||||
|
||||
if (!state.queries[index].tool_calls) {
|
||||
state.queries[index].tool_calls = [];
|
||||
}
|
||||
|
||||
const existingIndex = state.queries[index].tool_calls.findIndex(
|
||||
(call) => call.call_id === tool_call.call_id,
|
||||
);
|
||||
|
||||
if (existingIndex !== -1) {
|
||||
Object.assign(
|
||||
state.queries[index].tool_calls[existingIndex],
|
||||
tool_call,
|
||||
);
|
||||
} else state.queries[index].tool_calls.push(tool_call);
|
||||
},
|
||||
updateQuery(
|
||||
state,
|
||||
@@ -378,7 +390,7 @@ export const {
|
||||
updateConversationId,
|
||||
updateThought,
|
||||
updateStreamingSource,
|
||||
updateToolCalls,
|
||||
updateToolCall,
|
||||
setConversation,
|
||||
setAttachments,
|
||||
addAttachment,
|
||||
|
||||
@@ -3,5 +3,6 @@ export type ToolCallsType = {
|
||||
action_name: string;
|
||||
call_id: string;
|
||||
arguments: Record<string, any>;
|
||||
result: Record<string, any>;
|
||||
result?: Record<string, any>;
|
||||
status?: 'pending' | 'completed';
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user