feat: tools frontend and endpoints refactor

This commit is contained in:
Siddhant Rai
2024-12-18 22:48:40 +05:30
parent f87ae429f4
commit f9a7db11eb
23 changed files with 1069 additions and 168 deletions

View File

@@ -1,9 +1,10 @@
from application.llm.llm_creator import LLMCreator
from application.core.settings import settings
from application.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
import json
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.tools.tool_manager import ToolManager
tool_tg = {
"name": "telegram_send_message",
"description": "Send a notification to telegram about current chat",
@@ -12,15 +13,15 @@ tool_tg = {
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
"description": "Text to send in the notification",
}
},
"required": ["text"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
tool_crypto = {
tool_crypto = {
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
@@ -28,33 +29,30 @@ tool_crypto = {
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = [
{
"type": "function",
"function": tool_crypto
}
]
self.tool_config = {
}
self.tools = [{"type": "function", "function": tool_crypto}]
self.tool_config = {}
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -65,19 +63,17 @@ class Agent:
tool.pop("_id")
user_tools = {tool["name"]: tool for tool in user_tools}
return user_tools
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
# combine all tool_actions into one list
self.tools.extend([
{
"type": "function",
"function": tool_action
}
for tool in tools_dict.values()
for tool_action in tool["actions"]
])
self.tools.extend(
[
{"type": "function", "function": tool_action}
for tool in tools_dict.values()
for tool_action in tool["actions"]
]
)
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
@@ -88,7 +84,7 @@ class Agent:
while resp.finish_reason == "tool_calls":
# Append the assistant's message to the conversation
messages.append(json.loads(resp.model_dump_json())['message'])
messages.append(json.loads(resp.model_dump_json())["message"])
# Handle each tool call
tool_calls = resp.message.tool_calls
for call in tool_calls:
@@ -98,25 +94,27 @@ class Agent:
call_id = call.id
# Determine the tool name and load it
tool_name = call_name.split("_")[0]
tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config'])
tool = tm.load_tool(
tool_name, tool_config=tools_dict[tool_name]["config"]
)
# Execute the tool's action
resp_tool = tool.execute_action(call_name, **call_args)
# Append the tool's response to the conversation
messages.append(
{
"role": "tool",
"content": str(resp_tool),
"tool_call_id": call_id
}
{"role": "tool", "content": str(resp_tool), "tool_call_id": call_id}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
else:
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
yield line
@@ -127,4 +125,4 @@ class Agent:
else:
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
for line in resp:
yield line
yield line

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):

View File

@@ -1,21 +1,25 @@
from application.tools.base import Tool
import requests
from application.tools.base import Tool
class CryptoPriceTool(Tool):
"""
CryptoPrice
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {
"cryptoprice_get": self.get_price
}
actions = {"cryptoprice_get": self._get_price}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def get_price(self, symbol, currency):
def _get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
@@ -32,17 +36,17 @@ class CryptoPriceTool(Tool):
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully."
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}."
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price."
"message": "Failed to retrieve price.",
}
def get_actions_metadata(self):
@@ -55,16 +59,16 @@ class CryptoPriceTool(Tool):
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
]

View File

@@ -1,16 +1,23 @@
from application.tools.base import Tool
import requests
from application.tools.base import Tool
class TelegramTool(Tool):
"""
Telegram Bot
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
Requires a bot token and chat ID for configuration
"""
def __init__(self, config):
self.config = config
self.chat_id = config.get("chat_id", "142189016")
self.token = config.get("token", "YOUR_TG_TOKEN")
self.token = config.get("token", "")
self.chat_id = config.get("chat_id", "")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self.send_message,
"telegram_send_image": self.send_image
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
@@ -18,14 +25,14 @@ class TelegramTool(Tool):
else:
raise ValueError(f"Unknown action: {action_name}")
def send_message(self, text):
def _send_message(self, text):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": self.chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def send_image(self, image_url):
def _send_image(self, image_url):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": self.chat_id, "photo": image_url}
@@ -42,12 +49,12 @@ class TelegramTool(Tool):
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
"description": "Text to send in the notification",
}
},
"required": ["text"],
"additionalProperties": False
}
"additionalProperties": False,
},
},
{
"name": "telegram_send_image",
@@ -57,23 +64,20 @@ class TelegramTool(Tool):
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send"
"description": "URL of the image to send",
}
},
"required": ["image_url"],
"additionalProperties": False
}
}
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"chat_id": {
"type": "string",
"description": "Telegram chat ID to send messages to"
"description": "Telegram chat ID to send messages to",
},
"token": {
"type": "string",
"description": "Bot token for authentication"
}
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,10 +1,11 @@
import importlib
import inspect
import pkgutil
import os
import pkgutil
from application.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
@@ -12,11 +13,13 @@ class ToolManager:
self.load_tools()
def load_tools(self):
tools_dir = os.path.dirname(__file__)
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == 'base' or name.startswith('__'):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(f'application.tools.{name}')
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
@@ -24,13 +27,14 @@ class ToolManager:
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
tools_dir = os.path.dirname(__file__)
module = importlib.import_module(f'application.tools.{tool_name}')
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")