diff --git a/application/llm/openai.py b/application/llm/openai.py index cc2285a1..b507a1da 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,6 +1,5 @@ -from application.llm.base import BaseLLM from application.core.settings import settings - +from application.llm.base import BaseLLM class OpenAILLM(BaseLLM): @@ -10,10 +9,7 @@ class OpenAILLM(BaseLLM): super().__init__(*args, **kwargs) if settings.OPENAI_BASE_URL: - self.client = OpenAI( - api_key=api_key, - base_url=settings.OPENAI_BASE_URL - ) + self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL) else: self.client = OpenAI(api_key=api_key) self.api_key = api_key @@ -27,8 +23,8 @@ class OpenAILLM(BaseLLM): stream=False, tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, - **kwargs - ): + **kwargs, + ): if tools: response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, tools=tools, **kwargs @@ -48,18 +44,16 @@ class OpenAILLM(BaseLLM): stream=True, tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, - **kwargs - ): + **kwargs, + ): response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) for line in response: - # import sys - # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content - + def _supports_tools(self): return True diff --git a/application/tools/agent.py b/application/tools/agent.py index a6aec2e9..f4b37d9b 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,8 +1,7 @@ -import json -import logging - from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator +from application.tools.llm_handler import get_llm_handler +from application.tools.tool_action_parser import ToolActionParser from application.tools.tool_manager import ToolManager @@ -12,6 +11,7 @@ class Agent: self.llm = LLMCreator.create_llm( llm_name, api_key=api_key, user_api_key=user_api_key ) + self.llm_handler = get_llm_handler(llm_name) self.gpt_model = gpt_model # Static tool configuration (to be replaced later) self.tools = [] @@ -61,10 +61,8 @@ class Agent: ] def _execute_tool_action(self, tools_dict, call): - call_id = call.id - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] + parser = ToolActionParser(self.llm.__class__.__name__) + tool_id, action_name, call_args = parser.parse_args(call) tool_data = tools_dict[tool_id] action_data = next( @@ -78,26 +76,9 @@ class Agent: tm = ToolManager(config={}) tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) print(f"Executing tool: {action_name} with args: {call_args}") - return tool.execute_action(action_name, **call_args), call_id - - def _execute_tool_action_google(self, tools_dict, call): - call_args = json.loads(call.args) - tool_id = call.name.split("_")[-1] - action_name = call.name.rsplit("_", 1)[0] - - tool_data = tools_dict[tool_id] - action_data = next( - action for action in tool_data["actions"] if action["name"] == action_name - ) - - for param, details in action_data["parameters"]["properties"].items(): - if param not in call_args and "value" in details: - call_args[param] = details["value"] - - tm = ToolManager(config={}) - tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) - print(f"Executing tool: {action_name} with args: {call_args}") - return tool.execute_action(action_name, **call_args) + result = tool.execute_action(action_name, **call_args) + call_id = getattr(call, "id", None) + return result, call_id def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() @@ -111,47 +92,8 @@ class Agent: if resp.message.content: yield resp.message.content return - # check if self.llm class is GoogleLLM - while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call: - from google.genai import types - function_call_part = resp.candidates[0].content.parts[0] - tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call) - function_response_part = types.Part.from_function_response( - name=function_call_part.function_call.name, - response=tool_response - ) - - while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - tool_response, call_id = self._execute_tool_action(tools_dict, call) - messages.append( - { - "role": "tool", - "content": str(tool_response), - "tool_call_id": call_id, - } - ) - except Exception as e: - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call.id, - } - ) - # Generate a new response from the LLM after processing tools - resp = self.llm.gen( - model=self.gpt_model, messages=messages, tools=self.tools - ) + resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) # If no tool calls are needed, generate the final response if isinstance(resp, str): diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py new file mode 100644 index 00000000..58fce56e --- /dev/null +++ b/application/tools/llm_handler.py @@ -0,0 +1,74 @@ +import json +from abc import ABC, abstractmethod + + +class LLMHandler(ABC): + @abstractmethod + def handle_response(self, agent, resp, tools_dict, messages, **kwargs): + pass + + +class OpenAILLMHandler(LLMHandler): + def handle_response(self, agent, resp, tools_dict, messages): + while resp.finish_reason == "tool_calls": + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) + + tool_calls = resp.message.tool_calls + for call in tool_calls: + try: + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + messages.append( + { + "role": "tool", + "content": str(tool_response), + "tool_call_id": call_id, + } + ) + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call_id, + } + ) + resp = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + return resp + + +class GoogleLLMHandler(LLMHandler): + def handle_response(self, agent, resp, tools_dict, messages): + from google.genai import types + + while resp.content.parts[0].function_call: + function_call_part = resp.candidates[0].content.parts[0] + tool_response, call_id = agent._execute_tool_action( + tools_dict, function_call_part.function_call + ) + function_response_part = types.Part.from_function_response( + name=function_call_part.function_call.name, response=tool_response + ) + + messages.append(function_call_part, function_response_part) + resp = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + + return resp + + +def get_llm_handler(llm_type): + handlers = { + "openai": OpenAILLMHandler(), + "google": GoogleLLMHandler(), + } + return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/tools/tool_action_parser.py b/application/tools/tool_action_parser.py new file mode 100644 index 00000000..b708992a --- /dev/null +++ b/application/tools/tool_action_parser.py @@ -0,0 +1,26 @@ +import json + + +class ToolActionParser: + def __init__(self, llm_type): + self.llm_type = llm_type + self.parsers = { + "OpenAILLM": self._parse_openai_llm, + "GoogleLLM": self._parse_google_llm, + } + + def parse_args(self, call): + parser = self.parsers.get(self.llm_type, self._parse_openai_llm) + return parser(call) + + def _parse_openai_llm(self, call): + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + return tool_id, action_name, call_args + + def _parse_google_llm(self, call): + call_args = json.loads(call.args) + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + return tool_id, action_name, call_args