diff --git a/application/llm/openai.py b/application/llm/openai.py
index b507a1da..36861584 100644
--- a/application/llm/openai.py
+++ b/application/llm/openai.py
@@ -1,3 +1,5 @@
+import json
+
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
self.api_key = api_key
self.user_api_key = user_api_key
+ def _clean_messages_openai(self, messages):
+ cleaned_messages = []
+ for message in messages:
+ role = message.get("role")
+ content = message.get("content")
+
+ if role == "model":
+ role = "assistant"
+
+ if role and content is not None:
+ if isinstance(content, str):
+ cleaned_messages.append({"role": role, "content": content})
+ elif isinstance(content, list):
+ for item in content:
+ if "text" in item:
+ cleaned_messages.append(
+ {"role": role, "content": item["text"]}
+ )
+ elif "function_call" in item:
+ tool_call = {
+ "id": item["function_call"]["call_id"],
+ "type": "function",
+ "function": {
+ "name": item["function_call"]["name"],
+ "arguments": json.dumps(
+ item["function_call"]["args"]
+ ),
+ },
+ }
+ cleaned_messages.append(
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [tool_call],
+ }
+ )
+ elif "function_response" in item:
+ cleaned_messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": item["function_response"][
+ "call_id"
+ ],
+ "content": json.dumps(
+ item["function_response"]["response"]["result"]
+ ),
+ }
+ )
+ else:
+ raise ValueError(
+ f"Unexpected content dictionary format: {item}"
+ )
+ else:
+ raise ValueError(f"Unexpected content type: {type(content)}")
+
+ return cleaned_messages
+
def _raw_gen(
self,
baseself,
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
+ messages = self._clean_messages_openai(messages)
+ print(messages)
if tools:
response = self.client.chat.completions.create(
- model=model, messages=messages, stream=stream, tools=tools, **kwargs
+ model=model,
+ messages=messages,
+ stream=stream,
+ tools=tools,
+ **kwargs,
)
return response.choices[0]
else:
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
+ messages = self._clean_messages_openai(messages)
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index a0c03cbf..ca40f966 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -1,3 +1,5 @@
+import uuid
+
from application.core.settings import settings
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
@@ -86,21 +88,38 @@ class ClassicRAG(BaseRetriever):
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
- messages_combine.append(
- {
- "role": "assistant",
- "content": f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}",
- }
- )
- messages_combine.append({"role": "user", "content": self.question})
- # llm = LLMCreator.create_llm(
- # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
- # )
- # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
+ call_id = tool_call.get("call_id")
+ if call_id is None or call_id == "None":
+ call_id = str(uuid.uuid4())
+ function_call_dict = {
+ "function_call": {
+ "name": tool_call.get("action_name"),
+ "args": tool_call.get("arguments"),
+ "call_id": call_id,
+ }
+ }
+ function_response_dict = {
+ "function_response": {
+ "name": tool_call.get("action_name"),
+ "response": {"result": tool_call.get("result")},
+ "call_id": call_id,
+ }
+ }
+
+ messages_combine.append(
+ {"role": "assistant", "content": [function_call_dict]}
+ )
+ messages_combine.append(
+ {"role": "tool", "content": [function_response_dict]}
+ )
+
+ messages_combine.append({"role": "user", "content": self.question})
completion = self.agent.gen(messages_combine)
+
for line in completion:
yield {"answer": str(line)}
+
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):
diff --git a/application/tools/agent.py b/application/tools/agent.py
index 148160db..10798862 100644
--- a/application/tools/agent.py
+++ b/application/tools/agent.py
@@ -127,9 +127,10 @@ class Agent:
tool_call_data = {
"tool_name": tool_data["name"],
- "action_name": action_name,
- "arguments": str(call_args),
- "result": str(result),
+ "call_id": call_id if call_id is not None else "None",
+ "action_name": f"{action_name}_{tool_id}",
+ "arguments": call_args,
+ "result": result,
}
self.tool_calls.append(tool_call_data)
diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py
index cc7494c0..334d2c4c 100644
--- a/application/tools/llm_handler.py
+++ b/application/tools/llm_handler.py
@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
- messages.append(
- {
- "role": "tool",
- "content": str(tool_response),
- "tool_call_id": call_id,
+ function_call_dict = {
+ "function_call": {
+ "name": call.function.name,
+ "args": call.function.arguments,
+ "call_id": call_id,
}
+ }
+ function_response_dict = {
+ "function_response": {
+ "name": call.function.name,
+ "response": {"result": tool_response},
+ "call_id": call_id,
+ }
+ }
+
+ messages.append(
+ {"role": "assistant", "content": [function_call_dict]}
)
+ messages.append(
+ {"role": "tool", "content": [function_response_dict]}
+ )
+
except Exception as e:
messages.append(
{
diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx
index 4d26a3e5..ab70584f 100644
--- a/frontend/src/conversation/ConversationBubble.tsx
+++ b/frontend/src/conversation/ConversationBubble.tsx
@@ -628,7 +628,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
{toolCalls.map((toolCall, index) => (
- {toolCall.arguments} + {JSON.stringify(toolCall.arguments, null, 2)}
@@ -654,14 +656,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { Response {' '} -- {toolCall.result} + {JSON.stringify(toolCall.result, null, 2)}
diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index ae58a81d..9b5f2365 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -1,6 +1,7 @@ export type ToolCallsType = { tool_name: string; action_name: string; - arguments: string; - result: string; + call_id: string; + arguments: Record