diff --git a/application/llm/openai.py b/application/llm/openai.py index b507a1da..36861584 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,3 +1,5 @@ +import json + from application.core.settings import settings from application.llm.base import BaseLLM @@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM): self.api_key = api_key self.user_api_key = user_api_key + def _clean_messages_openai(self, messages): + cleaned_messages = [] + for message in messages: + role = message.get("role") + content = message.get("content") + + if role == "model": + role = "assistant" + + if role and content is not None: + if isinstance(content, str): + cleaned_messages.append({"role": role, "content": content}) + elif isinstance(content, list): + for item in content: + if "text" in item: + cleaned_messages.append( + {"role": role, "content": item["text"]} + ) + elif "function_call" in item: + tool_call = { + "id": item["function_call"]["call_id"], + "type": "function", + "function": { + "name": item["function_call"]["name"], + "arguments": json.dumps( + item["function_call"]["args"] + ), + }, + } + cleaned_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [tool_call], + } + ) + elif "function_response" in item: + cleaned_messages.append( + { + "role": "tool", + "tool_call_id": item["function_response"][ + "call_id" + ], + "content": json.dumps( + item["function_response"]["response"]["result"] + ), + } + ) + else: + raise ValueError( + f"Unexpected content dictionary format: {item}" + ) + else: + raise ValueError(f"Unexpected content type: {type(content)}") + + return cleaned_messages + def _raw_gen( self, baseself, @@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM): engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs, ): + messages = self._clean_messages_openai(messages) + print(messages) if tools: response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, tools=tools, **kwargs + model=model, + messages=messages, + stream=stream, + tools=tools, + **kwargs, ) return response.choices[0] else: @@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM): engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs, ): + messages = self._clean_messages_openai(messages) response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index a0c03cbf..ca40f966 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,3 +1,5 @@ +import uuid + from application.core.settings import settings from application.retriever.base import BaseRetriever from application.tools.agent import Agent @@ -86,21 +88,38 @@ class ClassicRAG(BaseRetriever): ) if "tool_calls" in i: for tool_call in i["tool_calls"]: - messages_combine.append( - { - "role": "assistant", - "content": f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}", - } - ) - messages_combine.append({"role": "user", "content": self.question}) - # llm = LLMCreator.create_llm( - # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key - # ) - # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + call_id = tool_call.get("call_id") + if call_id is None or call_id == "None": + call_id = str(uuid.uuid4()) + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + messages_combine.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages_combine.append( + {"role": "tool", "content": [function_response_dict]} + ) + + messages_combine.append({"role": "user", "content": self.question}) completion = self.agent.gen(messages_combine) + for line in completion: yield {"answer": str(line)} + yield {"tool_calls": self.agent.tool_calls.copy()} def search(self): diff --git a/application/tools/agent.py b/application/tools/agent.py index 148160db..10798862 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -127,9 +127,10 @@ class Agent: tool_call_data = { "tool_name": tool_data["name"], - "action_name": action_name, - "arguments": str(call_args), - "result": str(result), + "call_id": call_id if call_id is not None else "None", + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + "result": result, } self.tool_calls.append(tool_call_data) diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py index cc7494c0..334d2c4c 100644 --- a/application/tools/llm_handler.py +++ b/application/tools/llm_handler.py @@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler): tool_response, call_id = agent._execute_tool_action( tools_dict, call ) - messages.append( - { - "role": "tool", - "content": str(tool_response), - "tool_call_id": call_id, + function_call_dict = { + "function_call": { + "name": call.function.name, + "args": call.function.arguments, + "call_id": call_id, } + } + function_response_dict = { + "function_response": { + "name": call.function.name, + "response": {"result": tool_response}, + "call_id": call_id, + } + } + + messages.append( + {"role": "assistant", "content": [function_call_dict]} ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + except Exception as e: messages.append( { diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 4d26a3e5..ab70584f 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -628,7 +628,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { {toolCalls.map((toolCall, index) => ( Arguments {' '} - +

- {toolCall.arguments} + {JSON.stringify(toolCall.arguments, null, 2)}

@@ -654,14 +656,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { Response {' '} - +

- {toolCall.result} + {JSON.stringify(toolCall.result, null, 2)}

diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index ae58a81d..9b5f2365 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -1,6 +1,7 @@ export type ToolCallsType = { tool_name: string; action_name: string; - arguments: string; - result: string; + call_id: string; + arguments: Record; + result: Record; };