feat: add support for multiple sources in agent configuration and update related components

This commit is contained in:
Siddhant Rai
2025-09-08 22:10:08 +05:30
parent 07d59b6640
commit 2f88890c94
5 changed files with 592 additions and 158 deletions

View File

@@ -492,9 +492,9 @@ class DeleteOldIndexes(Resource):
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
@@ -508,7 +508,7 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
@@ -517,7 +517,7 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f)
else:
storage.delete_file(file_path)
except FileNotFoundError:
pass
except Exception as err:
@@ -525,7 +525,7 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -573,55 +573,75 @@ class UploadFile(Resource):
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
safe_file = safe_filename(original_filename)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
if zipfile.is_zipfile(temp_file_path):
try:
with zipfile.ZipFile(temp_file_path, 'r') as zip_ref:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
# Walk through extracted files and upload them
for root, _, files in os.walk(temp_dir):
for extracted_file in files:
if os.path.join(root, extracted_file) == temp_file_path:
if (
os.path.join(root, extracted_file)
== temp_file_path
):
continue
rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir)
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
with open(os.path.join(root, extracted_file), 'rb') as f:
with open(
os.path.join(root, extracted_file), "rb"
) as f:
storage.save_file(f, storage_path)
except Exception as e:
current_app.logger.error(f"Error extracting zip: {e}", exc_info=True)
current_app.logger.error(
f"Error extracting zip: {e}", exc_info=True
)
# If zip extraction fails, save the original zip file
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, 'rb') as f:
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
else:
# For non-zip files, save directly
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, 'rb') as f:
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
".jpg", ".jpeg",
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
user,
file_path=base_path,
filename=dir_name
filename=dir_name,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
@@ -635,12 +655,29 @@ class ManageSourceFiles(Resource):
api.model(
"ManageSourceFilesModel",
{
"source_id": fields.String(required=True, description="Source ID to modify"),
"operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"),
"file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"),
"directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"),
"file": fields.Raw(required=False, description="Files to add (for add operation)"),
"parent_dir": fields.String(required=False, description="Parent directory path relative to source root"),
"source_id": fields.String(
required=True, description="Source ID to modify"
),
"operation": fields.String(
required=True,
description="Operation: 'add', 'remove', or 'remove_directory'",
),
"file_paths": fields.List(
fields.String,
required=False,
description="File paths to remove (for remove operation)",
),
"directory_path": fields.String(
required=False,
description="Directory path to remove (for remove_directory operation)",
),
"file": fields.Raw(
required=False, description="Files to add (for add operation)"
),
"parent_dir": fields.String(
required=False,
description="Parent directory path relative to source root",
),
},
)
)
@@ -650,7 +687,9 @@ class ManageSourceFiles(Resource):
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401)
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
@@ -658,12 +697,24 @@ class ManageSourceFiles(Resource):
if not source_id or not operation:
return make_response(
jsonify({"success": False, "message": "source_id and operation are required"}), 400
jsonify(
{
"success": False,
"message": "source_id and operation are required",
}
),
400,
)
if operation not in ["add", "remove", "remove_directory"]:
return make_response(
jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400
jsonify(
{
"success": False,
"message": "operation must be 'add', 'remove', or 'remove_directory'",
}
),
400,
)
try:
@@ -674,34 +725,53 @@ class ManageSourceFiles(Resource):
)
try:
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not source:
return make_response(
jsonify({"success": False, "message": "Source not found or access denied"}), 404
jsonify(
{
"success": False,
"message": "Source not found or access denied",
}
),
404,
)
except Exception as err:
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": "Database error"}), 500)
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
parent_dir = request.form.get("parent_dir", "")
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
jsonify({"success": False, "message": "Invalid parent directory path"}), 400
jsonify(
{"success": False, "message": "Invalid parent directory path"}
),
400,
)
if operation == "add":
files = request.files.getlist("file")
if not files or all(file.filename == "" for file in files):
return make_response(
jsonify({"success": False, "message": "No files provided for add operation"}), 400
jsonify(
{
"success": False,
"message": "No files provided for add operation",
}
),
400,
)
added_files = []
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
@@ -720,26 +790,44 @@ class ManageSourceFiles(Resource):
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(jsonify({
"success": True,
"message": f"Added {len(added_files)} files",
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id
}), 200)
return make_response(
jsonify(
{
"success": True,
"message": f"Added {len(added_files)} files",
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
return make_response(
jsonify({"success": False, "message": "file_paths required for remove operation"}), 400
jsonify(
{
"success": False,
"message": "file_paths required for remove operation",
}
),
400,
)
try:
file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str
file_paths = (
json.loads(file_paths_str)
if isinstance(file_paths_str, str)
else file_paths_str
)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid file_paths format"}), 400
jsonify(
{"success": False, "message": "Invalid file_paths format"}
),
400,
)
# Remove files from storage and directory structure
@@ -757,18 +845,29 @@ class ManageSourceFiles(Resource):
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(jsonify({
"success": True,
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id
}), 200)
return make_response(
jsonify(
{
"success": True,
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
return make_response(
jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400
jsonify(
{
"success": False,
"message": "directory_path required for remove_directory operation",
}
),
400,
)
# Validate directory path (prevent path traversal)
@@ -778,10 +877,17 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
)
return make_response(
jsonify({"success": False, "message": "Invalid directory path"}), 400
jsonify(
{"success": False, "message": "Invalid directory path"}
),
400,
)
full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
else source_file_path
)
if not storage.is_directory(full_directory_path):
current_app.logger.warning(
@@ -790,7 +896,13 @@ class ManageSourceFiles(Resource):
f"Full path: {full_directory_path}"
)
return make_response(
jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404
jsonify(
{
"success": False,
"message": "Directory not found or is not a directory",
}
),
404,
)
success = storage.remove_directory(full_directory_path)
@@ -802,7 +914,10 @@ class ManageSourceFiles(Resource):
f"Full path: {full_directory_path}"
)
return make_response(
jsonify({"success": False, "message": "Failed to remove directory"}), 500
jsonify(
{"success": False, "message": "Failed to remove directory"}
),
500,
)
current_app.logger.info(
@@ -816,12 +931,17 @@ class ManageSourceFiles(Resource):
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(jsonify({
"success": True,
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id
}), 200)
return make_response(
jsonify(
{
"success": True,
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
}
),
200,
)
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
@@ -835,8 +955,12 @@ class ManageSourceFiles(Resource):
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True)
return make_response(jsonify({"success": False, "message": "Operation failed"}), 500)
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Operation failed"}), 500
)
@user_ns.route("/api/remote")
@@ -984,7 +1108,7 @@ class PaginatedSources(Resource):
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"isNested": bool(doc.get("directory_structure"))
"isNested": bool(doc.get("directory_structure")),
}
paginated_docs.append(doc_data)
response = {
@@ -1032,7 +1156,7 @@ class CombinedJson(Resource):
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"is_nested": bool(index.get("directory_structure"))
"is_nested": bool(index.get("directory_structure")),
}
)
except Exception as err:
@@ -1272,6 +1396,16 @@ class GetAgent(Resource):
and (source_doc := db.dereference(agent.get("source")))
else ""
),
"sources": [
(
str(db.dereference(source_ref)["_id"])
if isinstance(source_ref, DBRef) and db.dereference(source_ref)
else source_ref
)
for source_ref in agent.get("sources", [])
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
@@ -1325,8 +1459,24 @@ class GetAgents(Resource):
str(source_doc["_id"])
if isinstance(agent.get("source"), DBRef)
and (source_doc := db.dereference(agent.get("source")))
else ""
else (
agent.get("source", "")
if agent.get("source") == "default"
else ""
)
),
"sources": [
(
source_ref
if source_ref == "default"
else str(db.dereference(source_ref)["_id"])
)
for source_ref in agent.get("sources", [])
if source_ref == "default"
or (
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
@@ -1351,6 +1501,7 @@ class GetAgents(Resource):
for agent in agents
if "source" in agent or "retriever" in agent
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1369,7 +1520,14 @@ class CreateAgent(Resource):
"image": fields.Raw(
required=False, description="Image file upload", type="file"
),
"source": fields.String(required=True, description="Source ID"),
"source": fields.String(
required=False, description="Source ID (legacy single source)"
),
"sources": fields.List(
fields.String,
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
@@ -1381,7 +1539,8 @@ class CreateAgent(Resource):
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False, description="JSON schema for enforcing structured output format"
required=False,
description="JSON schema for enforcing structured output format",
),
},
)
@@ -1401,13 +1560,18 @@ class CreateAgent(Resource):
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "sources" in data:
try:
data["sources"] = json.loads(data["sources"])
except json.JSONDecodeError:
data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
try:
@@ -1415,20 +1579,32 @@ class CreateAgent(Resource):
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}),
400
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}),
400
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
return make_response(
jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}),
400
jsonify(
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -1446,12 +1622,22 @@ class CreateAgent(Resource):
required_fields = [
"name",
"description",
"source",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
if not data.get("source") and not data.get("sources"):
return make_response(
jsonify(
{
"success": False,
"message": "Either 'source' or 'sources' field is required for published agents",
}
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
@@ -1471,16 +1657,31 @@ class CreateAgent(Resource):
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": (
DBRef("sources", ObjectId(data.get("source")))
if ObjectId.is_valid(data.get("source"))
else ""
),
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
@@ -1495,7 +1696,11 @@ class CreateAgent(Resource):
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "0"
if new_agent["source"] == "" and new_agent["retriever"] == "":
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
@@ -1517,7 +1722,14 @@ class UpdateAgent(Resource):
"image": fields.String(
required=False, description="New image URL or identifier"
),
"source": fields.String(required=True, description="Source ID"),
"source": fields.String(
required=False, description="Source ID (legacy single source)"
),
"sources": fields.List(
fields.String,
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
@@ -1529,7 +1741,8 @@ class UpdateAgent(Resource):
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False, description="JSON schema for enforcing structured output format"
required=False,
description="JSON schema for enforcing structured output format",
),
},
)
@@ -1549,6 +1762,11 @@ class UpdateAgent(Resource):
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "sources" in data:
try:
data["sources"] = json.loads(data["sources"])
except json.JSONDecodeError:
data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
@@ -1593,6 +1811,7 @@ class UpdateAgent(Resource):
"description",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
@@ -1616,7 +1835,10 @@ class UpdateAgent(Resource):
update_fields[field] = new_status
elif field == "source":
source_id = data.get("source")
if source_id and ObjectId.is_valid(source_id):
if source_id == "default":
# Handle special "default" source
update_fields[field] = "default"
elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
elif source_id:
return make_response(
@@ -1630,6 +1852,30 @@ class UpdateAgent(Resource):
)
else:
update_fields[field] = ""
elif field == "sources":
sources_list = data.get("sources", [])
if sources_list and isinstance(sources_list, list):
valid_sources = []
for source_id in sources_list:
if source_id == "default":
valid_sources.append("default")
elif ObjectId.is_valid(source_id):
valid_sources.append(
DBRef("sources", ObjectId(source_id))
)
else:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid source ID format: {source_id}",
}
),
400,
)
update_fields[field] = valid_sources
else:
update_fields[field] = []
elif field == "chunks":
chunks_value = data.get("chunks")
if chunks_value == "":
@@ -3532,7 +3778,7 @@ class GetChunks(Resource):
"page": "Page number for pagination",
"per_page": "Number of chunks per page",
"path": "Optional: Filter chunks by relative file path",
"search": "Optional: Search term to filter chunks by title or content"
"search": "Optional: Search term to filter chunks by title or content",
},
)
def get(self):
@@ -3556,7 +3802,7 @@ class GetChunks(Resource):
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
filtered_chunks = []
for chunk in chunks:
metadata = chunk.get("metadata", {})
@@ -3577,9 +3823,9 @@ class GetChunks(Resource):
continue
filtered_chunks.append(chunk)
chunks = filtered_chunks
total_chunks = len(chunks)
start = (page - 1) * per_page
end = start + per_page
@@ -3593,7 +3839,7 @@ class GetChunks(Resource):
"total": total_chunks,
"chunks": paginated_chunks,
"path": path if path else None,
"search": search_term if search_term else None
"search": search_term if search_term else None,
}
),
200,
@@ -3602,6 +3848,7 @@ class GetChunks(Resource):
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@user_ns.route("/api/add_chunk")
class AddChunk(Resource):
@api.expect(
@@ -3768,7 +4015,9 @@ class UpdateChunk(Resource):
deleted = store.delete_chunk(chunk_id)
if not deleted:
current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created")
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
return make_response(
jsonify(
@@ -3900,39 +4149,38 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(
jsonify({"error": "Document ID is required"}), 400
)
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
directory_structure = doc.get("directory_structure", {})
return make_response(
jsonify({
"success": True,
"directory_structure": directory_structure,
"base_path": doc.get("file_path", "")
}), 200
jsonify(
{
"success": True,
"directory_structure": directory_structure,
"base_path": doc.get("file_path", ""),
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(
jsonify({"success": False, "error": str(e)}), 500
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)