Merge branch 'main' of https://github.com/siiddhantt/DocsGPT into pr/1930

This commit is contained in:
Siddhant Rai
2025-09-10 20:15:20 +05:30
51 changed files with 3792 additions and 247 deletions

View File

@@ -3,11 +3,12 @@ import json
import math
import os
import secrets
import tempfile
import uuid
import zipfile
from functools import wraps
from typing import Optional, Tuple
import tempfile
import zipfile
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
@@ -25,26 +26,28 @@ from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.tasks import (
ingest,
ingest_connector_task,
ingest_remote,
process_agent_webhook,
store_attachment,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.api import api
from application.parser.connectors.connector_creator import ConnectorCreator
from application.storage.storage_creator import StorageCreator
from application.tts.google_tts import GoogleTTS
from application.utils import (
check_required_fields,
generate_image_url,
num_tokens_from_string,
safe_filename,
validate_function_name,
validate_required_fields,
)
from application.utils import num_tokens_from_string
from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
@@ -72,7 +75,6 @@ try:
users_collection.create_index("user_id", unique=True)
except Exception as e:
print("Error creating indexes:", e)
user = Blueprint("user", __name__)
user_ns = Namespace("user", description="User related operations", path="/")
api.add_namespace(user_ns)
@@ -125,11 +127,9 @@ def ensure_user_doc(user_id):
updates["agent_preferences.pinned"] = []
if "shared_with_me" not in prefs:
updates["agent_preferences.shared_with_me"] = []
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
return user_doc
@@ -181,7 +181,6 @@ def handle_image_upload(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
return image_url, None
@@ -295,8 +294,8 @@ class GetSingleConversation(Resource):
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
# Process queries to include attachment names
queries = conversation["queries"]
for query in queries:
if "attachments" in query and query["attachments"]:
@@ -492,11 +491,11 @@ class DeleteOldIndexes(Resource):
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
index_path = f"indexes/{str(doc['_id'])}"
if storage.file_exists(f"{index_path}/index.faiss"):
@@ -508,7 +507,6 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
@@ -517,7 +515,6 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f)
else:
storage.delete_file(file_path)
except FileNotFoundError:
pass
except Exception as err:
@@ -525,7 +522,6 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -567,6 +563,7 @@ class UploadFile(Resource):
job_name = request.form["name"]
# Create safe versions for filesystem operations
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
@@ -588,6 +585,7 @@ class UploadFile(Resource):
zip_ref.extractall(path=temp_dir)
# Walk through extracted files and upload them
for root, _, files in os.walk(temp_dir):
for extracted_file in files:
if (
@@ -595,7 +593,6 @@ class UploadFile(Resource):
== temp_file_path
):
continue
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
@@ -610,15 +607,16 @@ class UploadFile(Resource):
f"Error extracting zip: {e}", exc_info=True
)
# If zip extraction fails, save the original zip file
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
else:
# For non-zip files, save directly
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
@@ -690,7 +688,6 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
operation = request.form.get("operation")
@@ -705,7 +702,6 @@ class ManageSourceFiles(Resource):
),
400,
)
if operation not in ["add", "remove", "remove_directory"]:
return make_response(
jsonify(
@@ -716,14 +712,12 @@ class ManageSourceFiles(Resource):
),
400,
)
try:
ObjectId(source_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid source ID format"}), 400
)
try:
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
@@ -743,7 +737,6 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -756,7 +749,6 @@ class ManageSourceFiles(Resource):
),
400,
)
if operation == "add":
files = request.files.getlist("file")
if not files or all(file.filename == "" for file in files):
@@ -769,23 +761,22 @@ class ManageSourceFiles(Resource):
),
400,
)
added_files = []
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
@@ -802,7 +793,6 @@ class ManageSourceFiles(Resource):
),
200,
)
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
@@ -815,7 +805,6 @@ class ManageSourceFiles(Resource):
),
400,
)
try:
file_paths = (
json.loads(file_paths_str)
@@ -829,18 +818,19 @@ class ManageSourceFiles(Resource):
),
400,
)
# Remove files from storage and directory structure
removed_files = []
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
@@ -856,7 +846,6 @@ class ManageSourceFiles(Resource):
),
200,
)
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
@@ -869,8 +858,8 @@ class ManageSourceFiles(Resource):
),
400,
)
# Validate directory path (prevent path traversal)
if directory_path.startswith("/") or ".." in directory_path:
current_app.logger.warning(
f"Invalid directory path attempted for removal. "
@@ -882,7 +871,6 @@ class ManageSourceFiles(Resource):
),
400,
)
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
@@ -904,7 +892,6 @@ class ManageSourceFiles(Resource):
),
404,
)
success = storage.remove_directory(full_directory_path)
if not success:
@@ -919,7 +906,6 @@ class ManageSourceFiles(Resource):
),
500,
)
current_app.logger.info(
f"Successfully removed directory. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
@@ -927,6 +913,7 @@ class ManageSourceFiles(Resource):
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
@@ -942,7 +929,6 @@ class ManageSourceFiles(Resource):
),
200,
)
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
@@ -954,7 +940,6 @@ class ManageSourceFiles(Resource):
elif operation == "add":
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
@@ -1001,6 +986,50 @@ class UploadRemote(Resource):
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
return make_response(
jsonify(
{
"success": False,
"error": f"Missing session_token in {data['source']} configuration",
}
),
400,
)
# Process file_ids
file_ids = config.get("file_ids", [])
if isinstance(file_ids, str):
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
elif not isinstance(file_ids, list):
file_ids = []
# Process folder_ids
folder_ids = config.get("folder_ids", [])
if isinstance(folder_ids, str):
folder_ids = [
id.strip() for id in folder_ids.split(",") if id.strip()
]
elif not isinstance(folder_ids, list):
folder_ids = []
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
)
return make_response(
jsonify({"success": True, "task_id": task.id}), 200
)
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
@@ -1109,6 +1138,7 @@ class PaginatedSources(Resource):
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
paginated_docs.append(doc_data)
response = {
@@ -1157,6 +1187,9 @@ class CombinedJson(Resource):
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
), # Add type field with default "file"
}
)
except Exception as err:
@@ -1372,17 +1405,14 @@ class GetAgent(Resource):
def get(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
if not (agent_id := request.args.get("id")):
return {"success": False, "message": "ID required"}, 400
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
)
if not agent:
return {"status": "Not found"}, 404
data = {
"id": str(agent["_id"]),
"name": agent["name"],
@@ -1428,7 +1458,6 @@ class GetAgent(Resource):
"shared_token": agent.get("shared_token", ""),
}
return make_response(jsonify(data), 200)
except Exception as e:
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
return {"success": False}, 400
@@ -1440,7 +1469,6 @@ class GetAgents(Resource):
def get(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
user = decoded_token.get("sub")
try:
user_doc = ensure_user_doc(user)
@@ -1501,7 +1529,6 @@ class GetAgents(Resource):
for agent in agents
if "source" in agent or "retriever" in agent
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1573,9 +1600,11 @@ class CreateAgent(Resource):
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
@@ -1587,8 +1616,8 @@ class CreateAgent(Resource):
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
@@ -1606,7 +1635,6 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") not in ["draft", "published"]:
return make_response(
jsonify(
@@ -1617,7 +1645,6 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
@@ -1628,6 +1655,7 @@ class CreateAgent(Resource):
"agent_type",
]
# Require either source or sources (but not both)
if not data.get("source") and not data.get("sources"):
return make_response(
jsonify(
@@ -1648,13 +1676,11 @@ class CreateAgent(Resource):
return missing_fields
if invalid_fields:
return invalid_fields
image_url, error = handle_image_upload(request, "", user, storage)
if error:
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
@@ -1674,7 +1700,6 @@ class CreateAgent(Resource):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
@@ -1772,7 +1797,6 @@ class UpdateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if not ObjectId.is_valid(agent_id):
return make_response(
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
@@ -1796,7 +1820,6 @@ class UpdateAgent(Resource):
),
404,
)
image_url, error = handle_image_upload(
request, existing_agent.get("image", ""), user, storage
)
@@ -1804,7 +1827,6 @@ class UpdateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
update_fields = {}
allowed_fields = [
"name",
@@ -1837,6 +1859,7 @@ class UpdateAgent(Resource):
source_id = data.get("source")
if source_id == "default":
# Handle special "default" source
update_fields[field] = "default"
elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
@@ -1941,7 +1964,6 @@ class UpdateAgent(Resource):
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -2028,7 +2050,6 @@ class PinnedAgents(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
@@ -2037,7 +2058,6 @@ class PinnedAgents(Resource):
if not pinned_ids:
return make_response(jsonify([]), 200)
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
pinned_agents_cursor = agents_collection.find(
@@ -2047,6 +2067,7 @@ class PinnedAgents(Resource):
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
# Clean up any stale pinned IDs
stale_ids = [
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
]
@@ -2055,7 +2076,6 @@ class PinnedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
)
list_pinned_agents = [
{
"id": str(agent["_id"]),
@@ -2092,11 +2112,9 @@ class PinnedAgents(Resource):
for agent in pinned_agents
if "source" in agent or "retriever" in agent
]
except Exception as err:
current_app.logger.error(f"Error retrieving pinned agents: {err}")
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_pinned_agents), 200)
@@ -2160,7 +2178,6 @@ class RemoveSharedAgent(Resource):
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "shared_publicly": True}
@@ -2170,7 +2187,6 @@ class RemoveSharedAgent(Resource):
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
@@ -2183,7 +2199,6 @@ class RemoveSharedAgent(Resource):
)
return make_response(jsonify({"success": True, "action": "removed"}), 200)
except Exception as err:
current_app.logger.error(f"Error removing shared agent: {err}")
return make_response(
@@ -2206,7 +2221,6 @@ class SharedAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
query = {
"shared_publicly": True,
@@ -2218,7 +2232,6 @@ class SharedAgent(Resource):
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
agent_id = str(shared_agent["_id"])
data = {
"id": agent_id,
@@ -2230,7 +2243,12 @@ class SharedAgent(Resource):
else ""
),
"description": shared_agent.get("description", ""),
"source": shared_agent.get("source", ""),
"source": (
str(source_doc["_id"])
if isinstance(shared_agent.get("source"), DBRef)
and (source_doc := db.dereference(shared_agent.get("source")))
else ""
),
"chunks": shared_agent.get("chunks", "0"),
"retriever": shared_agent.get("retriever", "classic"),
"prompt_id": shared_agent.get("prompt_id", "default"),
@@ -2253,7 +2271,6 @@ class SharedAgent(Resource):
if tool_data:
enriched_tools.append(tool_data.get("name", ""))
data["tools"] = enriched_tools
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
@@ -2265,9 +2282,7 @@ class SharedAgent(Resource):
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -2301,7 +2316,6 @@ class SharedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = [
@@ -2328,7 +2342,6 @@ class SharedAgents(Resource):
]
return make_response(jsonify(list_shared_agents), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agents: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -3808,22 +3821,22 @@ class GetChunks(Resource):
metadata = chunk.get("metadata", {})
# Filter by path if provided
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
if not chunk_source or not chunk_source.endswith(path):
continue
# Filter by search term if provided
if search_term:
text_match = search_term in chunk.get("text", "").lower()
title_match = search_term in metadata.get("title", "").lower()
if not (text_match or title_match):
continue
filtered_chunks.append(chunk)
chunks = filtered_chunks
total_chunks = len(chunks)
@@ -3983,7 +3996,6 @@ class UpdateChunk(Resource):
if metadata is None:
metadata = {}
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
@@ -3998,7 +4010,6 @@ class UpdateChunk(Resource):
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
if not existing_chunk:
return make_response(jsonify({"error": "Chunk not found"}), 404)
new_text = text if text is not None else existing_chunk["text"]
if metadata is not None:
@@ -4006,10 +4017,8 @@ class UpdateChunk(Resource):
new_metadata.update(metadata)
else:
new_metadata = existing_chunk["metadata"].copy()
if text is not None:
new_metadata["token_count"] = num_tokens_from_string(new_text)
try:
new_chunk_id = store.add_chunk(new_text, new_metadata)
@@ -4018,7 +4027,6 @@ class UpdateChunk(Resource):
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
return make_response(
jsonify(
{
@@ -4065,7 +4073,6 @@ class StoreAttachment(Resource):
jsonify({"status": "error", "message": "Missing file"}),
400,
)
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
@@ -4080,7 +4087,6 @@ class StoreAttachment(Resource):
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
@@ -4122,7 +4128,6 @@ class ServeImage(Resource):
content_type = f"image/{extension}"
if extension == "jpg":
content_type = "image/jpeg"
response = make_response(file_obj.read())
response.headers.set("Content-Type", content_type)
response.headers.set("Cache-Control", "max-age=86400")
@@ -4149,36 +4154,43 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "")
provider = None
remote_data = doc.get("remote_data")
try:
if isinstance(remote_data, str) and remote_data:
remote_data_obj = json.loads(remote_data)
provider = remote_data_obj.get("provider")
except Exception as e:
current_app.logger.warning(
f"Failed to parse remote_data for doc {doc_id}: {e}"
)
return make_response(
jsonify(
{
"success": True,
"directory_structure": directory_structure,
"base_path": doc.get("file_path", ""),
"base_path": base_path,
"provider": provider,
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True