mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
refactor: folder restructure for agent based workflow
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
from typing import Dict, Generator
|
from typing import Dict, Generator
|
||||||
|
|
||||||
|
from application.agents.llm_handler import get_llm_handler
|
||||||
|
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||||
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.tools.llm_handler import get_llm_handler
|
|
||||||
from application.tools.tool_action_parser import ToolActionParser
|
|
||||||
from application.tools.tool_manager import ToolManager
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent:
|
class BaseAgent:
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Generator
|
from typing import Dict, Generator
|
||||||
|
|
||||||
|
from application.agents.base import BaseAgent
|
||||||
|
|
||||||
from application.retriever.base import BaseRetriever
|
from application.retriever.base import BaseRetriever
|
||||||
from application.tools.base_agent import BaseAgent
|
|
||||||
|
|
||||||
|
|
||||||
class ClassicAgent(BaseAgent):
|
class ClassicAgent(BaseAgent):
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from application.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
class APITool(Tool):
|
class APITool(Tool):
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import requests
|
import requests
|
||||||
from application.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
class CryptoPriceTool(Tool):
|
class CryptoPriceTool(Tool):
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import requests
|
import requests
|
||||||
from application.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
class TelegramTool(Tool):
|
class TelegramTool(Tool):
|
||||||
@@ -3,7 +3,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
||||||
from application.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
@@ -13,13 +13,11 @@ class ToolManager:
|
|||||||
self.load_tools()
|
self.load_tools()
|
||||||
|
|
||||||
def load_tools(self):
|
def load_tools(self):
|
||||||
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
|
tools_dir = os.path.join(os.path.dirname(__file__))
|
||||||
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
||||||
if name == "base" or name.startswith("__"):
|
if name == "base" or name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
module = importlib.import_module(
|
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||||
f"application.tools.implementations.{name}"
|
|
||||||
)
|
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool:
|
||||||
tool_config = self.config.get(name, {})
|
tool_config = self.config.get(name, {})
|
||||||
@@ -27,9 +25,7 @@ class ToolManager:
|
|||||||
|
|
||||||
def load_tool(self, tool_name, tool_config):
|
def load_tool(self, tool_name, tool_config):
|
||||||
self.config[tool_name] = tool_config
|
self.config[tool_name] = tool_config
|
||||||
module = importlib.import_module(
|
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||||
f"application.tools.implementations.{tool_name}"
|
|
||||||
)
|
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool:
|
||||||
return obj(tool_config)
|
return obj(tool_config)
|
||||||
@@ -1,15 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
|
||||||
|
|
||||||
from bson.dbref import DBRef
|
from bson.dbref import DBRef
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
from flask import Blueprint, make_response, request, Response
|
from flask import Blueprint, make_response, request, Response
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.agents.classic_agent import ClassicAgent
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
@@ -17,7 +18,6 @@ from application.error import bad_request
|
|||||||
from application.extensions import api
|
from application.extensions import api
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.retriever.retriever_creator import RetrieverCreator
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
from application.tools.agent import ClassicAgent
|
|
||||||
from application.utils import check_required_fields, limit_chat_history
|
from application.utils import check_required_fields, limit_chat_history
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
|
||||||
|
|
||||||
from bson.binary import Binary, UuidRepresentation
|
from bson.binary import Binary, UuidRepresentation
|
||||||
from bson.dbref import DBRef
|
from bson.dbref import DBRef
|
||||||
@@ -12,12 +12,13 @@ from flask import Blueprint, current_app, jsonify, make_response, redirect, requ
|
|||||||
from flask_restx import fields, inputs, Namespace, Resource
|
from flask_restx import fields, inputs, Namespace, Resource
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
|
|
||||||
from application.api.user.tasks import ingest, ingest_remote
|
from application.api.user.tasks import ingest, ingest_remote
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.extensions import api
|
from application.extensions import api
|
||||||
from application.tools.tool_manager import ToolManager
|
|
||||||
from application.tts.google_tts import GoogleTTS
|
from application.tts.google_tts import GoogleTTS
|
||||||
from application.utils import check_required_fields, validate_function_name
|
from application.utils import check_required_fields, validate_function_name
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
@@ -429,22 +430,21 @@ class UploadRemote(Resource):
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = json.loads(data["data"])
|
config = json.loads(data["data"])
|
||||||
source_data = None
|
source_data = None
|
||||||
|
|
||||||
if data["source"] == "github":
|
if data["source"] == "github":
|
||||||
source_data = config.get("repo_url")
|
source_data = config.get("repo_url")
|
||||||
elif data["source"] in ["crawler", "url"]:
|
elif data["source"] in ["crawler", "url"]:
|
||||||
source_data = config.get("url")
|
source_data = config.get("url")
|
||||||
elif data["source"] == "reddit":
|
elif data["source"] == "reddit":
|
||||||
source_data = config
|
source_data = config
|
||||||
|
|
||||||
|
task = ingest_remote.delay(
|
||||||
task = ingest_remote.delay(
|
|
||||||
source_data=source_data,
|
source_data=source_data,
|
||||||
job_name=data["name"],
|
job_name=data["name"],
|
||||||
user=data["user"],
|
user=data["user"],
|
||||||
loader=data["source"]
|
loader=data["source"],
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error uploading remote source: {err}")
|
current_app.logger.error(f"Error uploading remote source: {err}")
|
||||||
@@ -1936,11 +1936,14 @@ class UpdateTool(Resource):
|
|||||||
for action_name in list(data["config"]["actions"].keys()):
|
for action_name in list(data["config"]["actions"].keys()):
|
||||||
if not validate_function_name(action_name):
|
if not validate_function_name(action_name):
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({
|
jsonify(
|
||||||
"success": False,
|
{
|
||||||
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
"success": False,
|
||||||
"param": "tools[].function.name"
|
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
||||||
}), 400
|
"param": "tools[].function.name",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
)
|
)
|
||||||
update_data["config"] = data["config"]
|
update_data["config"] = data["config"]
|
||||||
if "status" in data:
|
if "status" in data:
|
||||||
|
|||||||
@@ -51,10 +51,12 @@ class ClassicRAG(BaseRetriever):
|
|||||||
|
|
||||||
Rephrase the following user question to be a standalone search query
|
Rephrase the following user question to be a standalone search query
|
||||||
that captures all relevant context from the conversation:
|
that captures all relevant context from the conversation:
|
||||||
{self.original_question}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = [{"role": "system", "content": prompt}]
|
messages = [
|
||||||
|
{"role": "system", "content": prompt},
|
||||||
|
{"role": "user", "content": self.original_question},
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
|
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
|
||||||
|
|||||||
Reference in New Issue
Block a user