refactor: tool agent for action parser and handlers

This commit is contained in:
Siddhant Rai
2025-01-15 16:35:26 +05:30
parent 51225b18b2
commit 811dfecf98
4 changed files with 116 additions and 80 deletions

View File

@@ -1,6 +1,5 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
from application.llm.base import BaseLLM
class OpenAILLM(BaseLLM):
@@ -10,10 +9,7 @@ class OpenAILLM(BaseLLM):
super().__init__(*args, **kwargs)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
@@ -27,7 +23,7 @@ class OpenAILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
**kwargs,
):
if tools:
response = self.client.chat.completions.create(
@@ -48,15 +44,13 @@ class OpenAILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
**kwargs,
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,8 +1,7 @@
import json
import logging
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.llm_handler import get_llm_handler
from application.tools.tool_action_parser import ToolActionParser
from application.tools.tool_manager import ToolManager
@@ -12,6 +11,7 @@ class Agent:
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.llm_handler = get_llm_handler(llm_name)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = []
@@ -61,10 +61,8 @@ class Agent:
]
def _execute_tool_action(self, tools_dict, call):
call_id = call.id
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
tool_data = tools_dict[tool_id]
action_data = next(
@@ -78,26 +76,9 @@ class Agent:
tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args), call_id
def _execute_tool_action_google(self, tools_dict, call):
call_args = json.loads(call.args)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
tool_data = tools_dict[tool_id]
action_data = next(
action for action in tool_data["actions"] if action["name"] == action_name
)
for param, details in action_data["parameters"]["properties"].items():
if param not in call_args and "value" in details:
call_args[param] = details["value"]
tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args)
result = tool.execute_action(action_name, **call_args)
call_id = getattr(call, "id", None)
return result, call_id
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
@@ -111,47 +92,8 @@ class Agent:
if resp.message.content:
yield resp.message.content
return
# check if self.llm class is GoogleLLM
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
from google.genai import types
function_call_part = resp.candidates[0].content.parts[0]
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
function_response_part = types.Part.from_function_response(
name=function_call_part.function_call.name,
response=tool_response
)
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = self._execute_tool_action(tools_dict, call)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):

View File

@@ -0,0 +1,74 @@
import json
from abc import ABC, abstractmethod
class LLMHandler(ABC):
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
from google.genai import types
while resp.content.parts[0].function_call:
function_call_part = resp.candidates[0].content.parts[0]
tool_response, call_id = agent._execute_tool_action(
tools_dict, function_call_part.function_call
)
function_response_part = types.Part.from_function_response(
name=function_call_part.function_call.name, response=tool_response
)
messages.append(function_call_part, function_response_part)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())

View File

@@ -0,0 +1,26 @@
import json
class ToolActionParser:
def __init__(self, llm_type):
self.llm_type = llm_type
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
}
def parse_args(self, call):
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _parse_openai_llm(self, call):
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
call_args = json.loads(call.args)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args