mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'main' into feat/remote-mcp
This commit is contained in:
@@ -5,19 +5,17 @@ from typing import Dict, Generator, List, Optional
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
def __init__(
|
||||
@@ -157,7 +155,7 @@ class BaseAgent(ABC):
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Failed to parse tool call.", call_id
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
# Check if tool_id exists in available tools
|
||||
if tool_id not in tools_dict:
|
||||
|
||||
@@ -69,11 +69,8 @@ class StreamProcessor:
|
||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||
)
|
||||
self.conversation_id = self.data.get("conversation_id")
|
||||
self.source = (
|
||||
{"active_docs": self.data["active_docs"]}
|
||||
if "active_docs" in self.data
|
||||
else {}
|
||||
)
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.agent_config = {}
|
||||
@@ -85,6 +82,8 @@ class StreamProcessor:
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
@@ -171,13 +170,77 @@ class StreamProcessor:
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
else:
|
||||
data["source"] = None
|
||||
# Handle multiple sources
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
"id": agent_data["source"],
|
||||
"retriever": agent_data.get("retriever", "classic"),
|
||||
}
|
||||
]
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _configure_agent(self):
|
||||
"""Configure the agent based on request data"""
|
||||
agent_id = self.data.get("agent_id")
|
||||
@@ -203,7 +266,13 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
self.agent_config.update(
|
||||
@@ -224,7 +293,13 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
else:
|
||||
self.agent_config.update(
|
||||
{
|
||||
@@ -243,7 +318,8 @@ class StreamProcessor:
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
}
|
||||
|
||||
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
|
||||
626
application/api/connector/routes.py
Normal file
626
application/api/connector/routes.py
Normal file
@@ -0,0 +1,626 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Blueprint,
|
||||
current_app,
|
||||
jsonify,
|
||||
make_response,
|
||||
request
|
||||
)
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.utils import (
|
||||
check_required_fields
|
||||
)
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
sessions_collection = db["connector_sessions"]
|
||||
|
||||
connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/upload")
|
||||
class UploadConnector(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source type (google_drive, github, etc.)"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Configuration data"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads connector source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
sync_frequency = config.get("sync_frequency", "never")
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration"
|
||||
}), 400)
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
task = ingest_connector_task.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/task_status")
|
||||
class ConnectorTaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"ConnectorTaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get connector task status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sources")
|
||||
class ConnectorSources(Resource):
|
||||
@api.doc(description="Get connector sources")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
sources = sources_collection.find({"user": user, "type": "connector"}).sort("date", -1)
|
||||
connector_sources = []
|
||||
for source in sources:
|
||||
connector_sources.append({
|
||||
"id": str(source["_id"]),
|
||||
"name": source.get("name"),
|
||||
"date": source.get("date"),
|
||||
"type": source.get("type"),
|
||||
"source": source.get("source"),
|
||||
"tokens": source.get("tokens", ""),
|
||||
"retriever": source.get("retriever", "classic"),
|
||||
"syncFrequency": source.get("sync_frequency", ""),
|
||||
})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(connector_sources), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/delete")
|
||||
class DeleteConnectorSource(Resource):
|
||||
@api.doc(
|
||||
description="Delete a connector source",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "source_id is required"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
class ConnectorAuth(Resource):
|
||||
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
|
||||
def get(self):
|
||||
try:
|
||||
provider = request.args.get('provider') or request.args.get('source')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
|
||||
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
import uuid
|
||||
state = str(uuid.uuid4())
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"authorization_url": authorization_url,
|
||||
"state": state
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback")
|
||||
class ConnectorsCallback(Resource):
|
||||
@api.doc(description="Handle OAuth callback for external connectors")
|
||||
def get(self):
|
||||
"""Handle OAuth callback for external connectors"""
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
import uuid
|
||||
|
||||
provider = request.args.get('provider', 'google_drive')
|
||||
authorization_code = request.args.get('code')
|
||||
_ = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
if error:
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
|
||||
try:
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry"),
|
||||
"scopes": token_info.get("scopes")
|
||||
}
|
||||
|
||||
user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
|
||||
sessions_collection.insert_one({
|
||||
"session_token": session_token,
|
||||
"user": user_id,
|
||||
"token_info": sanitized_token_info,
|
||||
"created_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
"user_email": user_email,
|
||||
"provider": provider
|
||||
})
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/refresh")
|
||||
class ConnectorRefresh(Resource):
|
||||
@api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
|
||||
@api.doc(description="Refresh connector access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
refresh_token = data.get('refresh_token')
|
||||
|
||||
if not provider or not refresh_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.refresh_access_token(refresh_token)
|
||||
return make_response(jsonify({"success": True, "token_info": token_info}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error refreshing token for connector: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True), "folder_id": fields.String(required=False), "limit": fields.Integer(required=False), "page_token": fields.String(required=False)}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination)")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
folder_id = data.get('folder_id')
|
||||
limit = data.get('limit', 10)
|
||||
page_token = data.get('page_token')
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
documents = loader.load_data({
|
||||
'limit': limit,
|
||||
'list_only': True,
|
||||
'session_token': session_token,
|
||||
'folder_id': folder_id,
|
||||
'page_token': page_token
|
||||
})
|
||||
|
||||
files = []
|
||||
for doc in documents[:limit]:
|
||||
metadata = doc.extra_info
|
||||
modified_time = metadata.get('modified_time')
|
||||
if modified_time:
|
||||
date_part = modified_time.split('T')[0]
|
||||
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
|
||||
formatted_time = f"{date_part} {time_part}"
|
||||
else:
|
||||
formatted_time = None
|
||||
|
||||
files.append({
|
||||
'id': doc.doc_id,
|
||||
'name': metadata.get('file_name', 'Unknown File'),
|
||||
'type': metadata.get('mime_type', 'unknown'),
|
||||
'size': metadata.get('size', None),
|
||||
'modifiedTime': formatted_time
|
||||
})
|
||||
|
||||
next_token = getattr(loader, 'next_page_token', None)
|
||||
has_more = bool(next_token)
|
||||
|
||||
return make_response(jsonify({"success": True, "files": files, "total": len(files), "next_page_token": next_token, "has_more": has_more}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}")
|
||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
class ConnectorValidateSession(Resource):
|
||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||
@api.doc(description="Validate connector session token and return user info")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session or "token_info" not in session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||
|
||||
token_info = session["token_info"]
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"expired": is_expired,
|
||||
"user_email": session.get('user_email', 'Connected User')
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/disconnect")
|
||||
class ConnectorDisconnect(Resource):
|
||||
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
|
||||
@api.doc(description="Disconnect a connector session")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
|
||||
|
||||
|
||||
if session_token:
|
||||
sessions_collection.delete_one({"session_token": session_token})
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sync")
|
||||
class ConnectorSync(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID to sync"),
|
||||
"session_token": fields.String(required=True, description="Authentication token")
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Sync connector source to check for modifications")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
try:
|
||||
data = request.get_json()
|
||||
source_id = data.get('source_id')
|
||||
session_token = data.get('session_token')
|
||||
|
||||
if not all([source_id, session_token]):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "source_id and session_token are required"
|
||||
}),
|
||||
400
|
||||
)
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source not found"
|
||||
}),
|
||||
404
|
||||
)
|
||||
|
||||
if source.get('user') != decoded_token.get('sub'):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Unauthorized access to source"
|
||||
}),
|
||||
403
|
||||
)
|
||||
|
||||
remote_data = {}
|
||||
try:
|
||||
if source.get('remote_data'):
|
||||
remote_data = json.loads(source.get('remote_data'))
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
|
||||
source_type = remote_data.get('provider')
|
||||
if not source_type:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source provider not found in remote_data"
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
# Extract configuration from remote_data
|
||||
file_ids = remote_data.get('file_ids', [])
|
||||
folder_ids = remote_data.get('folder_ids', [])
|
||||
recursive = remote_data.get('recursive', True)
|
||||
|
||||
# Start the sync task
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=source.get('name'),
|
||||
user=decoded_token.get('sub'),
|
||||
source_type=source_type,
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=recursive,
|
||||
retriever=source.get('retriever', 'classic'),
|
||||
operation_mode="sync",
|
||||
doc_id=source_id,
|
||||
sync_frequency=source.get('sync_frequency', 'never')
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": True,
|
||||
"task_id": task.id
|
||||
}),
|
||||
200
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error syncing connector source: {err}",
|
||||
exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": str(err)
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback-status")
|
||||
class ConnectorCallbackStatus(Resource):
|
||||
@api.doc(description="Return HTML page with connector authentication status")
|
||||
def get(self):
|
||||
"""Return HTML page with connector authentication status"""
|
||||
try:
|
||||
status = request.args.get('status', 'error')
|
||||
message = request.args.get('message', '')
|
||||
provider = request.args.get('provider', 'connector')
|
||||
session_token = request.args.get('session_token', '')
|
||||
user_email = request.args.get('user_email', '')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{provider.replace('_', ' ').title()} Authentication</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = "{status}";
|
||||
const sessionToken = "{session_token}";
|
||||
const userEmail = "{user_email}";
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: '{provider}_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
|
||||
<div class="{status}">
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})
|
||||
|
||||
|
||||
@@ -32,13 +32,15 @@ from application.api import api
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest,
|
||||
ingest_connector_task,
|
||||
ingest_remote,
|
||||
process_agent_webhook,
|
||||
store_attachment,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import encrypt_credentials, decrypt_credentials
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.utils import (
|
||||
@@ -76,7 +78,6 @@ try:
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
|
||||
user = Blueprint("user", __name__)
|
||||
user_ns = Namespace("user", description="User related operations", path="/")
|
||||
api.add_namespace(user_ns)
|
||||
@@ -129,11 +130,9 @@ def ensure_user_doc(user_id):
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
|
||||
return user_doc
|
||||
|
||||
|
||||
@@ -185,7 +184,6 @@ def handle_image_upload(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
|
||||
return image_url, None
|
||||
|
||||
|
||||
@@ -299,8 +297,8 @@ class GetSingleConversation(Resource):
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
# Process queries to include attachment names
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
@@ -501,6 +499,7 @@ class DeleteOldIndexes(Resource):
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
@@ -571,6 +570,7 @@ class UploadFile(Resource):
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
|
||||
safe_user = safe_filename(user)
|
||||
dir_name = safe_filename(job_name)
|
||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||
@@ -592,6 +592,7 @@ class UploadFile(Resource):
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
|
||||
# Walk through extracted files and upload them
|
||||
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for extracted_file in files:
|
||||
if (
|
||||
@@ -614,11 +615,13 @@ class UploadFile(Resource):
|
||||
f"Error extracting zip: {e}", exc_info=True
|
||||
)
|
||||
# If zip extraction fails, save the original zip file
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
else:
|
||||
# For non-zip files, save directly
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
@@ -709,7 +712,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if operation not in ["add", "remove", "remove_directory"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -720,14 +722,12 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
@@ -760,7 +760,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if operation == "add":
|
||||
files = request.files.getlist("file")
|
||||
if not files or all(file.filename == "" for file in files):
|
||||
@@ -773,23 +772,22 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
added_files = []
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
|
||||
for file in files:
|
||||
if file.filename:
|
||||
safe_filename_str = safe_filename(file.filename)
|
||||
file_path = f"{target_dir}/{safe_filename_str}"
|
||||
|
||||
# Save file to storage
|
||||
|
||||
storage.save_file(file, file_path)
|
||||
added_files.append(safe_filename_str)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
@@ -819,7 +817,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
file_paths = (
|
||||
json.loads(file_paths_str)
|
||||
@@ -833,18 +830,19 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
if storage.file_exists(full_path):
|
||||
storage.delete_file(full_path)
|
||||
removed_files.append(file_path)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
@@ -873,8 +871,8 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate directory path (prevent path traversal)
|
||||
|
||||
if directory_path.startswith("/") or ".." in directory_path:
|
||||
current_app.logger.warning(
|
||||
f"Invalid directory path attempted for removal. "
|
||||
@@ -908,7 +906,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
@@ -923,7 +920,6 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
current_app.logger.info(
|
||||
f"Successfully removed directory. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
@@ -931,6 +927,7 @@ class ManageSourceFiles(Resource):
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
@@ -1005,6 +1002,50 @@ class UploadRemote(Resource):
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Process file_ids
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
# Process folder_ids
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [
|
||||
id.strip() for id in folder_ids.split(",") if id.strip()
|
||||
]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
@@ -1113,6 +1154,7 @@ class PaginatedSources(Resource):
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
@@ -1161,6 +1203,9 @@ class CombinedJson(Resource):
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -1376,17 +1421,14 @@ class GetAgent(Resource):
|
||||
def get(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return {"success": False, "message": "ID required"}, 400
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
if not agent:
|
||||
return {"status": "Not found"}, 404
|
||||
|
||||
data = {
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
@@ -1400,6 +1442,16 @@ class GetAgent(Resource):
|
||||
and (source_doc := db.dereference(agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"sources": [
|
||||
(
|
||||
str(db.dereference(source_ref)["_id"])
|
||||
if isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
else source_ref
|
||||
)
|
||||
for source_ref in agent.get("sources", [])
|
||||
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
|
||||
or source_ref == "default"
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
@@ -1422,7 +1474,6 @@ class GetAgent(Resource):
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
@@ -1434,7 +1485,6 @@ class GetAgents(Resource):
|
||||
def get(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
user_doc = ensure_user_doc(user)
|
||||
@@ -1453,8 +1503,24 @@ class GetAgents(Resource):
|
||||
str(source_doc["_id"])
|
||||
if isinstance(agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(agent.get("source")))
|
||||
else ""
|
||||
else (
|
||||
agent.get("source", "")
|
||||
if agent.get("source") == "default"
|
||||
else ""
|
||||
)
|
||||
),
|
||||
"sources": [
|
||||
(
|
||||
source_ref
|
||||
if source_ref == "default"
|
||||
else str(db.dereference(source_ref)["_id"])
|
||||
)
|
||||
for source_ref in agent.get("sources", [])
|
||||
if source_ref == "default"
|
||||
or (
|
||||
isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
)
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
@@ -1497,7 +1563,14 @@ class CreateAgent(Resource):
|
||||
"image": fields.Raw(
|
||||
required=False, description="Image file upload", type="file"
|
||||
),
|
||||
"source": fields.String(required=True, description="Source ID"),
|
||||
"source": fields.String(
|
||||
required=False, description="Source ID (legacy single source)"
|
||||
),
|
||||
"sources": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
@@ -1530,6 +1603,11 @@ class CreateAgent(Resource):
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "sources" in data:
|
||||
try:
|
||||
data["sources"] = json.loads(data["sources"])
|
||||
except json.JSONDecodeError:
|
||||
data["sources"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
@@ -1538,9 +1616,11 @@ class CreateAgent(Resource):
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
@@ -1554,6 +1634,7 @@ class CreateAgent(Resource):
|
||||
)
|
||||
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -1571,7 +1652,6 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -1582,17 +1662,27 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if data.get("status") == "published":
|
||||
required_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"source",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"agent_type",
|
||||
]
|
||||
# Require either source or sources (but not both)
|
||||
|
||||
if not data.get("source") and not data.get("sources"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Either 'source' or 'sources' field is required for published agents",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
validate_fields = ["name", "description", "prompt_id", "agent_type"]
|
||||
else:
|
||||
required_fields = ["name"]
|
||||
@@ -1603,25 +1693,37 @@ class CreateAgent(Resource):
|
||||
return missing_fields
|
||||
if invalid_fields:
|
||||
return invalid_fields
|
||||
|
||||
image_url, error = handle_image_upload(request, "", user, storage)
|
||||
if error:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
|
||||
|
||||
sources_list = []
|
||||
if data.get("sources") and len(data.get("sources", [])) > 0:
|
||||
for source_id in data.get("sources", []):
|
||||
if source_id == "default":
|
||||
sources_list.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
sources_list.append(DBRef("sources", ObjectId(source_id)))
|
||||
source_field = ""
|
||||
else:
|
||||
source_value = data.get("source", "")
|
||||
if source_value == "default":
|
||||
source_field = "default"
|
||||
elif ObjectId.is_valid(source_value):
|
||||
source_field = DBRef("sources", ObjectId(source_value))
|
||||
else:
|
||||
source_field = ""
|
||||
new_agent = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"image": image_url,
|
||||
"source": (
|
||||
DBRef("sources", ObjectId(data.get("source")))
|
||||
if ObjectId.is_valid(data.get("source"))
|
||||
else ""
|
||||
),
|
||||
"source": source_field,
|
||||
"sources": sources_list,
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
@@ -1636,7 +1738,11 @@ class CreateAgent(Resource):
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "0"
|
||||
if new_agent["source"] == "" and new_agent["retriever"] == "":
|
||||
if (
|
||||
new_agent["source"] == ""
|
||||
and new_agent["retriever"] == ""
|
||||
and not new_agent["sources"]
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
resp = agents_collection.insert_one(new_agent)
|
||||
new_id = str(resp.inserted_id)
|
||||
@@ -1658,7 +1764,14 @@ class UpdateAgent(Resource):
|
||||
"image": fields.String(
|
||||
required=False, description="New image URL or identifier"
|
||||
),
|
||||
"source": fields.String(required=True, description="Source ID"),
|
||||
"source": fields.String(
|
||||
required=False, description="Source ID (legacy single source)"
|
||||
),
|
||||
"sources": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
@@ -1691,12 +1804,16 @@ class UpdateAgent(Resource):
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "sources" in data:
|
||||
try:
|
||||
data["sources"] = json.loads(data["sources"])
|
||||
except json.JSONDecodeError:
|
||||
data["sources"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
|
||||
if not ObjectId.is_valid(agent_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
|
||||
@@ -1720,7 +1837,6 @@ class UpdateAgent(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
image_url, error = handle_image_upload(
|
||||
request, existing_agent.get("image", ""), user, storage
|
||||
)
|
||||
@@ -1728,13 +1844,13 @@ class UpdateAgent(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
||||
)
|
||||
|
||||
update_fields = {}
|
||||
allowed_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"image",
|
||||
"source",
|
||||
"sources",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
@@ -1758,7 +1874,11 @@ class UpdateAgent(Resource):
|
||||
update_fields[field] = new_status
|
||||
elif field == "source":
|
||||
source_id = data.get("source")
|
||||
if source_id and ObjectId.is_valid(source_id):
|
||||
if source_id == "default":
|
||||
# Handle special "default" source
|
||||
|
||||
update_fields[field] = "default"
|
||||
elif source_id and ObjectId.is_valid(source_id):
|
||||
update_fields[field] = DBRef("sources", ObjectId(source_id))
|
||||
elif source_id:
|
||||
return make_response(
|
||||
@@ -1772,6 +1892,30 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
else:
|
||||
update_fields[field] = ""
|
||||
elif field == "sources":
|
||||
sources_list = data.get("sources", [])
|
||||
if sources_list and isinstance(sources_list, list):
|
||||
valid_sources = []
|
||||
for source_id in sources_list:
|
||||
if source_id == "default":
|
||||
valid_sources.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
valid_sources.append(
|
||||
DBRef("sources", ObjectId(source_id))
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid source ID format: {source_id}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = valid_sources
|
||||
else:
|
||||
update_fields[field] = []
|
||||
elif field == "chunks":
|
||||
chunks_value = data.get("chunks")
|
||||
if chunks_value == "":
|
||||
@@ -1837,7 +1981,6 @@ class UpdateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if not existing_agent.get("key"):
|
||||
newly_generated_key = str(uuid.uuid4())
|
||||
update_fields["key"] = newly_generated_key
|
||||
@@ -1924,7 +2067,6 @@ class PinnedAgents(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
@@ -1933,7 +2075,6 @@ class PinnedAgents(Resource):
|
||||
|
||||
if not pinned_ids:
|
||||
return make_response(jsonify([]), 200)
|
||||
|
||||
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
|
||||
|
||||
pinned_agents_cursor = agents_collection.find(
|
||||
@@ -1943,6 +2084,7 @@ class PinnedAgents(Resource):
|
||||
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
|
||||
|
||||
# Clean up any stale pinned IDs
|
||||
|
||||
stale_ids = [
|
||||
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
|
||||
]
|
||||
@@ -1951,7 +2093,6 @@ class PinnedAgents(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
||||
)
|
||||
|
||||
list_pinned_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
@@ -1988,11 +2129,9 @@ class PinnedAgents(Resource):
|
||||
for agent in pinned_agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
]
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving pinned agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(jsonify(list_pinned_agents), 200)
|
||||
|
||||
|
||||
@@ -2056,7 +2195,6 @@ class RemoveSharedAgent(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "shared_publicly": True}
|
||||
@@ -2066,7 +2204,6 @@ class RemoveSharedAgent(Resource):
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
@@ -2079,7 +2216,6 @@ class RemoveSharedAgent(Resource):
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error removing shared agent: {err}")
|
||||
return make_response(
|
||||
@@ -2102,7 +2238,6 @@ class SharedAgent(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
@@ -2114,7 +2249,6 @@ class SharedAgent(Resource):
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
@@ -2154,7 +2288,6 @@ class SharedAgent(Resource):
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
@@ -2166,9 +2299,7 @@ class SharedAgent(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -2202,7 +2333,6 @@ class SharedAgents(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
@@ -2229,7 +2359,6 @@ class SharedAgents(Resource):
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -3762,20 +3891,21 @@ class GetChunks(Resource):
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
# Check if the chunk's source matches the requested path
|
||||
|
||||
if not chunk_source or not chunk_source.endswith(path):
|
||||
continue
|
||||
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
|
||||
if not (text_match or title_match):
|
||||
continue
|
||||
|
||||
filtered_chunks.append(chunk)
|
||||
|
||||
chunks = filtered_chunks
|
||||
@@ -3937,7 +4067,6 @@ class UpdateChunk(Resource):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
@@ -3952,7 +4081,6 @@ class UpdateChunk(Resource):
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
if not existing_chunk:
|
||||
return make_response(jsonify({"error": "Chunk not found"}), 404)
|
||||
|
||||
new_text = text if text is not None else existing_chunk["text"]
|
||||
|
||||
if metadata is not None:
|
||||
@@ -3960,10 +4088,8 @@ class UpdateChunk(Resource):
|
||||
new_metadata.update(metadata)
|
||||
else:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
|
||||
if text is not None:
|
||||
new_metadata["token_count"] = num_tokens_from_string(new_text)
|
||||
|
||||
try:
|
||||
new_chunk_id = store.add_chunk(new_text, new_metadata)
|
||||
|
||||
@@ -4019,7 +4145,6 @@ class StoreAttachment(Resource):
|
||||
jsonify({"status": "error", "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = None
|
||||
if decoded_token:
|
||||
user = safe_filename(decoded_token.get("sub"))
|
||||
@@ -4034,7 +4159,6 @@ class StoreAttachment(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
@@ -4076,7 +4200,6 @@ class ServeImage(Resource):
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
@@ -4121,18 +4244,29 @@ class DirectoryStructure(Resource):
|
||||
)
|
||||
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
base_path = doc.get("file_path", "")
|
||||
|
||||
provider = None
|
||||
remote_data = doc.get("remote_data")
|
||||
try:
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": doc.get("file_path", ""),
|
||||
"base_path": base_path,
|
||||
"provider": provider,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
|
||||
@@ -47,6 +47,39 @@ def process_agent_webhook(self, agent_id, payload):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
user,
|
||||
source_type,
|
||||
session_token=None,
|
||||
file_ids=None,
|
||||
folder_ids=None,
|
||||
recursive=True,
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never"
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
resp = ingest_connector(
|
||||
self,
|
||||
job_name,
|
||||
user,
|
||||
source_type,
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=recursive,
|
||||
retriever=retriever,
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.on_after_configure.connect
|
||||
def setup_periodic_tasks(sender, **kwargs):
|
||||
sender.add_periodic_task(
|
||||
|
||||
@@ -16,6 +16,7 @@ from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
|
||||
@@ -30,6 +31,7 @@ app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.config.update(
|
||||
UPLOAD_FOLDER="inputs",
|
||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||
|
||||
@@ -40,6 +40,13 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback"
|
||||
##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive
|
||||
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
|
||||
18
application/parser/connectors/__init__.py
Normal file
18
application/parser/connectors/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
External knowledge base connectors for DocsGPT.
|
||||
|
||||
This module contains connectors for external knowledge bases and document storage systems
|
||||
that require authentication and specialized handling, separate from simple web scrapers.
|
||||
"""
|
||||
|
||||
from .base import BaseConnectorAuth, BaseConnectorLoader
|
||||
from .connector_creator import ConnectorCreator
|
||||
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
|
||||
|
||||
__all__ = [
|
||||
'BaseConnectorAuth',
|
||||
'BaseConnectorLoader',
|
||||
'ConnectorCreator',
|
||||
'GoogleDriveAuth',
|
||||
'GoogleDriveLoader'
|
||||
]
|
||||
129
application/parser/connectors/base.py
Normal file
129
application/parser/connectors/base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Base classes for external knowledge base connectors.
|
||||
|
||||
This module provides minimal abstract base classes that define the essential
|
||||
interface for external knowledge base connectors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class BaseConnectorAuth(ABC):
|
||||
"""
|
||||
Abstract base class for connector authentication.
|
||||
|
||||
Defines the minimal interface that all connector authentication
|
||||
implementations must follow.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate authorization URL for OAuth flows.
|
||||
|
||||
Args:
|
||||
state: Optional state parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
Authorization URL
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Exchange authorization code for access tokens.
|
||||
|
||||
Args:
|
||||
authorization_code: Authorization code from OAuth callback
|
||||
|
||||
Returns:
|
||||
Dictionary containing token information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh an expired access token.
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
Dictionary containing refreshed token information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a token is expired.
|
||||
|
||||
Args:
|
||||
token_info: Token information dictionary
|
||||
|
||||
Returns:
|
||||
True if token is expired, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseConnectorLoader(ABC):
|
||||
"""
|
||||
Abstract base class for connector loaders.
|
||||
|
||||
Defines the minimal interface that all connector loader
|
||||
implementations must follow.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session_token: str):
|
||||
"""
|
||||
Initialize the connector loader.
|
||||
|
||||
Args:
|
||||
session_token: Authentication session token
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load documents from the external knowledge base.
|
||||
|
||||
Args:
|
||||
inputs: Configuration dictionary containing:
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- folder_ids: Optional list of folder IDs to browse/download
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, return metadata without content
|
||||
- recursive: Whether to recursively process folders
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Download files/folders to a local directory.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory path to download files to
|
||||
source_config: Configuration for what to download
|
||||
|
||||
Returns:
|
||||
Dictionary containing download results:
|
||||
- files_downloaded: Number of files downloaded
|
||||
- directory_path: Path where files were downloaded
|
||||
- empty_result: Whether no files were downloaded
|
||||
- source_type: Type of connector
|
||||
- config_used: Configuration that was used
|
||||
- error: Error message if download failed (optional)
|
||||
"""
|
||||
pass
|
||||
81
application/parser/connectors/connector_creator.py
Normal file
81
application/parser/connectors/connector_creator.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
"""
|
||||
Factory class for creating external knowledge base connectors and auth providers.
|
||||
|
||||
These are different from remote loaders as they typically require
|
||||
authentication and connect to external document storage systems.
|
||||
"""
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_connector(cls, connector_type, *args, **kwargs):
|
||||
"""
|
||||
Create a connector instance for the specified type.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector to create (e.g., 'google_drive')
|
||||
*args, **kwargs: Arguments to pass to the connector constructor
|
||||
|
||||
Returns:
|
||||
Connector instance
|
||||
|
||||
Raises:
|
||||
ValueError: If connector type is not supported
|
||||
"""
|
||||
connector_class = cls.connectors.get(connector_type.lower())
|
||||
if not connector_class:
|
||||
raise ValueError(f"No connector class found for type {connector_type}")
|
||||
return connector_class(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_auth(cls, connector_type):
|
||||
"""
|
||||
Create an auth provider instance for the specified connector type.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector auth to create (e.g., 'google_drive')
|
||||
|
||||
Returns:
|
||||
Auth provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If connector type is not supported for auth
|
||||
"""
|
||||
auth_class = cls.auth_providers.get(connector_type.lower())
|
||||
if not auth_class:
|
||||
raise ValueError(f"No auth class found for type {connector_type}")
|
||||
return auth_class()
|
||||
|
||||
@classmethod
|
||||
def get_supported_connectors(cls):
|
||||
"""
|
||||
Get list of supported connector types.
|
||||
|
||||
Returns:
|
||||
List of supported connector type strings
|
||||
"""
|
||||
return list(cls.connectors.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, connector_type):
|
||||
"""
|
||||
Check if a connector type is supported.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector to check
|
||||
|
||||
Returns:
|
||||
True if supported, False otherwise
|
||||
"""
|
||||
return connector_type.lower() in cls.connectors
|
||||
10
application/parser/connectors/google_drive/__init__.py
Normal file
10
application/parser/connectors/google_drive/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Google Drive connector for DocsGPT.
|
||||
|
||||
This module provides authentication and document loading capabilities for Google Drive.
|
||||
"""
|
||||
|
||||
from .auth import GoogleDriveAuth
|
||||
from .loader import GoogleDriveLoader
|
||||
|
||||
__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']
|
||||
268
application/parser/connectors/google_drive/auth.py
Normal file
268
application/parser/connectors/google_drive/auth.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import logging
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
|
||||
class GoogleDriveAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Google OAuth 2.0 authentication for Google Drive access.
|
||||
"""
|
||||
|
||||
SCOPES = [
|
||||
'https://www.googleapis.com/auth/drive.readonly',
|
||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.GOOGLE_CLIENT_ID
|
||||
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive"
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
||||
|
||||
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
try:
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"redirect_uris": [self.redirect_uri]
|
||||
}
|
||||
},
|
||||
scopes=self.SCOPES
|
||||
)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type='offline',
|
||||
prompt='consent',
|
||||
include_granted_scopes='true',
|
||||
state=state
|
||||
)
|
||||
|
||||
return authorization_url
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating authorization URL: {e}")
|
||||
raise
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if not authorization_code:
|
||||
raise ValueError("Authorization code is required")
|
||||
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"redirect_uris": [self.redirect_uri]
|
||||
}
|
||||
},
|
||||
scopes=self.SCOPES
|
||||
)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
|
||||
flow.fetch_token(code=authorization_code)
|
||||
|
||||
credentials = flow.credentials
|
||||
|
||||
if not credentials.refresh_token:
|
||||
logging.warning("OAuth flow did not return a refresh_token.")
|
||||
if not credentials.token:
|
||||
raise ValueError("OAuth flow did not return an access token")
|
||||
|
||||
if not credentials.token_uri:
|
||||
credentials.token_uri = "https://oauth2.googleapis.com/token"
|
||||
|
||||
if not credentials.client_id:
|
||||
credentials.client_id = self.client_id
|
||||
|
||||
if not credentials.client_secret:
|
||||
credentials.client_secret = self.client_secret
|
||||
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError(
|
||||
"No refresh token received. This typically happens when offline access wasn't granted. "
|
||||
)
|
||||
|
||||
return {
|
||||
'access_token': credentials.token,
|
||||
'refresh_token': credentials.refresh_token,
|
||||
'token_uri': credentials.token_uri,
|
||||
'client_id': credentials.client_id,
|
||||
'client_secret': credentials.client_secret,
|
||||
'scopes': credentials.scopes,
|
||||
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error exchanging code for tokens: {e}")
|
||||
raise
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if not refresh_token:
|
||||
raise ValueError("Refresh token is required")
|
||||
|
||||
credentials = Credentials(
|
||||
token=None,
|
||||
refresh_token=refresh_token,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret
|
||||
)
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
credentials.refresh(Request())
|
||||
|
||||
return {
|
||||
'access_token': credentials.token,
|
||||
'refresh_token': refresh_token,
|
||||
'token_uri': credentials.token_uri,
|
||||
'client_id': credentials.client_id,
|
||||
'client_secret': credentials.client_secret,
|
||||
'scopes': credentials.scopes,
|
||||
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error refreshing access token: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials:
|
||||
from application.core.settings import settings
|
||||
|
||||
access_token = token_info.get('access_token')
|
||||
if not access_token:
|
||||
raise ValueError("No access token found in token_info")
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=token_info.get('refresh_token'),
|
||||
token_uri= 'https://oauth2.googleapis.com/token',
|
||||
client_id=settings.GOOGLE_CLIENT_ID,
|
||||
client_secret=settings.GOOGLE_CLIENT_SECRET,
|
||||
scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly'])
|
||||
)
|
||||
|
||||
if not credentials.token:
|
||||
raise ValueError("Credentials created without valid access token")
|
||||
|
||||
return credentials
|
||||
|
||||
def build_drive_service(self, credentials: Credentials):
|
||||
try:
|
||||
if not credentials:
|
||||
raise ValueError("No credentials provided")
|
||||
|
||||
if not credentials.token and not credentials.refresh_token:
|
||||
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
|
||||
|
||||
needs_refresh = credentials.expired or not credentials.token
|
||||
if needs_refresh:
|
||||
if credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
credentials.refresh(Request())
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(f"Failed to refresh credentials: {refresh_error}")
|
||||
else:
|
||||
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
|
||||
|
||||
return build('drive', 'v3', credentials=credentials)
|
||||
|
||||
except HttpError as e:
|
||||
raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to build Google Drive service: {str(e)}")
|
||||
|
||||
def is_token_expired(self, token_info):
|
||||
if 'expiry' in token_info and token_info['expiry']:
|
||||
try:
|
||||
from dateutil import parser
|
||||
# Google Drive provides timezone-aware ISO8601 dates
|
||||
expiry_dt = parser.parse(token_info['expiry'])
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
return current_time >= expiry_dt - datetime.timedelta(seconds=60)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if 'access_token' in token_info and token_info['access_token']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'client_id' not in token_info:
|
||||
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
|
||||
if 'client_secret' not in token_info:
|
||||
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
|
||||
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}")
|
||||
|
||||
def validate_credentials(self, credentials: Credentials) -> bool:
|
||||
"""
|
||||
Validate Google Drive credentials by making a test API call.
|
||||
|
||||
Args:
|
||||
credentials: Google credentials object
|
||||
|
||||
Returns:
|
||||
True if credentials are valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
service = self.build_drive_service(credentials)
|
||||
service.about().get(fields="user").execute()
|
||||
return True
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error validating credentials: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating credentials: {e}")
|
||||
return False
|
||||
536
application/parser/connectors/google_drive/loader.py
Normal file
536
application/parser/connectors/google_drive/loader.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
Google Drive loader for DocsGPT.
|
||||
Loads documents from Google Drive using Google Drive API.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class GoogleDriveLoader(BaseConnectorLoader):
|
||||
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'application/pdf': '.pdf',
|
||||
'application/vnd.google-apps.document': '.docx',
|
||||
'application/vnd.google-apps.presentation': '.pptx',
|
||||
'application/vnd.google-apps.spreadsheet': '.xlsx',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/msword': '.doc',
|
||||
'application/vnd.ms-powerpoint': '.ppt',
|
||||
'application/vnd.ms-excel': '.xls',
|
||||
'text/plain': '.txt',
|
||||
'text/csv': '.csv',
|
||||
'text/html': '.html',
|
||||
'application/rtf': '.rtf',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/jpg': '.jpg',
|
||||
'image/png': '.png',
|
||||
}
|
||||
|
||||
EXPORT_FORMATS = {
|
||||
'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
||||
}
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
self.auth = GoogleDriveAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.credentials = self.auth.create_credentials_from_token_info(token_info)
|
||||
|
||||
try:
|
||||
self.service = self.auth.build_drive_service(self.credentials)
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not build Google Drive service: {e}")
|
||||
self.service = None
|
||||
|
||||
self.next_page_token = None
|
||||
|
||||
|
||||
|
||||
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
|
||||
try:
|
||||
file_id = file_metadata.get('id')
|
||||
file_name = file_metadata.get('name', 'Unknown')
|
||||
mime_type = file_metadata.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||
return None
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
|
||||
return None
|
||||
# Google Drive provides timezone-aware ISO8601 dates
|
||||
doc_metadata = {
|
||||
'file_name': file_name,
|
||||
'mime_type': mime_type,
|
||||
'size': file_metadata.get('size', None),
|
||||
'created_time': file_metadata.get('createdTime'),
|
||||
'modified_time': file_metadata.get('modifiedTime'),
|
||||
'parents': file_metadata.get('parents', []),
|
||||
'source': 'google_drive'
|
||||
}
|
||||
|
||||
if not load_content:
|
||||
return Document(
|
||||
text="",
|
||||
doc_id=file_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
content = self._download_file_content(file_id, mime_type)
|
||||
if content is None:
|
||||
logging.warning(f"Could not load content for file {file_name} ({file_id})")
|
||||
return None
|
||||
|
||||
return Document(
|
||||
text=content,
|
||||
doc_id=file_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file: {e}")
|
||||
return None
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
session_token = inputs.get('session_token')
|
||||
if session_token and session_token != self.session_token:
|
||||
logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
|
||||
self.config = inputs
|
||||
|
||||
try:
|
||||
documents: List[Document] = []
|
||||
|
||||
folder_id = inputs.get('folder_id')
|
||||
file_ids = inputs.get('file_ids', [])
|
||||
limit = inputs.get('limit', 100)
|
||||
list_only = inputs.get('list_only', False)
|
||||
load_content = not list_only
|
||||
page_token = inputs.get('page_token')
|
||||
self.next_page_token = None
|
||||
|
||||
if file_ids:
|
||||
# Specific files requested: load them
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
|
||||
self._credential_refreshed = False
|
||||
logging.info(f"Retrying load of file {file_id} after credential refresh")
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Browsing mode: list immediate children of provided folder or root
|
||||
parent_id = folder_id if folder_id else 'root'
|
||||
documents = self._list_items_in_parent(parent_id, limit=limit, load_content=load_content, page_token=page_token)
|
||||
|
||||
logging.info(f"Loaded {len(documents)} documents from Google Drive")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading data from Google Drive: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
|
||||
self._ensure_service()
|
||||
|
||||
try:
|
||||
file_metadata = self.service.files().get(
|
||||
fileId=file_id,
|
||||
fields='id,name,mimeType,size,createdTime,modifiedTime,parents'
|
||||
).execute()
|
||||
|
||||
return self._process_file(file_metadata, load_content=load_content)
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}")
|
||||
|
||||
if e.resp.status in [401, 403]:
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
self._ensure_service()
|
||||
return None
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||
else:
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None) -> List[Document]:
|
||||
self._ensure_service()
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
query = f"'{parent_id}' in parents and trashed=false"
|
||||
next_token_out: Optional[str] = None
|
||||
|
||||
while True:
|
||||
page_size = 100
|
||||
if limit:
|
||||
remaining = max(0, limit - len(documents))
|
||||
if remaining == 0:
|
||||
break
|
||||
page_size = min(100, remaining)
|
||||
|
||||
results = self.service.files().list(
|
||||
q=query,
|
||||
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
|
||||
pageToken=page_token,
|
||||
pageSize=page_size
|
||||
).execute()
|
||||
|
||||
items = results.get('files', [])
|
||||
for item in items:
|
||||
mime_type = item.get('mimeType')
|
||||
if mime_type == 'application/vnd.google-apps.folder':
|
||||
doc_metadata = {
|
||||
'file_name': item.get('name', 'Unknown'),
|
||||
'mime_type': mime_type,
|
||||
'size': item.get('size', None),
|
||||
'created_time': item.get('createdTime'),
|
||||
'modified_time': item.get('modifiedTime'),
|
||||
'parents': item.get('parents', []),
|
||||
'source': 'google_drive',
|
||||
'is_folder': True
|
||||
}
|
||||
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
|
||||
else:
|
||||
doc = self._process_file(item, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
if limit and len(documents) >= limit:
|
||||
self.next_page_token = results.get('nextPageToken')
|
||||
return documents
|
||||
|
||||
page_token = results.get('nextPageToken')
|
||||
next_token_out = page_token
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
self.next_page_token = next_token_out
|
||||
return documents
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing items under parent {parent_id}: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
||||
def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]:
|
||||
if not self.credentials.token:
|
||||
logging.warning("No access token in credentials, attempting to refresh")
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
logging.info("Credentials refreshed successfully")
|
||||
self._ensure_service()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to refresh credentials: {e}")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token")
|
||||
else:
|
||||
logging.error("No access token and no refresh_token available")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
if self.credentials.expired:
|
||||
logging.warning("Credentials are expired, attempting to refresh")
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
logging.info("Credentials refreshed successfully")
|
||||
self._ensure_service()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to refresh expired credentials: {e}")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: expired credentials")
|
||||
else:
|
||||
logging.error("Credentials expired and no refresh_token available")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
try:
|
||||
if mime_type in self.EXPORT_FORMATS:
|
||||
export_mime_type = self.EXPORT_FORMATS[mime_type]
|
||||
request = self.service.files().export_media(
|
||||
fileId=file_id,
|
||||
mimeType=export_mime_type
|
||||
)
|
||||
else:
|
||||
request = self.service.files().get_media(fileId=file_id)
|
||||
|
||||
file_io = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(file_io, request)
|
||||
|
||||
done = False
|
||||
while done is False:
|
||||
try:
|
||||
_, done = downloader.next_chunk()
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error during download of file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
content_bytes = file_io.getvalue()
|
||||
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = content_bytes.decode('latin-1')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
return content
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||
|
||||
if e.resp.status in [401, 403]:
|
||||
logging.error(f"Authentication error downloading file {file_id}")
|
||||
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
logging.info(f"Attempting to refresh credentials for file {file_id}")
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
logging.info("Credentials refreshed successfully")
|
||||
self._credential_refreshed = True
|
||||
self._ensure_service()
|
||||
return None
|
||||
except Exception as refresh_error:
|
||||
logging.error(f"Error refreshing credentials: {refresh_error}")
|
||||
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||
else:
|
||||
logging.error("Cannot refresh credentials: missing refresh_token")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
self._ensure_service()
|
||||
return self._download_single_file(file_id, local_dir)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _ensure_service(self):
|
||||
if not self.service:
|
||||
try:
|
||||
self.service = self.auth.build_drive_service(self.credentials)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Cannot access Google Drive: {e}")
|
||||
|
||||
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
|
||||
file_metadata = self.service.files().get(
|
||||
fileId=file_id,
|
||||
fields='name,mimeType'
|
||||
).execute()
|
||||
|
||||
file_name = file_metadata['name']
|
||||
mime_type = file_metadata['mimeType']
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||
return False
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
if mime_type in self.EXPORT_FORMATS:
|
||||
export_mime_type = self.EXPORT_FORMATS[mime_type]
|
||||
request = self.service.files().export_media(
|
||||
fileId=file_id,
|
||||
mimeType=export_mime_type
|
||||
)
|
||||
extension = self._get_extension_for_mime_type(export_mime_type)
|
||||
if not full_path.endswith(extension):
|
||||
full_path += extension
|
||||
else:
|
||||
request = self.service.files().get_media(fileId=file_id)
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
downloader = MediaIoBaseDownload(f, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
return True
|
||||
|
||||
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
files_downloaded = 0
|
||||
try:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
query = f"'{folder_id}' in parents and trashed=false"
|
||||
page_token = None
|
||||
|
||||
while True:
|
||||
results = self.service.files().list(
|
||||
q=query,
|
||||
fields='nextPageToken, files(id, name, mimeType)',
|
||||
pageToken=page_token,
|
||||
pageSize=1000
|
||||
).execute()
|
||||
|
||||
items = results.get('files', [])
|
||||
logging.info(f"Found {len(items)} items in folder {folder_id}")
|
||||
|
||||
for item in items:
|
||||
item_name = item['name']
|
||||
item_id = item['id']
|
||||
mime_type = item['mimeType']
|
||||
|
||||
if mime_type == 'application/vnd.google-apps.folder':
|
||||
if recursive:
|
||||
# Create subfolder and recurse
|
||||
subfolder_path = os.path.join(local_dir, item_name)
|
||||
os.makedirs(subfolder_path, exist_ok=True)
|
||||
subfolder_files = self._download_folder_recursive(
|
||||
item_id,
|
||||
subfolder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += subfolder_files
|
||||
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
|
||||
else:
|
||||
# Download file
|
||||
success = self._download_single_file(item_id, local_dir)
|
||||
if success:
|
||||
files_downloaded += 1
|
||||
logging.info(f"Downloaded file: {item_name}")
|
||||
else:
|
||||
logging.warning(f"Failed to download file: {item_name}")
|
||||
|
||||
page_token = results.get('nextPageToken')
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return files_downloaded
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
|
||||
return files_downloaded
|
||||
|
||||
def _get_extension_for_mime_type(self, mime_type: str) -> str:
|
||||
extensions = {
|
||||
'application/pdf': '.pdf',
|
||||
'text/plain': '.txt',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
}
|
||||
return extensions.get(mime_type, '.bin')
|
||||
|
||||
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
try:
|
||||
self._ensure_service()
|
||||
return self._download_folder_recursive(folder_id, local_dir, recursive)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||
if source_config is None:
|
||||
source_config = {}
|
||||
|
||||
config = source_config if source_config else getattr(self, 'config', {})
|
||||
files_downloaded = 0
|
||||
|
||||
try:
|
||||
folder_ids = config.get('folder_ids', [])
|
||||
file_ids = config.get('file_ids', [])
|
||||
recursive = config.get('recursive', True)
|
||||
|
||||
self._ensure_service()
|
||||
|
||||
if file_ids:
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
|
||||
for file_id in file_ids:
|
||||
if self._download_file_to_directory(file_id, local_dir):
|
||||
files_downloaded += 1
|
||||
|
||||
# Process folders
|
||||
if folder_ids:
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for folder_id in folder_ids:
|
||||
try:
|
||||
folder_metadata = self.service.files().get(
|
||||
fileId=folder_id,
|
||||
fields='name'
|
||||
).execute()
|
||||
folder_name = folder_metadata.get('name', '')
|
||||
folder_path = os.path.join(local_dir, folder_name)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
folder_files = self._download_folder_recursive(
|
||||
folder_id,
|
||||
folder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += folder_files
|
||||
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
|
||||
if not file_ids and not folder_ids:
|
||||
raise ValueError("No folder_ids or file_ids provided for download")
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "google_drive",
|
||||
"config_used": config
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": True,
|
||||
"source_type": "google_drive",
|
||||
"config_used": config,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -6,6 +6,16 @@ from application.parser.remote.github_loader import GitHubLoader
|
||||
|
||||
|
||||
class RemoteCreator:
|
||||
"""
|
||||
Factory class for creating remote content loaders.
|
||||
|
||||
These loaders fetch content from remote web sources like URLs,
|
||||
sitemaps, web crawlers, social media platforms, etc.
|
||||
|
||||
For external knowledge base connectors (like Google Drive),
|
||||
use ConnectorCreator instead.
|
||||
"""
|
||||
|
||||
loaders = {
|
||||
"url": WebLoader,
|
||||
"sitemap": SitemapLoader,
|
||||
@@ -18,5 +28,5 @@ class RemoteCreator:
|
||||
def create_loader(cls, type, *args, **kwargs):
|
||||
loader_class = cls.loaders.get(type.lower())
|
||||
if not loader_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
raise ValueError(f"No loader class found for type {type}")
|
||||
return loader_class(*args, **kwargs)
|
||||
|
||||
@@ -14,6 +14,9 @@ Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.3.0
|
||||
google-api-python-client==2.179.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-auth-oauthlib==1.2.2
|
||||
gTTS==2.5.4
|
||||
gunicorn==23.0.0
|
||||
javalang==0.13.0
|
||||
|
||||
@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
@@ -20,10 +21,20 @@ class ClassicRAG(BaseRetriever):
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = ""
|
||||
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
|
||||
self.original_question = source.get("question", "")
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
if isinstance(chunks, str):
|
||||
try:
|
||||
self.chunks = int(chunks)
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f"Invalid chunks value '{chunks}', using default value 2"
|
||||
)
|
||||
self.chunks = 2
|
||||
else:
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
@@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
self.vectorstores = source["active_docs"]
|
||||
else:
|
||||
self.vectorstores = [source["active_docs"]]
|
||||
else:
|
||||
self.vectorstores = []
|
||||
self.question = self._rephrase_query()
|
||||
self.decoded_token = decoded_token
|
||||
self._validate_vectorstore_config()
|
||||
|
||||
def _validate_vectorstore_config(self):
|
||||
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
||||
if not self.vectorstores:
|
||||
logging.warning("No vectorstores configured for retrieval")
|
||||
return
|
||||
invalid_ids = [
|
||||
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
||||
]
|
||||
if invalid_ids:
|
||||
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
||||
self.vectorstores = [
|
||||
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
||||
]
|
||||
|
||||
def _rephrase_query(self):
|
||||
"""Rephrase user query with chat history context for better retrieval"""
|
||||
if (
|
||||
not self.original_question
|
||||
or not self.chat_history
|
||||
or self.chat_history == []
|
||||
or self.chunks == 0
|
||||
or self.vectorstore is None
|
||||
or not self.vectorstores
|
||||
):
|
||||
return self.original_question
|
||||
|
||||
prompt = f"""Given the following conversation history:
|
||||
|
||||
{self.chat_history}
|
||||
|
||||
|
||||
|
||||
Rephrase the following user question to be a standalone search query
|
||||
|
||||
that captures all relevant context from the conversation:
|
||||
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever):
|
||||
return self.original_question
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0 or self.vectorstore is None:
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||
docs = [
|
||||
{
|
||||
"title": i.metadata.get(
|
||||
"title", i.metadata.get("post_title", i.page_content)
|
||||
).split("/")[-1],
|
||||
"text": i.page_content,
|
||||
"source": (
|
||||
i.metadata.get("source")
|
||||
if i.metadata.get("source")
|
||||
else "local"
|
||||
),
|
||||
}
|
||||
for i in docs_temp
|
||||
]
|
||||
"""Retrieve relevant documents from configured vectorstores"""
|
||||
if self.chunks == 0 or not self.vectorstores:
|
||||
return []
|
||||
all_docs = []
|
||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||
|
||||
return docs
|
||||
for vectorstore_id in self.vectorstores:
|
||||
if vectorstore_id:
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=chunks_per_source)
|
||||
|
||||
def gen():
|
||||
pass
|
||||
for doc in docs_temp:
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
if isinstance(title, str):
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
title = str(title).split("/")[-1]
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": metadata.get("source") or vectorstore_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
return all_docs
|
||||
|
||||
def search(self, query: str = ""):
|
||||
"""Search for documents using optional query override"""
|
||||
if query:
|
||||
self.original_question = query
|
||||
self.question = self._rephrase_query()
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
"""Return current retriever configuration parameters"""
|
||||
return {
|
||||
"question": self.original_question,
|
||||
"rephrased_question": self.question,
|
||||
"source": self.vectorstore,
|
||||
"sources": self.vectorstores,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
|
||||
|
||||
def embed_documents(self, documents: list):
|
||||
return self.model.encode(documents).tolist()
|
||||
|
||||
|
||||
def __call__(self, text):
|
||||
if isinstance(text, str):
|
||||
return self.embed_query(text)
|
||||
@@ -24,15 +32,14 @@ class EmbeddingsWrapper:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
|
||||
|
||||
|
||||
class EmbeddingsSingleton:
|
||||
_instances = {}
|
||||
|
||||
@staticmethod
|
||||
def get_instance(embeddings_name, *args, **kwargs):
|
||||
if embeddings_name not in EmbeddingsSingleton._instances:
|
||||
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
|
||||
embeddings_name, *args, **kwargs
|
||||
EmbeddingsSingleton._instances[embeddings_name] = (
|
||||
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
|
||||
)
|
||||
return EmbeddingsSingleton._instances[embeddings_name]
|
||||
|
||||
@@ -40,9 +47,15 @@ class EmbeddingsSingleton:
|
||||
def _create_instance(embeddings_name, *args, **kwargs):
|
||||
embeddings_factory = {
|
||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
|
||||
"hkunlp/instructor-large"
|
||||
),
|
||||
}
|
||||
|
||||
if embeddings_name in embeddings_factory:
|
||||
@@ -50,34 +63,63 @@ class EmbeddingsSingleton:
|
||||
else:
|
||||
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
"""Search for similar documents/chunks in the vectorstore"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts, metadatas=None, *args, **kwargs):
|
||||
"""Add texts with their embeddings to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
"""Delete the entire index/collection"""
|
||||
pass
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
"""Save vectorstore to local storage"""
|
||||
pass
|
||||
|
||||
def get_chunks(self, *args, **kwargs):
|
||||
"""Get all chunks from the vectorstore"""
|
||||
pass
|
||||
|
||||
def add_chunk(self, text, metadata=None, *args, **kwargs):
|
||||
"""Add a single chunk to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_chunk(self, chunk_id, *args, **kwargs):
|
||||
"""Delete a specific chunk from the vectorstore"""
|
||||
pass
|
||||
|
||||
def is_azure_configured(self):
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
||||
if embeddings_name == "openai_text-embedding-ada-002":
|
||||
if self.is_azure_configured():
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
openai_api_key=embeddings_key
|
||||
embeddings_name, openai_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name = "./models/all-mpnet-base-v2",
|
||||
embeddings_name="./models/all-mpnet-base-v2",
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
@@ -87,4 +129,3 @@ class BaseVectorStore(ABC):
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||
|
||||
return embedding_instance
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import shutil
|
||||
import string
|
||||
import tempfile
|
||||
from typing import Any, Dict
|
||||
import zipfile
|
||||
|
||||
from collections import Counter
|
||||
@@ -21,6 +22,7 @@ from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.chunking import Chunker
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.embedding_pipeline import embed_and_store_documents
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
from application.parser.remote.remote_creator import RemoteCreator
|
||||
@@ -649,8 +651,11 @@ def remote_worker(
|
||||
"id": str(id),
|
||||
"type": loader,
|
||||
"remote_data": source_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"sync_frequency": sync_frequency
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
upload_index(full_path, file_data)
|
||||
except Exception as e:
|
||||
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
|
||||
@@ -707,7 +712,7 @@ def sync_worker(self, frequency):
|
||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||
)
|
||||
sync_counts["total_sync_count"] += 1
|
||||
sync_counts[
|
||||
sync_counts[
|
||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||
] += 1
|
||||
return {
|
||||
@@ -744,7 +749,7 @@ def attachment_worker(self, file_info, user):
|
||||
input_files=[local_path], exclude_hidden=True, errors="ignore"
|
||||
)
|
||||
.load_data()[0]
|
||||
.text,
|
||||
.text,
|
||||
)
|
||||
|
||||
|
||||
@@ -835,3 +840,174 @@ def agent_webhook_worker(self, agent_id, payload):
|
||||
f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id}
|
||||
)
|
||||
return {"status": "success", "result": result}
|
||||
|
||||
|
||||
def ingest_connector(
|
||||
self,
|
||||
job_name: str,
|
||||
user: str,
|
||||
source_type: str,
|
||||
session_token=None,
|
||||
file_ids=None,
|
||||
folder_ids=None,
|
||||
recursive=True,
|
||||
retriever: str = "classic",
|
||||
operation_mode: str = "upload",
|
||||
doc_id=None,
|
||||
sync_frequency: str = "never",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Ingestion for internal knowledge bases (GoogleDrive, etc.).
|
||||
|
||||
Args:
|
||||
job_name: Name of the ingestion job
|
||||
user: User identifier
|
||||
source_type: Type of remote source ("google_drive", "dropbox", etc.)
|
||||
session_token: Authentication token for the service
|
||||
file_ids: List of file IDs to download
|
||||
folder_ids: List of folder IDs to download
|
||||
recursive: Whether to recursively download folders
|
||||
retriever: Type of retriever to use
|
||||
operation_mode: "upload" for initial ingestion, "sync" for incremental sync
|
||||
doc_id: Document ID for sync operations (required when operation_mode="sync")
|
||||
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
||||
"""
|
||||
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Step 1: Initialize the appropriate loader
|
||||
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
|
||||
|
||||
if not session_token:
|
||||
raise ValueError(f"{source_type} connector requires session_token")
|
||||
|
||||
if not ConnectorCreator.is_supported(source_type):
|
||||
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
|
||||
|
||||
remote_loader = ConnectorCreator.create_connector(source_type, session_token)
|
||||
|
||||
# Create a clean config for storage
|
||||
api_source_config = {
|
||||
"file_ids": file_ids or [],
|
||||
"folder_ids": folder_ids or [],
|
||||
"recursive": recursive
|
||||
}
|
||||
|
||||
# Step 2: Download files to temp directory
|
||||
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
|
||||
download_info = remote_loader.download_to_directory(
|
||||
temp_dir,
|
||||
api_source_config
|
||||
)
|
||||
|
||||
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
|
||||
logging.warning(f"No files were downloaded from {source_type}")
|
||||
# Create empty result directly instead of calling a separate method
|
||||
return {
|
||||
"name": job_name,
|
||||
"user": user,
|
||||
"tokens": 0,
|
||||
"type": source_type,
|
||||
"source_config": api_source_config,
|
||||
"directory_structure": "{}",
|
||||
}
|
||||
|
||||
# Step 3: Use SimpleDirectoryReader to process downloaded files
|
||||
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=temp_dir,
|
||||
recursive=True,
|
||||
required_exts=[
|
||||
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||
".jpg", ".jpeg",
|
||||
],
|
||||
exclude_hidden=True,
|
||||
file_metadata=metadata_from_filename,
|
||||
)
|
||||
raw_docs = reader.load_data()
|
||||
directory_structure = getattr(reader, 'directory_structure', {})
|
||||
|
||||
|
||||
|
||||
# Step 4: Process documents (chunking, embedding, etc.)
|
||||
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
|
||||
|
||||
chunker = Chunker(
|
||||
chunking_strategy="classic_chunk",
|
||||
max_tokens=MAX_TOKENS,
|
||||
min_tokens=MIN_TOKENS,
|
||||
duplicate_headers=False,
|
||||
)
|
||||
raw_docs = chunker.chunk(documents=raw_docs)
|
||||
|
||||
# Preserve source information in document metadata
|
||||
for doc in raw_docs:
|
||||
if hasattr(doc, 'extra_info') and doc.extra_info:
|
||||
source = doc.extra_info.get('source')
|
||||
if source and os.path.isabs(source):
|
||||
# Convert absolute path to relative path
|
||||
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = ObjectId()
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = ObjectId(doc_id)
|
||||
else:
|
||||
raise ValueError(f"Invalid operation_mode: {operation_mode}")
|
||||
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
|
||||
embed_and_store_documents(docs, vector_store_path, id, self)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
|
||||
# Step 6: Upload index files
|
||||
file_data = {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"type": "connector",
|
||||
"remote_data": json.dumps({
|
||||
"provider": source_type,
|
||||
**api_source_config
|
||||
}),
|
||||
"directory_structure": json.dumps(directory_structure),
|
||||
"sync_frequency": sync_frequency
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
else:
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
|
||||
upload_index(vector_store_path, file_data)
|
||||
|
||||
# Ensure we mark the task as complete
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
|
||||
|
||||
logging.info(f"Remote ingestion completed: {job_name}")
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"tokens": tokens,
|
||||
"type": source_type,
|
||||
"id": str(id),
|
||||
"status": "complete"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user