mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: tools frontend and endpoints refactor
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.core.settings import settings
|
||||
from application.tools.tool_manager import ToolManager
|
||||
from application.core.mongo_db import MongoDB
|
||||
import json
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.tool_manager import ToolManager
|
||||
|
||||
tool_tg = {
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to telegram about current chat",
|
||||
@@ -12,15 +13,15 @@ tool_tg = {
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification"
|
||||
"description": "Text to send in the notification",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
|
||||
tool_crypto = {
|
||||
tool_crypto = {
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
@@ -28,33 +29,30 @@ tool_crypto = {
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)"
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)",
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)"
|
||||
}
|
||||
"description": "The currency in which you want the price (e.g. USD)",
|
||||
},
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
# Initialize the LLM with the provided parameters
|
||||
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": tool_crypto
|
||||
}
|
||||
]
|
||||
self.tool_config = {
|
||||
}
|
||||
|
||||
self.tools = [{"type": "function", "function": tool_crypto}]
|
||||
self.tool_config = {}
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
@@ -65,19 +63,17 @@ class Agent:
|
||||
tool.pop("_id")
|
||||
user_tools = {tool["name"]: tool for tool in user_tools}
|
||||
return user_tools
|
||||
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
tools_dict = self._get_user_tools()
|
||||
# combine all tool_actions into one list
|
||||
self.tools.extend([
|
||||
{
|
||||
"type": "function",
|
||||
"function": tool_action
|
||||
}
|
||||
for tool in tools_dict.values()
|
||||
for tool_action in tool["actions"]
|
||||
])
|
||||
|
||||
self.tools.extend(
|
||||
[
|
||||
{"type": "function", "function": tool_action}
|
||||
for tool in tools_dict.values()
|
||||
for tool_action in tool["actions"]
|
||||
]
|
||||
)
|
||||
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
@@ -88,7 +84,7 @@ class Agent:
|
||||
|
||||
while resp.finish_reason == "tool_calls":
|
||||
# Append the assistant's message to the conversation
|
||||
messages.append(json.loads(resp.model_dump_json())['message'])
|
||||
messages.append(json.loads(resp.model_dump_json())["message"])
|
||||
# Handle each tool call
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
@@ -98,25 +94,27 @@ class Agent:
|
||||
call_id = call.id
|
||||
# Determine the tool name and load it
|
||||
tool_name = call_name.split("_")[0]
|
||||
tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config'])
|
||||
tool = tm.load_tool(
|
||||
tool_name, tool_config=tools_dict[tool_name]["config"]
|
||||
)
|
||||
# Execute the tool's action
|
||||
resp_tool = tool.execute_action(call_name, **call_args)
|
||||
# Append the tool's response to the conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(resp_tool),
|
||||
"tool_call_id": call_id
|
||||
}
|
||||
{"role": "tool", "content": str(resp_tool), "tool_call_id": call_id}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
resp = self.llm.gen(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
else:
|
||||
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
completion = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
for line in completion:
|
||||
yield line
|
||||
|
||||
@@ -127,4 +125,4 @@ class Agent:
|
||||
else:
|
||||
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
|
||||
for line in resp:
|
||||
yield line
|
||||
yield line
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
from application.tools.base import Tool
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class CryptoPriceTool(Tool):
|
||||
"""
|
||||
CryptoPrice
|
||||
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"cryptoprice_get": self.get_price
|
||||
}
|
||||
actions = {"cryptoprice_get": self._get_price}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def get_price(self, symbol, currency):
|
||||
def _get_price(self, symbol, currency):
|
||||
"""
|
||||
Fetches the current price of a given cryptocurrency symbol in the specified currency.
|
||||
Example:
|
||||
@@ -32,17 +36,17 @@ class CryptoPriceTool(Tool):
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"price": data[currency.upper()],
|
||||
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully."
|
||||
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}."
|
||||
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": "Failed to retrieve price."
|
||||
"message": "Failed to retrieve price.",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -55,16 +59,16 @@ class CryptoPriceTool(Tool):
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)"
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)",
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)"
|
||||
}
|
||||
"description": "The currency in which you want the price (e.g. USD)",
|
||||
},
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
from application.tools.base import Tool
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class TelegramTool(Tool):
|
||||
"""
|
||||
Telegram Bot
|
||||
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
|
||||
Requires a bot token and chat ID for configuration
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.chat_id = config.get("chat_id", "142189016")
|
||||
self.token = config.get("token", "YOUR_TG_TOKEN")
|
||||
self.token = config.get("token", "")
|
||||
self.chat_id = config.get("chat_id", "")
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"telegram_send_message": self.send_message,
|
||||
"telegram_send_image": self.send_image
|
||||
"telegram_send_message": self._send_message,
|
||||
"telegram_send_image": self._send_image,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
@@ -18,14 +25,14 @@ class TelegramTool(Tool):
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def send_message(self, text):
|
||||
def _send_message(self, text):
|
||||
print(f"Sending message: {text}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": self.chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def send_image(self, image_url):
|
||||
def _send_image(self, image_url):
|
||||
print(f"Sending image: {image_url}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": self.chat_id, "photo": image_url}
|
||||
@@ -42,12 +49,12 @@ class TelegramTool(Tool):
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification"
|
||||
"description": "Text to send in the notification",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "telegram_send_image",
|
||||
@@ -57,23 +64,20 @@ class TelegramTool(Tool):
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the image to send"
|
||||
"description": "URL of the image to send",
|
||||
}
|
||||
},
|
||||
"required": ["image_url"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Telegram chat ID to send messages to"
|
||||
"description": "Telegram chat ID to send messages to",
|
||||
},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Bot token for authentication"
|
||||
}
|
||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
@@ -12,11 +13,13 @@ class ToolManager:
|
||||
self.load_tools()
|
||||
|
||||
def load_tools(self):
|
||||
tools_dir = os.path.dirname(__file__)
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
|
||||
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
||||
if name == 'base' or name.startswith('__'):
|
||||
if name == "base" or name.startswith("__"):
|
||||
continue
|
||||
module = importlib.import_module(f'application.tools.{name}')
|
||||
module = importlib.import_module(
|
||||
f"application.tools.implementations.{name}"
|
||||
)
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
tool_config = self.config.get(name, {})
|
||||
@@ -24,13 +27,14 @@ class ToolManager:
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
tools_dir = os.path.dirname(__file__)
|
||||
module = importlib.import_module(f'application.tools.{tool_name}')
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
|
||||
module = importlib.import_module(
|
||||
f"application.tools.implementations.{tool_name}"
|
||||
)
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
|
||||
Reference in New Issue
Block a user