refactor: folder restructure for agent based workflow

This commit is contained in:
Siddhant Rai
2025-02-25 09:03:45 +05:30
parent 6fed84958e
commit 1f0b779c64
12 changed files with 38 additions and 35 deletions

View File

@@ -1,10 +1,11 @@
from typing import Dict, Generator
from application.agents.llm_handler import get_llm_handler
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.llm_handler import get_llm_handler
from application.tools.tool_action_parser import ToolActionParser
from application.tools.tool_manager import ToolManager
class BaseAgent:

View File

@@ -1,8 +1,9 @@
import uuid
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.retriever.base import BaseRetriever
from application.tools.base_agent import BaseAgent
class ClassicAgent(BaseAgent):

View File

@@ -1,7 +1,7 @@
import json
import requests
from application.tools.base import Tool
from application.agents.tools.base import Tool
class APITool(Tool):

View File

@@ -1,5 +1,5 @@
import requests
from application.tools.base import Tool
from application.agents.tools.base import Tool
class CryptoPriceTool(Tool):

View File

@@ -1,5 +1,5 @@
import requests
from application.tools.base import Tool
from application.agents.tools.base import Tool
class TelegramTool(Tool):

View File

@@ -3,7 +3,7 @@ import inspect
import os
import pkgutil
from application.tools.base import Tool
from application.agents.tools.base import Tool
class ToolManager:
@@ -13,13 +13,11 @@ class ToolManager:
self.load_tools()
def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
tools_dir = os.path.join(os.path.dirname(__file__))
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
@@ -27,9 +25,7 @@ class ToolManager:
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)

View File

@@ -1,15 +1,16 @@
import asyncio
import datetime
import json
import logging
import os
import traceback
import logging
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.agents.classic_agent import ClassicAgent
from application.core.mongo_db import MongoDB
from application.core.settings import settings
@@ -17,7 +18,6 @@ from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.tools.agent import ClassicAgent
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)

View File

@@ -1,9 +1,9 @@
import datetime
import json
import math
import os
import shutil
import uuid
import json
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
@@ -12,12 +12,13 @@ from flask import Blueprint, current_app, jsonify, make_response, redirect, requ
from flask_restx import fields, inputs, Namespace, Resource
from werkzeug.utils import secure_filename
from application.agents.tools.tool_manager import ToolManager
from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
from application.tools.tool_manager import ToolManager
from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields, validate_function_name
from application.vectorstore.vector_creator import VectorCreator
@@ -439,12 +440,11 @@ class UploadRemote(Resource):
elif data["source"] == "reddit":
source_data = config
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
user=data["user"],
loader=data["source"]
loader=data["source"],
)
except Exception as err:
current_app.logger.error(f"Error uploading remote source: {err}")
@@ -1936,11 +1936,14 @@ class UpdateTool(Resource):
for action_name in list(data["config"]["actions"].keys()):
if not validate_function_name(action_name):
return make_response(
jsonify({
jsonify(
{
"success": False,
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
"param": "tools[].function.name"
}), 400
"param": "tools[].function.name",
}
),
400,
)
update_data["config"] = data["config"]
if "status" in data:

View File

@@ -51,10 +51,12 @@ class ClassicRAG(BaseRetriever):
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
{self.original_question}
"""
messages = [{"role": "system", "content": prompt}]
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": self.original_question},
]
try:
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)