mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-01 17:43:15 +00:00
feat: enhance tool call handling with structured message cleaning and improved UI display
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
)
|
||||
elif "function_response" in item:
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"][
|
||||
"call_id"
|
||||
],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
print(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.tools.agent import Agent
|
||||
@@ -86,21 +88,38 @@ class ClassicRAG(BaseRetriever):
|
||||
)
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
messages_combine.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}",
|
||||
}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# )
|
||||
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
call_id = tool_call.get("call_id")
|
||||
if call_id is None or call_id == "None":
|
||||
call_id = str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
completion = self.agent.gen(messages_combine)
|
||||
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
yield {"tool_calls": self.agent.tool_calls.copy()}
|
||||
|
||||
def search(self):
|
||||
|
||||
@@ -127,9 +127,10 @@ class Agent:
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"action_name": action_name,
|
||||
"arguments": str(call_args),
|
||||
"result": str(result),
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
|
||||
@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
|
||||
@@ -628,7 +628,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
{toolCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
children={
|
||||
@@ -638,14 +638,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton text={toolCall.arguments} />
|
||||
<CopyButton
|
||||
text={JSON.stringify(toolCall.arguments, null, 2)}
|
||||
/>
|
||||
</p>
|
||||
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
|
||||
<span
|
||||
className="text-black dark:text-gray-400 leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.arguments}
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
@@ -654,14 +656,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton text={toolCall.result} />
|
||||
<CopyButton
|
||||
text={JSON.stringify(toolCall.result, null, 2)}
|
||||
/>
|
||||
</p>
|
||||
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
|
||||
<span
|
||||
className="text-black dark:text-gray-400 leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.result}
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export type ToolCallsType = {
|
||||
tool_name: string;
|
||||
action_name: string;
|
||||
arguments: string;
|
||||
result: string;
|
||||
call_id: string;
|
||||
arguments: Record<string, any>;
|
||||
result: Record<string, any>;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user