mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-02 08:12:06 +00:00
feat: tooling init
This commit is contained in:
98
application/tools/agent.py
Normal file
98
application/tools/agent.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.core.settings import settings
|
||||
from application.tools.tool_manager import ToolManager
|
||||
import json
|
||||
|
||||
tool_tg = {
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to telegram about current chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification"
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
|
||||
tool_crypto = {
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)"
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)"
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
|
||||
class Agent:
|
||||
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
# Initialize the LLM with the provided parameters
|
||||
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": tool_crypto
|
||||
}
|
||||
]
|
||||
self.tool_config = {
|
||||
}
|
||||
|
||||
def gen(self, messages):
|
||||
# Generate initial response from the LLM
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
if isinstance(resp, str):
|
||||
# Yield the response if it's a string and exit
|
||||
yield resp
|
||||
return
|
||||
|
||||
while resp.finish_reason == "tool_calls":
|
||||
# Append the assistant's message to the conversation
|
||||
messages.append(json.loads(resp.model_dump_json())['message'])
|
||||
# Handle each tool call
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
tm = ToolManager(config={})
|
||||
call_name = call.function.name
|
||||
call_args = json.loads(call.function.arguments)
|
||||
call_id = call.id
|
||||
# Determine the tool name and load it
|
||||
tool_name = call_name.split("_")[0]
|
||||
tool = tm.load_tool(tool_name, tool_config=self.tool_config)
|
||||
# Execute the tool's action
|
||||
resp_tool = tool.execute_action(call_name, **call_args)
|
||||
# Append the tool's response to the conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(resp_tool),
|
||||
"tool_call_id": call_id
|
||||
}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
else:
|
||||
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
for line in completion:
|
||||
yield line
|
||||
20
application/tools/base.py
Normal file
20
application/tools/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class Tool(ABC):
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
Returns a list of JSON objects describing the actions supported by the tool.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Returns a dictionary describing the configuration requirements for the tool.
|
||||
"""
|
||||
pass
|
||||
73
application/tools/cryptoprice.py
Normal file
73
application/tools/cryptoprice.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from application.tools.base import Tool
|
||||
import requests
|
||||
|
||||
class CryptoPriceTool(Tool):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"cryptoprice_get": self.get_price
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def get_price(self, symbol, currency):
|
||||
"""
|
||||
Fetches the current price of a given cryptocurrency symbol in the specified currency.
|
||||
Example:
|
||||
symbol = "BTC"
|
||||
currency = "USD"
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# data will be like {"USD": <price>} if the call is successful
|
||||
if currency.upper() in data:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"price": data[currency.upper()],
|
||||
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": "Failed to retrieve price."
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)"
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)"
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
# No specific configuration needed for this tool as it just queries a public endpoint
|
||||
return {}
|
||||
79
application/tools/telegram.py
Normal file
79
application/tools/telegram.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from application.tools.base import Tool
|
||||
import requests
|
||||
|
||||
class TelegramTool(Tool):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.chat_id = config.get("chat_id", "142189016")
|
||||
self.token = config.get("token", "YOUR_TG_TOKEN")
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"telegram_send_message": self.send_message,
|
||||
"telegram_send_image": self.send_image
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def send_message(self, text):
|
||||
print(f"Sending message: {text}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": self.chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def send_image(self, image_url):
|
||||
print(f"Sending image: {image_url}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": self.chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification"
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "telegram_send_image",
|
||||
"description": "Send an image to the Telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the image to send"
|
||||
}
|
||||
},
|
||||
"required": ["image_url"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Telegram chat ID to send messages to"
|
||||
},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Bot token for authentication"
|
||||
}
|
||||
}
|
||||
43
application/tools/tool_manager.py
Normal file
43
application/tools/tool_manager.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import os
|
||||
|
||||
from application.tools.base import Tool
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.tools = {}
|
||||
self.load_tools()
|
||||
|
||||
def load_tools(self):
|
||||
tools_dir = os.path.dirname(__file__)
|
||||
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
||||
if name == 'base' or name.startswith('__'):
|
||||
continue
|
||||
module = importlib.import_module(f'application.tools.{name}')
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
tools_dir = os.path.dirname(__file__)
|
||||
module = importlib.import_module(f'application.tools.{tool_name}')
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
metadata = []
|
||||
for tool in self.tools.values():
|
||||
metadata.extend(tool.get_actions_metadata())
|
||||
return metadata
|
||||
Reference in New Issue
Block a user