feat: tools frontend and endpoints refactor

This commit is contained in:
Siddhant Rai
2024-12-18 22:48:40 +05:30
parent f87ae429f4
commit f9a7db11eb
23 changed files with 1069 additions and 168 deletions

View File

@@ -1,14 +1,14 @@
import datetime
import math
import os
import shutil
import uuid
import math
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request, redirect
from flask_restx import inputs, fields, Namespace, Resource
from flask import Blueprint, jsonify, make_response, redirect, request
from flask_restx import fields, inputs, Namespace, Resource
from werkzeug.utils import secure_filename
from application.api.user.tasks import ingest, ingest_remote
@@ -16,9 +16,10 @@ from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
from application.tools.tool_manager import ToolManager
from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
from application.tts.google_tts import GoogleTTS
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -40,6 +41,9 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
tool_config = {}
tool_manager = ToolManager(config=tool_config)
def generate_minute_range(start_date, end_date):
return {
@@ -1789,24 +1793,88 @@ class TextToSpeech(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@user_ns.route("/api/available_tools")
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
doc = tool_instance.__doc__.strip()
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"actions": tool_instance.get_actions_metadata(),
}
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
@user_ns.route("/api/get_tools")
class GetTools(Resource):
@api.doc(description="Get tools created by a user")
def get(self):
try:
user = "local"
tools = user_tools_collection.find({"user": user})
user_tools = []
for tool in tools:
tool["id"] = str(tool["_id"])
tool.pop("_id")
user_tools.append(tool)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
@user_ns.route("/api/create_tool")
class CreateTool(Resource):
create_tool_model = api.model(
"CreateToolModel",
{
"name": fields.String(required=True, description="Name of the tool"),
"config": fields.Raw(required=True, description="Configuration of the tool"),
"actions": fields.List(required=True, description="Actions the tool can perform"),
"status": fields.Boolean(required=True, description="Status of the tool")
},
@api.expect(
api.model(
"CreateToolModel",
{
"name": fields.String(required=True, description="Name of the tool"),
"displayName": fields.String(
required=True, description="Display name for the tool"
),
"description": fields.String(
required=True, description="Tool description"
),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.expect(create_tool_model)
@api.doc(description="Create a new tool")
def post(self):
data = request.get_json()
required_fields = ["name", "config", "actions", "status"]
required_fields = [
"name",
"displayName",
"description",
"actions",
"config",
"status",
]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
@@ -1814,10 +1882,12 @@ class CreateTool(Resource):
user = "local"
try:
new_tool = {
"name": data["name"],
"config": data["config"],
"actions": data["actions"],
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"actions": data["actions"],
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
@@ -1826,18 +1896,72 @@ class CreateTool(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"id": new_id}), 200)
@user_ns.route("/api/update_tool")
class UpdateTool(Resource):
@api.expect(
api.model(
"UpdateToolModel",
{
"id": fields.String(required=True, description="Tool ID"),
"name": fields.String(description="Name of the tool"),
"displayName": fields.String(description="Display name for the tool"),
"description": fields.String(description="Tool description"),
"config": fields.Raw(description="Configuration of the tool"),
"actions": fields.List(
fields.Raw, description="Actions the tool can perform"
),
"status": fields.Boolean(description="Status of the tool"),
},
)
)
@api.doc(description="Update a tool by ID")
def post(self):
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
if "config" in data:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": "local"},
{"$set": update_data},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_config")
class UpdateToolConfig(Resource):
update_tool_config_model = api.model(
"UpdateToolConfigModel",
{
"id": fields.String(required=True, description="Tool ID"),
"config": fields.Raw(required=True, description="Configuration of the tool"),
},
@api.expect(
api.model(
"UpdateToolConfigModel",
{
"id": fields.String(required=True, description="Tool ID"),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
},
)
)
@api.expect(update_tool_config_model)
@api.doc(description="Update the configuration of a tool")
def post(self):
data = request.get_json()
@@ -1855,18 +1979,23 @@ class UpdateToolConfig(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_actions")
class UpdateToolActions(Resource):
update_tool_actions_model = api.model(
"UpdateToolActionsModel",
{
"id": fields.String(required=True, description="Tool ID"),
"actions": fields.List(required=True, description="Actions the tool can perform"),
},
@api.expect(
api.model(
"UpdateToolActionsModel",
{
"id": fields.String(required=True, description="Tool ID"),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
},
)
)
@api.expect(update_tool_actions_model)
@api.doc(description="Update the actions of a tool")
def post(self):
data = request.get_json()
@@ -1885,17 +2014,20 @@ class UpdateToolActions(Resource):
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_status")
class UpdateToolStatus(Resource):
update_tool_status_model = api.model(
"UpdateToolStatusModel",
{
"id": fields.String(required=True, description="Tool ID"),
"status": fields.Boolean(required=True, description="Status of the tool"),
},
@api.expect(
api.model(
"UpdateToolStatusModel",
{
"id": fields.String(required=True, description="Tool ID"),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.expect(update_tool_status_model)
@api.doc(description="Update the status of a tool")
def post(self):
data = request.get_json()
@@ -1913,15 +2045,16 @@ class UpdateToolStatus(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/delete_tool")
class DeleteTool(Resource):
delete_tool_model = api.model(
"DeleteToolModel",
{"id": fields.String(required=True, description="Tool ID")},
@api.expect(
api.model(
"DeleteToolModel",
{"id": fields.String(required=True, description="Tool ID")},
)
)
@api.expect(delete_tool_model)
@api.doc(description="Delete a tool by ID")
def post(self):
data = request.get_json()
@@ -1937,4 +2070,4 @@ class DeleteTool(Resource):
except Exception as err:
return {"success": False, "error": str(err)}, 400
return {"success": True}, 200
return {"success": True}, 200

View File

@@ -1,45 +1,32 @@
from application.llm.base import BaseLLM
from openai import OpenAI
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,9 +1,10 @@
from application.llm.llm_creator import LLMCreator
from application.core.settings import settings
from application.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
import json
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.tools.tool_manager import ToolManager
tool_tg = {
"name": "telegram_send_message",
"description": "Send a notification to telegram about current chat",
@@ -12,15 +13,15 @@ tool_tg = {
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
"description": "Text to send in the notification",
}
},
"required": ["text"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
tool_crypto = {
tool_crypto = {
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
@@ -28,33 +29,30 @@ tool_crypto = {
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = [
{
"type": "function",
"function": tool_crypto
}
]
self.tool_config = {
}
self.tools = [{"type": "function", "function": tool_crypto}]
self.tool_config = {}
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -65,19 +63,17 @@ class Agent:
tool.pop("_id")
user_tools = {tool["name"]: tool for tool in user_tools}
return user_tools
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
# combine all tool_actions into one list
self.tools.extend([
{
"type": "function",
"function": tool_action
}
for tool in tools_dict.values()
for tool_action in tool["actions"]
])
self.tools.extend(
[
{"type": "function", "function": tool_action}
for tool in tools_dict.values()
for tool_action in tool["actions"]
]
)
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
@@ -88,7 +84,7 @@ class Agent:
while resp.finish_reason == "tool_calls":
# Append the assistant's message to the conversation
messages.append(json.loads(resp.model_dump_json())['message'])
messages.append(json.loads(resp.model_dump_json())["message"])
# Handle each tool call
tool_calls = resp.message.tool_calls
for call in tool_calls:
@@ -98,25 +94,27 @@ class Agent:
call_id = call.id
# Determine the tool name and load it
tool_name = call_name.split("_")[0]
tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config'])
tool = tm.load_tool(
tool_name, tool_config=tools_dict[tool_name]["config"]
)
# Execute the tool's action
resp_tool = tool.execute_action(call_name, **call_args)
# Append the tool's response to the conversation
messages.append(
{
"role": "tool",
"content": str(resp_tool),
"tool_call_id": call_id
}
{"role": "tool", "content": str(resp_tool), "tool_call_id": call_id}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
else:
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
yield line
@@ -127,4 +125,4 @@ class Agent:
else:
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
for line in resp:
yield line
yield line

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):

View File

@@ -1,21 +1,25 @@
from application.tools.base import Tool
import requests
from application.tools.base import Tool
class CryptoPriceTool(Tool):
"""
CryptoPrice
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {
"cryptoprice_get": self.get_price
}
actions = {"cryptoprice_get": self._get_price}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def get_price(self, symbol, currency):
def _get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
@@ -32,17 +36,17 @@ class CryptoPriceTool(Tool):
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully."
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}."
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price."
"message": "Failed to retrieve price.",
}
def get_actions_metadata(self):
@@ -55,16 +59,16 @@ class CryptoPriceTool(Tool):
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
]

View File

@@ -1,16 +1,23 @@
from application.tools.base import Tool
import requests
from application.tools.base import Tool
class TelegramTool(Tool):
"""
Telegram Bot
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
Requires a bot token and chat ID for configuration
"""
def __init__(self, config):
self.config = config
self.chat_id = config.get("chat_id", "142189016")
self.token = config.get("token", "YOUR_TG_TOKEN")
self.token = config.get("token", "")
self.chat_id = config.get("chat_id", "")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self.send_message,
"telegram_send_image": self.send_image
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
@@ -18,14 +25,14 @@ class TelegramTool(Tool):
else:
raise ValueError(f"Unknown action: {action_name}")
def send_message(self, text):
def _send_message(self, text):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": self.chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def send_image(self, image_url):
def _send_image(self, image_url):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": self.chat_id, "photo": image_url}
@@ -42,12 +49,12 @@ class TelegramTool(Tool):
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
"description": "Text to send in the notification",
}
},
"required": ["text"],
"additionalProperties": False
}
"additionalProperties": False,
},
},
{
"name": "telegram_send_image",
@@ -57,23 +64,20 @@ class TelegramTool(Tool):
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send"
"description": "URL of the image to send",
}
},
"required": ["image_url"],
"additionalProperties": False
}
}
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"chat_id": {
"type": "string",
"description": "Telegram chat ID to send messages to"
"description": "Telegram chat ID to send messages to",
},
"token": {
"type": "string",
"description": "Bot token for authentication"
}
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,10 +1,11 @@
import importlib
import inspect
import pkgutil
import os
import pkgutil
from application.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
@@ -12,11 +13,13 @@ class ToolManager:
self.load_tools()
def load_tools(self):
tools_dir = os.path.dirname(__file__)
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == 'base' or name.startswith('__'):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(f'application.tools.{name}')
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
@@ -24,13 +27,14 @@ class ToolManager:
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
tools_dir = os.path.dirname(__file__)
module = importlib.import_module(f'application.tools.{tool_name}')
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")