feat: enhance tool call handling with structured message cleaning and improved UI display

This commit is contained in:
Siddhant Rai
2025-02-14 00:15:01 +05:30
parent 8a3612e56c
commit 5cf5bed6a8
6 changed files with 133 additions and 27 deletions

View File

@@ -1,3 +1,5 @@
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
self,
baseself,
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
print(messages)
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)

View File

@@ -1,3 +1,5 @@
import uuid
from application.core.settings import settings
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
@@ -86,21 +88,38 @@ class ClassicRAG(BaseRetriever):
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
messages_combine.append(
{
"role": "assistant",
"content": f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}",
}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": self.question})
completion = self.agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):

View File

@@ -127,9 +127,10 @@ class Agent:
tool_call_data = {
"tool_name": tool_data["name"],
"action_name": action_name,
"arguments": str(call_args),
"result": str(result),
"call_id": call_id if call_id is not None else "None",
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": result,
}
self.tool_calls.append(tool_call_data)

View File

@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
except Exception as e:
messages.append(
{

View File

@@ -628,7 +628,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
{toolCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name}`}
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]"
titleClassName="px-6 py-2 text-sm font-semibold"
children={
@@ -638,14 +638,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Arguments
</span>{' '}
<CopyButton text={toolCall.arguments} />
<CopyButton
text={JSON.stringify(toolCall.arguments, null, 2)}
/>
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span
className="text-black dark:text-gray-400 leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.arguments}
{JSON.stringify(toolCall.arguments, null, 2)}
</span>
</p>
</div>
@@ -654,14 +656,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Response
</span>{' '}
<CopyButton text={toolCall.result} />
<CopyButton
text={JSON.stringify(toolCall.result, null, 2)}
/>
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span
className="text-black dark:text-gray-400 leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.result}
{JSON.stringify(toolCall.result, null, 2)}
</span>
</p>
</div>

View File

@@ -1,6 +1,7 @@
export type ToolCallsType = {
tool_name: string;
action_name: string;
arguments: string;
result: string;
call_id: string;
arguments: Record<string, any>;
result: Record<string, any>;
};