mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: tools agent refactor for custom fields and unique actions
This commit is contained in:
@@ -4,42 +4,6 @@ from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.tool_manager import ToolManager
|
||||
|
||||
tool_tg = {
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to telegram about current chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
|
||||
tool_crypto = {
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)",
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)",
|
||||
},
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
@@ -49,7 +13,7 @@ class Agent:
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = [{"type": "function", "function": tool_crypto}]
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
@@ -58,50 +22,102 @@ class Agent:
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
for tool in user_tools:
|
||||
tool.pop("_id")
|
||||
user_tools = {tool["name"]: tool for tool in user_tools}
|
||||
return user_tools
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
return tools_by_id
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": {
|
||||
**action["parameters"],
|
||||
"properties": {
|
||||
k: {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
for k, v in action["parameters"]["properties"].items()
|
||||
if v.get("filled_by_llm", False)
|
||||
},
|
||||
"required": [
|
||||
key
|
||||
for key in action["parameters"]["required"]
|
||||
if key in action["parameters"]["properties"]
|
||||
and action["parameters"]["properties"][key].get(
|
||||
"filled_by_llm", False
|
||||
)
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
for action in tool["actions"]
|
||||
if action["active"]
|
||||
]
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
call_id = call.id
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
action for action in tool_data["actions"] if action["name"] == action_name
|
||||
)
|
||||
|
||||
for param, details in action_data["parameters"]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
call_args[param] = details["value"]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args), call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
tools_dict = self._get_user_tools()
|
||||
# combine all tool_actions into one list
|
||||
self.tools.extend(
|
||||
[
|
||||
{"type": "function", "function": tool_action}
|
||||
for tool in tools_dict.values()
|
||||
for tool_action in tool["actions"]
|
||||
]
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
if isinstance(resp, str):
|
||||
# Yield the response if it's a string and exit
|
||||
yield resp
|
||||
return
|
||||
if resp.message.content:
|
||||
yield resp.message.content
|
||||
return
|
||||
|
||||
while resp.finish_reason == "tool_calls":
|
||||
# Append the assistant's message to the conversation
|
||||
messages.append(json.loads(resp.model_dump_json())["message"])
|
||||
# Handle each tool call
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
tm = ToolManager(config={})
|
||||
call_name = call.function.name
|
||||
call_args = json.loads(call.function.arguments)
|
||||
call_id = call.id
|
||||
# Determine the tool name and load it
|
||||
tool_name = call_name.split("_")[0]
|
||||
tool = tm.load_tool(
|
||||
tool_name, tool_config=tools_dict[tool_name]["config"]
|
||||
)
|
||||
# Execute the tool's action
|
||||
resp_tool = tool.execute_action(call_name, **call_args)
|
||||
# Append the tool's response to the conversation
|
||||
messages.append(
|
||||
{"role": "tool", "content": str(resp_tool), "tool_call_id": call_id}
|
||||
)
|
||||
try:
|
||||
tool_response, call_id = self._execute_tool_action(tools_dict, call)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
@@ -110,6 +126,8 @@ class Agent:
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
elif resp.message.content:
|
||||
yield resp.message.content
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
@@ -117,6 +135,8 @@ class Agent:
|
||||
for line in completion:
|
||||
yield line
|
||||
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
# Generate initial response from the LLM
|
||||
if self.llm.supports_tools():
|
||||
|
||||
Reference in New Issue
Block a user