mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
refactor: tool agent for action parser and handlers
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
@@ -10,10 +9,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if settings.OPENAI_BASE_URL:
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
else:
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.api_key = api_key
|
||||
@@ -27,8 +23,8 @@ class OpenAILLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
@@ -48,18 +44,16 @@ class OpenAILLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
for line in response:
|
||||
# import sys
|
||||
# print(line.choices[0].delta.content, file=sys.stderr)
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.llm_handler import get_llm_handler
|
||||
from application.tools.tool_action_parser import ToolActionParser
|
||||
from application.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
@@ -12,6 +11,7 @@ class Agent:
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = []
|
||||
@@ -61,10 +61,8 @@ class Agent:
|
||||
]
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
call_id = call.id
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
@@ -78,26 +76,9 @@ class Agent:
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args), call_id
|
||||
|
||||
def _execute_tool_action_google(self, tools_dict, call):
|
||||
call_args = json.loads(call.args)
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
action for action in tool_data["actions"] if action["name"] == action_name
|
||||
)
|
||||
|
||||
for param, details in action_data["parameters"]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
call_args[param] = details["value"]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args)
|
||||
result = tool.execute_action(action_name, **call_args)
|
||||
call_id = getattr(call, "id", None)
|
||||
return result, call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
tools_dict = self._get_user_tools()
|
||||
@@ -111,47 +92,8 @@ class Agent:
|
||||
if resp.message.content:
|
||||
yield resp.message.content
|
||||
return
|
||||
# check if self.llm class is GoogleLLM
|
||||
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
|
||||
from google.genai import types
|
||||
|
||||
function_call_part = resp.candidates[0].content.parts[0]
|
||||
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=function_call_part.function_call.name,
|
||||
response=tool_response
|
||||
)
|
||||
|
||||
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
tool_response, call_id = self._execute_tool_action(tools_dict, call)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
|
||||
74
application/tools/llm_handler.py
Normal file
74
application/tools/llm_handler.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
while resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
resp = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
from google.genai import types
|
||||
|
||||
while resp.content.parts[0].function_call:
|
||||
function_call_part = resp.candidates[0].content.parts[0]
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, function_call_part.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=function_call_part.function_call.name, response=tool_response
|
||||
)
|
||||
|
||||
messages.append(function_call_part, function_response_part)
|
||||
resp = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler(),
|
||||
"google": GoogleLLMHandler(),
|
||||
}
|
||||
return handlers.get(llm_type, OpenAILLMHandler())
|
||||
26
application/tools/tool_action_parser.py
Normal file
26
application/tools/tool_action_parser.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import json
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
self.llm_type = llm_type
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
}
|
||||
|
||||
def parse_args(self, call):
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call_args = json.loads(call.args)
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
Reference in New Issue
Block a user