refactor: folder restructure for agent based workflow

This commit is contained in:
Siddhant Rai
2025-02-25 09:03:45 +05:30
parent 6fed84958e
commit 1f0b779c64
12 changed files with 38 additions and 35 deletions

View File

@@ -1,102 +0,0 @@
import uuid
from typing import Dict, Generator
from application.retriever.base import BaseRetriever
from application.tools.base_agent import BaseAgent
class ClassicAgent(BaseAgent):
def __init__(
self,
llm_name,
gpt_model,
api_key,
user_api_key=None,
prompt="",
chat_history=None,
):
super().__init__(llm_name, gpt_model, api_key, user_api_key)
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]:
retrieved_data = retriever.search(query)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": query})
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
if isinstance(resp, str):
yield {"answer": resp}
return
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
return
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
if isinstance(resp, str):
yield {"answer": resp}
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
else:
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
for line in completion:
yield {"answer": line}
yield {"tool_calls": self.tool_calls.copy()}

View File

@@ -1,21 +0,0 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass
@abstractmethod
def get_actions_metadata(self):
"""
Returns a list of JSON objects describing the actions supported by the tool.
"""
pass
@abstractmethod
def get_config_requirements(self):
"""
Returns a dictionary describing the configuration requirements for the tool.
"""
pass

View File

@@ -1,140 +0,0 @@
from typing import Dict, Generator
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.llm_handler import get_llm_handler
from application.tools.tool_action_parser import ToolActionParser
from application.tools.tool_manager import ToolManager
class BaseAgent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.llm_handler = get_llm_handler(llm_name)
self.gpt_model = gpt_model
self.tools = []
self.tool_config = {}
self.tool_calls = []
def gen(self, query: str) -> Generator[Dict, None, None]:
raise NotImplementedError('Method "gen" must be implemented in the child class')
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
return tools_by_id
def _build_tool_parameters(self, action):
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
}
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _execute_tool_action(self, tools_dict, call):
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if param not in call_args and "value" in details:
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
tool = tm.load_tool(
tool_data["name"],
tool_config=(
{
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
)
if tool_data["name"] == "api_tool":
print(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
call_id = getattr(call, "id", None)
tool_call_data = {
"tool_name": tool_data["name"],
"call_id": call_id if call_id is not None else "None",
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": result,
}
self.tool_calls.append(tool_call_data)
return result, call_id

View File

@@ -1,71 +0,0 @@
import json
import requests
from application.tools.base import Tool
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.query_params = config.get("query_params", {})
def execute_action(self, action_name, **kwargs):
return self._make_api_call(
self.url, self.method, self.headers, self.query_params, kwargs
)
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
if isinstance(body, dict):
body = json.dumps(body)
try:
print(f"Making API call: {method} {url} with body: {body}")
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.RequestException as e:
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
}
def get_actions_metadata(self):
return []
def get_config_requirements(self):
return {}

View File

@@ -1,77 +0,0 @@
import requests
from application.tools.base import Tool
class CryptoPriceTool(Tool):
"""
CryptoPrice
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {"cryptoprice_get": self._get_price}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
symbol = "BTC"
currency = "USD"
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# data will be like {"USD": <price>} if the call is successful
if currency.upper() in data:
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price.",
}
def get_actions_metadata(self):
return [
{
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False,
},
}
]
def get_config_requirements(self):
# No specific configuration needed for this tool as it just queries a public endpoint
return {}

View File

@@ -1,86 +0,0 @@
import requests
from application.tools.base import Tool
class TelegramTool(Tool):
"""
Telegram Bot
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
Requires a bot token and chat ID for configuration
"""
def __init__(self, config):
self.config = config
self.token = config.get("token", "")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
return [
{
"name": "telegram_send_message",
"description": "Send a notification to Telegram chat",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the notification to",
},
},
"required": ["text"],
"additionalProperties": False,
},
},
{
"name": "telegram_send_image",
"description": "Send an image to the Telegram chat",
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the image to",
},
},
"required": ["image_url"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,112 +0,0 @@
import json
from abc import ABC, abstractmethod
class LLMHandler(ABC):
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
from google.genai import types
while True:
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)
function_response_part = types.Part.from_function_response(
name=part.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [part.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if (
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts
else:
return response
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())

View File

@@ -1,26 +0,0 @@
import json
class ToolActionParser:
def __init__(self, llm_type):
self.llm_type = llm_type
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
}
def parse_args(self, call):
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _parse_openai_llm(self, call):
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
call_args = call.args
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args

View File

@@ -1,46 +0,0 @@
import importlib
import inspect
import os
import pkgutil
from application.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
self.tools = {}
self.load_tools()
def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):
metadata = []
for tool in self.tools.values():
metadata.extend(tool.get_actions_metadata())
return metadata