refactor: folder restructure for agent based workflow

This commit is contained in:
Siddhant Rai
2025-02-25 09:03:45 +05:30
parent 6fed84958e
commit 1f0b779c64
12 changed files with 38 additions and 35 deletions

View File

@@ -1,10 +1,11 @@
from typing import Dict, Generator from typing import Dict, Generator
from application.agents.llm_handler import get_llm_handler
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.tools.llm_handler import get_llm_handler
from application.tools.tool_action_parser import ToolActionParser
from application.tools.tool_manager import ToolManager
class BaseAgent: class BaseAgent:

View File

@@ -1,8 +1,9 @@
import uuid import uuid
from typing import Dict, Generator from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.retriever.base import BaseRetriever from application.retriever.base import BaseRetriever
from application.tools.base_agent import BaseAgent
class ClassicAgent(BaseAgent): class ClassicAgent(BaseAgent):

View File

@@ -1,7 +1,7 @@
import json import json
import requests import requests
from application.tools.base import Tool from application.agents.tools.base import Tool
class APITool(Tool): class APITool(Tool):

View File

@@ -1,5 +1,5 @@
import requests import requests
from application.tools.base import Tool from application.agents.tools.base import Tool
class CryptoPriceTool(Tool): class CryptoPriceTool(Tool):

View File

@@ -1,5 +1,5 @@
import requests import requests
from application.tools.base import Tool from application.agents.tools.base import Tool
class TelegramTool(Tool): class TelegramTool(Tool):

View File

@@ -3,7 +3,7 @@ import inspect
import os import os
import pkgutil import pkgutil
from application.tools.base import Tool from application.agents.tools.base import Tool
class ToolManager: class ToolManager:
@@ -13,13 +13,11 @@ class ToolManager:
self.load_tools() self.load_tools()
def load_tools(self): def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__), "implementations") tools_dir = os.path.join(os.path.dirname(__file__))
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"): if name == "base" or name.startswith("__"):
continue continue
module = importlib.import_module( module = importlib.import_module(f"application.agents.tools.{name}")
f"application.tools.implementations.{name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass): for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool: if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {}) tool_config = self.config.get(name, {})
@@ -27,9 +25,7 @@ class ToolManager:
def load_tool(self, tool_name, tool_config): def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config self.config[tool_name] = tool_config
module = importlib.import_module( module = importlib.import_module(f"application.agents.tools.{tool_name}")
f"application.tools.implementations.{tool_name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass): for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool: if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config) return obj(tool_config)

View File

@@ -1,15 +1,16 @@
import asyncio import asyncio
import datetime import datetime
import json import json
import logging
import os import os
import traceback import traceback
import logging
from bson.dbref import DBRef from bson.dbref import DBRef
from bson.objectid import ObjectId from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource from flask_restx import fields, Namespace, Resource
from application.agents.classic_agent import ClassicAgent
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
@@ -17,7 +18,6 @@ from application.error import bad_request
from application.extensions import api from application.extensions import api
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator from application.retriever.retriever_creator import RetrieverCreator
from application.tools.agent import ClassicAgent
from application.utils import check_required_fields, limit_chat_history from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,9 +1,9 @@
import datetime import datetime
import json
import math import math
import os import os
import shutil import shutil
import uuid import uuid
import json
from bson.binary import Binary, UuidRepresentation from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef from bson.dbref import DBRef
@@ -12,12 +12,13 @@ from flask import Blueprint, current_app, jsonify, make_response, redirect, requ
from flask_restx import fields, inputs, Namespace, Resource from flask_restx import fields, inputs, Namespace, Resource
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from application.agents.tools.tool_manager import ToolManager
from application.api.user.tasks import ingest, ingest_remote from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
from application.extensions import api from application.extensions import api
from application.tools.tool_manager import ToolManager
from application.tts.google_tts import GoogleTTS from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields, validate_function_name from application.utils import check_required_fields, validate_function_name
from application.vectorstore.vector_creator import VectorCreator from application.vectorstore.vector_creator import VectorCreator
@@ -429,22 +430,21 @@ class UploadRemote(Resource):
return missing_fields return missing_fields
try: try:
config = json.loads(data["data"]) config = json.loads(data["data"])
source_data = None source_data = None
if data["source"] == "github": if data["source"] == "github":
source_data = config.get("repo_url") source_data = config.get("repo_url")
elif data["source"] in ["crawler", "url"]: elif data["source"] in ["crawler", "url"]:
source_data = config.get("url") source_data = config.get("url")
elif data["source"] == "reddit": elif data["source"] == "reddit":
source_data = config source_data = config
task = ingest_remote.delay(
task = ingest_remote.delay(
source_data=source_data, source_data=source_data,
job_name=data["name"], job_name=data["name"],
user=data["user"], user=data["user"],
loader=data["source"] loader=data["source"],
) )
except Exception as err: except Exception as err:
current_app.logger.error(f"Error uploading remote source: {err}") current_app.logger.error(f"Error uploading remote source: {err}")
@@ -1936,11 +1936,14 @@ class UpdateTool(Resource):
for action_name in list(data["config"]["actions"].keys()): for action_name in list(data["config"]["actions"].keys()):
if not validate_function_name(action_name): if not validate_function_name(action_name):
return make_response( return make_response(
jsonify({ jsonify(
"success": False, {
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.", "success": False,
"param": "tools[].function.name" "message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
}), 400 "param": "tools[].function.name",
}
),
400,
) )
update_data["config"] = data["config"] update_data["config"] = data["config"]
if "status" in data: if "status" in data:

View File

@@ -51,10 +51,12 @@ class ClassicRAG(BaseRetriever):
Rephrase the following user question to be a standalone search query Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation: that captures all relevant context from the conversation:
{self.original_question}
""" """
messages = [{"role": "system", "content": prompt}] messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": self.original_question},
]
try: try:
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)