diff --git a/application/api/user/sources/routes.py b/application/api/user/sources/routes.py index 829371bc..bc5eafeb 100644 --- a/application/api/user/sources/routes.py +++ b/application/api/user/sources/routes.py @@ -9,6 +9,7 @@ from flask_restx import fields, Namespace, Resource from application.api import api from application.api.user.base import sources_collection +from application.api.user.tasks import sync_source from application.core.settings import settings from application.storage.storage_creator import StorageCreator from application.utils import check_required_fields @@ -20,6 +21,21 @@ sources_ns = Namespace( ) +def _get_provider_from_remote_data(remote_data): + if not remote_data: + return None + if isinstance(remote_data, dict): + return remote_data.get("provider") + if isinstance(remote_data, str): + try: + remote_data_obj = json.loads(remote_data) + except Exception: + return None + if isinstance(remote_data_obj, dict): + return remote_data_obj.get("provider") + return None + + @sources_ns.route("/sources") class CombinedJson(Resource): @api.doc(description="Provide JSON file with combined available indexes") @@ -41,6 +57,7 @@ class CombinedJson(Resource): try: for index in sources_collection.find({"user": user}).sort("date", -1): + provider = _get_provider_from_remote_data(index.get("remote_data")) data.append( { "id": str(index["_id"]), @@ -51,6 +68,7 @@ class CombinedJson(Resource): "tokens": index.get("tokens", ""), "retriever": index.get("retriever", "classic"), "syncFrequency": index.get("sync_frequency", ""), + "provider": provider, "is_nested": bool(index.get("directory_structure")), "type": index.get( "type", "file" @@ -107,6 +125,7 @@ class PaginatedSources(Resource): paginated_docs = [] for doc in documents: + provider = _get_provider_from_remote_data(doc.get("remote_data")) doc_data = { "id": str(doc["_id"]), "name": doc.get("name", ""), @@ -116,6 +135,7 @@ class PaginatedSources(Resource): "tokens": doc.get("tokens", ""), "retriever": doc.get("retriever", "classic"), "syncFrequency": doc.get("sync_frequency", ""), + "provider": provider, "isNested": bool(doc.get("directory_structure")), "type": doc.get("type", "file"), } @@ -240,7 +260,7 @@ class ManageSync(Resource): if not decoded_token: return make_response(jsonify({"success": False}), 401) user = decoded_token.get("sub") - data = request.get_json() + data = request.get_json() or {} required_fields = ["source_id", "sync_frequency"] missing_fields = check_required_fields(data, required_fields) if missing_fields: @@ -269,6 +289,72 @@ class ManageSync(Resource): return make_response(jsonify({"success": True}), 200) +@sources_ns.route("/sync_source") +class SyncSource(Resource): + sync_source_model = api.model( + "SyncSourceModel", + {"source_id": fields.String(required=True, description="Source ID")}, + ) + + @api.expect(sync_source_model) + @api.doc(description="Trigger an immediate sync for a source") + def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + data = request.get_json() + required_fields = ["source_id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + source_id = data["source_id"] + if not ObjectId.is_valid(source_id): + return make_response( + jsonify({"success": False, "message": "Invalid source ID"}), 400 + ) + doc = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": user} + ) + if not doc: + return make_response( + jsonify({"success": False, "message": "Source not found"}), 404 + ) + source_type = doc.get("type", "") + if source_type.startswith("connector"): + return make_response( + jsonify( + { + "success": False, + "message": "Connector sources must be synced via /api/connectors/sync", + } + ), + 400, + ) + source_data = doc.get("remote_data") + if not source_data: + return make_response( + jsonify({"success": False, "message": "Source is not syncable"}), 400 + ) + try: + task = sync_source.delay( + source_data=source_data, + job_name=doc.get("name", ""), + user=user, + loader=source_type, + sync_frequency=doc.get("sync_frequency", "never"), + retriever=doc.get("retriever", "classic"), + doc_id=source_id, + ) + except Exception as err: + current_app.logger.error( + f"Error starting sync for source {source_id}: {err}", + exc_info=True, + ) + return make_response(jsonify({"success": False}), 400) + return make_response(jsonify({"success": True, "task_id": task.id}), 200) + + @sources_ns.route("/directory_structure") class DirectoryStructure(Resource): @api.doc( diff --git a/application/api/user/sources/upload.py b/application/api/user/sources/upload.py index 801e5b44..0d2f43c8 100644 --- a/application/api/user/sources/upload.py +++ b/application/api/user/sources/upload.py @@ -187,6 +187,8 @@ class UploadRemote(Resource): source_data = config.get("url") elif data["source"] == "reddit": source_data = config + elif data["source"] == "s3": + source_data = config elif data["source"] in ConnectorCreator.get_supported_connectors(): session_token = config.get("session_token") if not session_token: diff --git a/application/api/user/tasks.py b/application/api/user/tasks.py index c7414b9f..1011be00 100644 --- a/application/api/user/tasks.py +++ b/application/api/user/tasks.py @@ -8,6 +8,7 @@ from application.worker import ( mcp_oauth, mcp_oauth_status, remote_worker, + sync, sync_worker, ) @@ -38,6 +39,30 @@ def schedule_syncs(self, frequency): return resp +@celery.task(bind=True) +def sync_source( + self, + source_data, + job_name, + user, + loader, + sync_frequency, + retriever, + doc_id, +): + resp = sync( + self, + source_data, + job_name, + user, + loader, + sync_frequency, + retriever, + doc_id, + ) + return resp + + @celery.task(bind=True) def store_attachment(self, file_info, user): resp = attachment_worker(self, file_info, user) diff --git a/application/parser/remote/remote_creator.py b/application/parser/remote/remote_creator.py index a47b186a..f74e35e9 100644 --- a/application/parser/remote/remote_creator.py +++ b/application/parser/remote/remote_creator.py @@ -3,6 +3,7 @@ from application.parser.remote.crawler_loader import CrawlerLoader from application.parser.remote.web_loader import WebLoader from application.parser.remote.reddit_loader import RedditPostsLoaderRemote from application.parser.remote.github_loader import GitHubLoader +from application.parser.remote.s3_loader import S3Loader class RemoteCreator: @@ -22,6 +23,7 @@ class RemoteCreator: "crawler": CrawlerLoader, "reddit": RedditPostsLoaderRemote, "github": GitHubLoader, + "s3": S3Loader, } @classmethod diff --git a/application/parser/remote/s3_loader.py b/application/parser/remote/s3_loader.py new file mode 100644 index 00000000..e99fa3c1 --- /dev/null +++ b/application/parser/remote/s3_loader.py @@ -0,0 +1,427 @@ +import json +import logging +import os +import tempfile +import mimetypes +from typing import List, Optional +from application.parser.remote.base import BaseRemote +from application.parser.schema.base import Document + +try: + import boto3 + from botocore.exceptions import ClientError, NoCredentialsError +except ImportError: + boto3 = None + +logger = logging.getLogger(__name__) + + +class S3Loader(BaseRemote): + """Load documents from an AWS S3 bucket.""" + + def __init__(self): + if boto3 is None: + raise ImportError( + "boto3 is required for S3Loader. Install it with: pip install boto3" + ) + self.s3_client = None + + def _normalize_endpoint_url(self, endpoint_url: str, bucket: str) -> tuple[str, str]: + """ + Normalize endpoint URL for S3-compatible services. + + Detects common mistakes like using bucket-prefixed URLs and extracts + the correct endpoint and bucket name. + + Args: + endpoint_url: The provided endpoint URL + bucket: The provided bucket name + + Returns: + Tuple of (normalized_endpoint_url, bucket_name) + """ + import re + from urllib.parse import urlparse + + if not endpoint_url: + return endpoint_url, bucket + + parsed = urlparse(endpoint_url) + host = parsed.netloc or parsed.path + + # Check for DigitalOcean Spaces bucket-prefixed URL pattern + # e.g., https://mybucket.nyc3.digitaloceanspaces.com + do_match = re.match(r"^([^.]+)\.([a-z0-9]+)\.digitaloceanspaces\.com$", host) + if do_match: + extracted_bucket = do_match.group(1) + region = do_match.group(2) + correct_endpoint = f"https://{region}.digitaloceanspaces.com" + logger.warning( + f"Detected bucket-prefixed DigitalOcean Spaces URL. " + f"Extracted bucket '{extracted_bucket}' from endpoint. " + f"Using endpoint: {correct_endpoint}" + ) + # If bucket wasn't provided or differs, use extracted one + if not bucket or bucket != extracted_bucket: + logger.info(f"Using extracted bucket name: '{extracted_bucket}' (was: '{bucket}')") + bucket = extracted_bucket + return correct_endpoint, bucket + + # Check for just "digitaloceanspaces.com" without region + if host == "digitaloceanspaces.com": + logger.error( + "Invalid DigitalOcean Spaces endpoint: missing region. " + "Use format: https://.digitaloceanspaces.com (e.g., https://lon1.digitaloceanspaces.com)" + ) + + return endpoint_url, bucket + + def _init_client( + self, + aws_access_key_id: str, + aws_secret_access_key: str, + region_name: str = "us-east-1", + endpoint_url: Optional[str] = None, + bucket: Optional[str] = None, + ) -> Optional[str]: + """ + Initialize the S3 client with credentials. + + Returns: + The potentially corrected bucket name if endpoint URL was normalized + """ + from botocore.config import Config + + client_kwargs = { + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + "region_name": region_name, + } + + logger.info(f"Initializing S3 client with region: {region_name}") + + corrected_bucket = bucket + if endpoint_url: + # Normalize the endpoint URL and potentially extract bucket name + normalized_endpoint, corrected_bucket = self._normalize_endpoint_url(endpoint_url, bucket) + logger.info(f"Original endpoint URL: {endpoint_url}") + logger.info(f"Normalized endpoint URL: {normalized_endpoint}") + logger.info(f"Bucket name: '{corrected_bucket}'") + + client_kwargs["endpoint_url"] = normalized_endpoint + # Use path-style addressing for S3-compatible services + # (DigitalOcean Spaces, MinIO, etc.) + client_kwargs["config"] = Config(s3={"addressing_style": "path"}) + else: + logger.info("Using default AWS S3 endpoint") + + self.s3_client = boto3.client("s3", **client_kwargs) + logger.info("S3 client initialized successfully") + + return corrected_bucket + + def is_text_file(self, file_path: str) -> bool: + """Determine if a file is a text file based on extension.""" + text_extensions = { + ".txt", + ".md", + ".markdown", + ".rst", + ".json", + ".xml", + ".yaml", + ".yml", + ".py", + ".js", + ".ts", + ".jsx", + ".tsx", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".cs", + ".go", + ".rs", + ".rb", + ".php", + ".swift", + ".kt", + ".scala", + ".html", + ".css", + ".scss", + ".sass", + ".less", + ".sh", + ".bash", + ".zsh", + ".fish", + ".sql", + ".r", + ".m", + ".mat", + ".ini", + ".cfg", + ".conf", + ".config", + ".env", + ".gitignore", + ".dockerignore", + ".editorconfig", + ".log", + ".csv", + ".tsv", + } + + file_lower = file_path.lower() + for ext in text_extensions: + if file_lower.endswith(ext): + return True + + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type and ( + mime_type.startswith("text") + or mime_type in ["application/json", "application/xml"] + ): + return True + + return False + + def is_supported_document(self, file_path: str) -> bool: + """Check if file is a supported document type for parsing.""" + document_extensions = { + ".pdf", + ".docx", + ".doc", + ".xlsx", + ".xls", + ".pptx", + ".ppt", + ".epub", + ".odt", + ".rtf", + } + + file_lower = file_path.lower() + for ext in document_extensions: + if file_lower.endswith(ext): + return True + + return False + + def list_objects(self, bucket: str, prefix: str = "") -> List[str]: + """ + List all objects in the bucket with the given prefix. + + Args: + bucket: S3 bucket name + prefix: Optional path prefix to filter objects + + Returns: + List of object keys + """ + objects = [] + paginator = self.s3_client.get_paginator("list_objects_v2") + + logger.info(f"Listing objects in bucket: '{bucket}' with prefix: '{prefix}'") + logger.debug(f"S3 client endpoint: {self.s3_client.meta.endpoint_url}") + + try: + page_count = 0 + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + page_count += 1 + logger.debug(f"Processing page {page_count}, keys in response: {list(page.keys())}") + if "Contents" in page: + for obj in page["Contents"]: + key = obj["Key"] + if not key.endswith("/"): + objects.append(key) + logger.debug(f"Found object: {key}") + else: + logger.info(f"Page {page_count} has no 'Contents' key - bucket may be empty or prefix not found") + + logger.info(f"Found {len(objects)} objects in bucket '{bucket}'") + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + error_message = e.response.get("Error", {}).get("Message", "") + logger.error(f"ClientError listing objects - Code: {error_code}, Message: {error_message}") + logger.error(f"Full error response: {e.response}") + logger.error(f"Bucket: '{bucket}', Prefix: '{prefix}', Endpoint: {self.s3_client.meta.endpoint_url}") + + if error_code == "NoSuchBucket": + raise Exception(f"S3 bucket '{bucket}' does not exist") + elif error_code == "AccessDenied": + raise Exception( + f"Access denied to S3 bucket '{bucket}'. Check your credentials and permissions." + ) + elif error_code == "NoSuchKey": + # This is unusual for ListObjectsV2 - may indicate endpoint/bucket configuration issue + logger.error( + "NoSuchKey error on ListObjectsV2 - this may indicate the bucket name " + "is incorrect or the endpoint URL format is wrong. " + "For DigitalOcean Spaces, the endpoint should be like: " + "https://.digitaloceanspaces.com and bucket should be just the space name." + ) + raise Exception( + f"S3 error: {e}. For S3-compatible services, verify: " + f"1) Endpoint URL format (e.g., https://nyc3.digitaloceanspaces.com), " + f"2) Bucket name is just the space/bucket name without region prefix" + ) + else: + raise Exception(f"S3 error: {e}") + except NoCredentialsError: + raise Exception( + "AWS credentials not found. Please provide valid credentials." + ) + + return objects + + def get_object_content(self, bucket: str, key: str) -> Optional[str]: + """ + Get the content of an S3 object as text. + + Args: + bucket: S3 bucket name + key: Object key + + Returns: + File content as string, or None if file should be skipped + """ + if not self.is_text_file(key) and not self.is_supported_document(key): + return None + + try: + response = self.s3_client.get_object(Bucket=bucket, Key=key) + content = response["Body"].read() + + if self.is_text_file(key): + try: + decoded_content = content.decode("utf-8").strip() + if not decoded_content: + return None + return decoded_content + except UnicodeDecodeError: + return None + elif self.is_supported_document(key): + return self._process_document(content, key) + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "NoSuchKey": + return None + elif error_code == "AccessDenied": + print(f"Access denied to object: {key}") + return None + else: + print(f"Error fetching object {key}: {e}") + return None + + return None + + def _process_document(self, content: bytes, key: str) -> Optional[str]: + """ + Process a document file (PDF, DOCX, etc.) and extract text. + + Args: + content: File content as bytes + key: Object key (filename) + + Returns: + Extracted text content + """ + ext = os.path.splitext(key)[1].lower() + + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + try: + from application.parser.file.bulk import SimpleDirectoryReader + + reader = SimpleDirectoryReader(input_files=[tmp_path]) + documents = reader.load_data() + if documents: + return "\n\n".join(doc.text for doc in documents if doc.text) + return None + except Exception as e: + print(f"Error processing document {key}: {e}") + return None + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + def load_data(self, inputs) -> List[Document]: + """ + Load documents from an S3 bucket. + + Args: + inputs: JSON string or dict containing: + - aws_access_key_id: AWS access key ID + - aws_secret_access_key: AWS secret access key + - bucket: S3 bucket name + - prefix: Optional path prefix to filter objects + - region: AWS region (default: us-east-1) + - endpoint_url: Custom S3 endpoint URL (for MinIO, R2, etc.) + + Returns: + List of Document objects + """ + if isinstance(inputs, str): + try: + data = json.loads(inputs) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON input: {e}") + else: + data = inputs + + required_fields = ["aws_access_key_id", "aws_secret_access_key", "bucket"] + missing_fields = [field for field in required_fields if not data.get(field)] + if missing_fields: + raise ValueError(f"Missing required fields: {', '.join(missing_fields)}") + + aws_access_key_id = data["aws_access_key_id"] + aws_secret_access_key = data["aws_secret_access_key"] + bucket = data["bucket"] + prefix = data.get("prefix", "") + region = data.get("region", "us-east-1") + endpoint_url = data.get("endpoint_url", "") + + logger.info(f"Loading data from S3 - Bucket: '{bucket}', Prefix: '{prefix}', Region: '{region}'") + if endpoint_url: + logger.info(f"Custom endpoint URL provided: '{endpoint_url}'") + + corrected_bucket = self._init_client( + aws_access_key_id, aws_secret_access_key, region, endpoint_url or None, bucket + ) + + # Use the corrected bucket name if endpoint URL normalization extracted one + if corrected_bucket and corrected_bucket != bucket: + logger.info(f"Using corrected bucket name: '{corrected_bucket}' (original: '{bucket}')") + bucket = corrected_bucket + + objects = self.list_objects(bucket, prefix) + documents = [] + + for key in objects: + content = self.get_object_content(bucket, key) + if content is None: + continue + + documents.append( + Document( + text=content, + doc_id=key, + extra_info={ + "title": os.path.basename(key), + "source": f"s3://{bucket}/{key}", + "bucket": bucket, + "key": key, + }, + ) + ) + + logger.info(f"Loaded {len(documents)} documents from S3 bucket '{bucket}'") + return documents diff --git a/application/worker.py b/application/worker.py index fa2b6cd7..44668247 100755 --- a/application/worker.py +++ b/application/worker.py @@ -868,6 +868,58 @@ def remote_worker( tokens = count_tokens_docs(docs) logging.info("Total tokens calculated: %d", tokens) + # Build directory structure from loaded documents + # Format matches local file uploads: flat structure with type, size_bytes, token_count + directory_structure = {} + for doc in raw_docs: + # Get the file path/name from doc_id or extra_info + file_path = doc.doc_id or "" + if not file_path and doc.extra_info: + file_path = doc.extra_info.get("key", "") or doc.extra_info.get( + "title", "" + ) + + if file_path: + # Use just the filename (last part of path) for flat structure + file_name = file_path.split("/")[-1] if "/" in file_path else file_path + + # Calculate token count + token_count = len(doc.text.split()) if doc.text else 0 + + # Estimate size in bytes from text content + size_bytes = len(doc.text.encode("utf-8")) if doc.text else 0 + + # Guess mime type from extension + ext = os.path.splitext(file_name)[1].lower() + mime_types = { + ".txt": "text/plain", + ".md": "text/markdown", + ".pdf": "application/pdf", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".doc": "application/msword", + ".html": "text/html", + ".json": "application/json", + ".csv": "text/csv", + ".xml": "application/xml", + ".py": "text/x-python", + ".js": "text/javascript", + ".ts": "text/typescript", + ".jsx": "text/jsx", + ".tsx": "text/tsx", + } + file_type = mime_types.get(ext, "application/octet-stream") + + directory_structure[file_name] = { + "type": file_type, + "size_bytes": size_bytes, + "token_count": token_count, + } + + logging.info( + f"Built directory structure with {len(directory_structure)} files: " + f"{list(directory_structure.keys())}" + ) + if operation_mode == "upload": id = ObjectId() embed_and_store_documents(docs, full_path, id, self) @@ -879,6 +931,10 @@ def remote_worker( embed_and_store_documents(docs, full_path, id, self) self.update_state(state="PROGRESS", meta={"current": 100}) + # Serialize remote_data as JSON if it's a dict (for S3, Reddit, etc.) + remote_data_serialized = ( + json.dumps(source_data) if isinstance(source_data, dict) else source_data + ) file_data = { "name": name_job, "user": user, @@ -886,8 +942,9 @@ def remote_worker( "retriever": retriever, "id": str(id), "type": loader, - "remote_data": source_data, + "remote_data": remote_data_serialized, "sync_frequency": sync_frequency, + "directory_structure": json.dumps(directory_structure), } if operation_mode == "sync": diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 6bd8a834..a127cf32 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -34,6 +34,7 @@ const endpoints = { FEEDBACK_ANALYTICS: '/api/get_feedback_analytics', LOGS: `/api/get_user_logs`, MANAGE_SYNC: '/api/manage_sync', + SYNC_SOURCE: '/api/sync_source', GET_AVAILABLE_TOOLS: '/api/available_tools', GET_USER_TOOLS: '/api/get_tools', CREATE_TOOL: '/api/create_tool', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 1dcf9f4c..f707f646 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -72,6 +72,8 @@ const userService = { apiClient.post(endpoints.USER.LOGS, data, token), manageSync: (data: any, token: string | null): Promise => apiClient.post(endpoints.USER.MANAGE_SYNC, data, token), + syncSource: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.SYNC_SOURCE, data, token), getAvailableTools: (token: string | null): Promise => apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token), getUserTools: (token: string | null): Promise => diff --git a/frontend/src/assets/s3.svg b/frontend/src/assets/s3.svg new file mode 100644 index 00000000..3fdc41ff --- /dev/null +++ b/frontend/src/assets/s3.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/frontend/src/locale/de.json b/frontend/src/locale/de.json index 55e10877..89890ae7 100644 --- a/frontend/src/locale/de.json +++ b/frontend/src/locale/de.json @@ -67,6 +67,7 @@ "preLoaded": "Vorgeladen", "private": "Privat", "sync": "Synchronisieren", + "syncNow": "Jetzt synchronisieren", "syncing": "Synchronisiere...", "syncConfirmation": "Bist du sicher, dass du \"{{sourceName}}\" synchronisieren möchtest? Dies aktualisiert den Inhalt mit deinem Cloud-Speicher und kann Änderungen an einzelnen Chunks überschreiben.", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "Von Google Drive hochladen" + }, + "s3": { + "label": "Amazon S3", + "heading": "Inhalt von Amazon S3 hinzufügen" } }, "connectors": { diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 94680a96..f1c3afd5 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -67,6 +67,7 @@ "preLoaded": "Pre-loaded", "private": "Private", "sync": "Sync", + "syncNow": "Sync now", "syncing": "Syncing...", "syncConfirmation": "Are you sure you want to sync \"{{sourceName}}\"? This will update the content with your cloud storage and may override any edits you made to individual chunks.", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "Upload from Google Drive" + }, + "s3": { + "label": "Amazon S3", + "heading": "Add content from Amazon S3" } }, "connectors": { diff --git a/frontend/src/locale/es.json b/frontend/src/locale/es.json index f388796e..5e20dff9 100644 --- a/frontend/src/locale/es.json +++ b/frontend/src/locale/es.json @@ -67,6 +67,7 @@ "preLoaded": "Precargado", "private": "Privado", "sync": "Sincronizar", + "syncNow": "Sincronizar ahora", "syncing": "Sincronizando...", "syncConfirmation": "¿Estás seguro de que deseas sincronizar \"{{sourceName}}\"? Esto actualizará el contenido con tu almacenamiento en la nube y puede anular cualquier edición que hayas realizado en fragmentos individuales.", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "Subir desde Google Drive" + }, + "s3": { + "label": "Amazon S3", + "heading": "Agregar contenido desde Amazon S3" } }, "connectors": { diff --git a/frontend/src/locale/jp.json b/frontend/src/locale/jp.json index cc5a9de3..233969eb 100644 --- a/frontend/src/locale/jp.json +++ b/frontend/src/locale/jp.json @@ -67,6 +67,7 @@ "preLoaded": "プリロード済み", "private": "プライベート", "sync": "同期", + "syncNow": "今すぐ同期", "syncing": "同期中...", "syncConfirmation": "\"{{sourceName}}\"を同期してもよろしいですか?これにより、コンテンツがクラウドストレージで更新され、個々のチャンクに加えた編集が上書きされる可能性があります。", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "Google Driveからアップロード" + }, + "s3": { + "label": "Amazon S3", + "heading": "Amazon S3からコンテンツを追加" } }, "connectors": { diff --git a/frontend/src/locale/ru.json b/frontend/src/locale/ru.json index d9146781..76b7b316 100644 --- a/frontend/src/locale/ru.json +++ b/frontend/src/locale/ru.json @@ -67,6 +67,7 @@ "preLoaded": "Предзагруженный", "private": "Частный", "sync": "Синхронизация", + "syncNow": "Синхронизировать сейчас", "syncing": "Синхронизация...", "syncConfirmation": "Вы уверены, что хотите синхронизировать \"{{sourceName}}\"? Это обновит содержимое с вашим облачным хранилищем и может перезаписать любые изменения, внесенные вами в отдельные фрагменты.", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "Загрузить из Google Drive" + }, + "s3": { + "label": "Amazon S3", + "heading": "Добавить контент из Amazon S3" } }, "connectors": { diff --git a/frontend/src/locale/zh-TW.json b/frontend/src/locale/zh-TW.json index 8e8d9714..2b1a9e0b 100644 --- a/frontend/src/locale/zh-TW.json +++ b/frontend/src/locale/zh-TW.json @@ -67,6 +67,7 @@ "preLoaded": "預載入", "private": "私人", "sync": "同步", + "syncNow": "立即同步", "syncing": "同步中...", "syncConfirmation": "您確定要同步 \"{{sourceName}}\" 嗎?這將使用您的雲端儲存更新內容,並可能覆蓋您對個別文本塊所做的任何編輯。", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "從Google Drive上傳" + }, + "s3": { + "label": "Amazon S3", + "heading": "從Amazon S3新增內容" } }, "connectors": { diff --git a/frontend/src/locale/zh.json b/frontend/src/locale/zh.json index f85cef5b..33635026 100644 --- a/frontend/src/locale/zh.json +++ b/frontend/src/locale/zh.json @@ -67,6 +67,7 @@ "preLoaded": "预加载", "private": "私有", "sync": "同步", + "syncNow": "立即同步", "syncing": "同步中...", "syncConfirmation": "您确定要同步 \"{{sourceName}}\" 吗?这将使用您的云存储更新内容,并可能覆盖您对单个文本块所做的任何编辑。", "syncFrequency": { @@ -316,6 +317,10 @@ "google_drive": { "label": "Google Drive", "heading": "从Google Drive上传" + }, + "s3": { + "label": "Amazon S3", + "heading": "从Amazon S3添加内容" } }, "connectors": { diff --git a/frontend/src/models/misc.ts b/frontend/src/models/misc.ts index 644d1ae9..61f0c56d 100644 --- a/frontend/src/models/misc.ts +++ b/frontend/src/models/misc.ts @@ -13,6 +13,7 @@ export type Doc = { retriever?: string; syncFrequency?: string; isNested?: boolean; + provider?: string; }; export type GetDocsResponse = { diff --git a/frontend/src/settings/Sources.tsx b/frontend/src/settings/Sources.tsx index 15ede1bd..6f6edcd6 100644 --- a/frontend/src/settings/Sources.tsx +++ b/frontend/src/settings/Sources.tsx @@ -201,6 +201,61 @@ export default function Sources({ }); }; + const getConnectorProvider = async (doc: Doc): Promise => { + if (doc.provider) { + return doc.provider; + } + if (!doc.id) { + return null; + } + try { + const directoryResponse = await userService.getDirectoryStructure( + doc.id, + token, + ); + const directoryData = await directoryResponse.json(); + return directoryData?.provider ?? null; + } catch (error) { + console.error('Error fetching connector provider:', error); + return null; + } + }; + + const handleSyncNow = async (doc: Doc) => { + if (!doc.id) { + return; + } + try { + if (doc.type?.startsWith('connector')) { + const provider = await getConnectorProvider(doc); + if (!provider) { + console.error('Sync now failed: provider not found'); + return; + } + const response = await userService.syncConnector( + doc.id, + provider, + token, + ); + const data = await response.json(); + if (!data.success) { + console.error('Sync now failed:', data.error || data.message); + } + return; + } + const response = await userService.syncSource( + { source_id: doc.id }, + token, + ); + const data = await response.json(); + if (!data.success) { + console.error('Sync now failed:', data.error || data.message); + } + } catch (error) { + console.error('Error syncing source:', error); + } + }; + const [documentToDelete, setDocumentToDelete] = useState<{ index: number; document: Doc; @@ -250,6 +305,16 @@ export default function Sources({ iconHeight: 14, variant: 'primary', }); + actions.push({ + icon: SyncIcon, + label: t('settings.sources.syncNow'), + onClick: () => { + handleSyncNow(document); + }, + iconWidth: 14, + iconHeight: 14, + variant: 'primary', + }); } actions.push({ diff --git a/frontend/src/upload/types/ingestor.ts b/frontend/src/upload/types/ingestor.ts index c06c9043..4e853262 100644 --- a/frontend/src/upload/types/ingestor.ts +++ b/frontend/src/upload/types/ingestor.ts @@ -4,6 +4,7 @@ import UrlIcon from '../../assets/url.svg'; import GithubIcon from '../../assets/github.svg'; import RedditIcon from '../../assets/reddit.svg'; import DriveIcon from '../../assets/drive.svg'; +import S3Icon from '../../assets/s3.svg'; export type IngestorType = | 'crawler' @@ -11,7 +12,8 @@ export type IngestorType = | 'reddit' | 'url' | 'google_drive' - | 'local_file'; + | 'local_file' + | 's3'; export interface IngestorConfig { type: IngestorType | null; @@ -147,6 +149,50 @@ export const IngestorFormSchemas: IngestorSchema[] = [ }, ], }, + { + key: 's3', + label: 'Amazon S3', + icon: S3Icon, + heading: 'Add content from Amazon S3', + fields: [ + { + name: 'aws_access_key_id', + label: 'AWS Access Key ID', + type: 'string', + required: true, + }, + { + name: 'aws_secret_access_key', + label: 'AWS Secret Access Key', + type: 'string', + required: true, + }, + { + name: 'bucket', + label: 'Bucket Name', + type: 'string', + required: true, + }, + { + name: 'prefix', + label: 'Path Prefix (optional)', + type: 'string', + required: false, + }, + { + name: 'region', + label: 'AWS Region', + type: 'string', + required: false, + }, + { + name: 'endpoint_url', + label: 'Custom Endpoint URL (optional)', + type: 'string', + required: false, + }, + ], + }, ]; export const IngestorDefaultConfigs: Record< @@ -175,6 +221,17 @@ export const IngestorDefaultConfigs: Record< }, }, local_file: { name: '', config: { files: [] } }, + s3: { + name: '', + config: { + aws_access_key_id: '', + aws_secret_access_key: '', + bucket: '', + prefix: '', + region: 'us-east-1', + endpoint_url: '', + }, + }, }; export interface IngestorOption { diff --git a/tests/api/user/sources/__init__.py b/tests/api/user/sources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/user/sources/test_routes.py b/tests/api/user/sources/test_routes.py new file mode 100644 index 00000000..288da7ff --- /dev/null +++ b/tests/api/user/sources/test_routes.py @@ -0,0 +1,357 @@ +"""Tests for sources routes.""" + +import json +import pytest +from unittest.mock import MagicMock, patch +from bson import ObjectId + + +class TestGetProviderFromRemoteData: + """Test the _get_provider_from_remote_data helper function.""" + + def test_returns_none_for_none_input(self): + """Should return None when remote_data is None.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + result = _get_provider_from_remote_data(None) + assert result is None + + def test_returns_none_for_empty_string(self): + """Should return None when remote_data is empty string.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + result = _get_provider_from_remote_data("") + assert result is None + + def test_extracts_provider_from_dict(self): + """Should extract provider from dict remote_data.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + remote_data = {"provider": "s3", "bucket": "my-bucket"} + result = _get_provider_from_remote_data(remote_data) + assert result == "s3" + + def test_extracts_provider_from_json_string(self): + """Should extract provider from JSON string remote_data.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + remote_data = json.dumps({"provider": "github", "repo": "test/repo"}) + result = _get_provider_from_remote_data(remote_data) + assert result == "github" + + def test_returns_none_for_dict_without_provider(self): + """Should return None when dict has no provider key.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + remote_data = {"bucket": "my-bucket", "region": "us-east-1"} + result = _get_provider_from_remote_data(remote_data) + assert result is None + + def test_returns_none_for_invalid_json(self): + """Should return None for invalid JSON string.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + result = _get_provider_from_remote_data("not valid json") + assert result is None + + def test_returns_none_for_json_array(self): + """Should return None when JSON parses to non-dict.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + result = _get_provider_from_remote_data('["item1", "item2"]') + assert result is None + + def test_returns_none_for_non_string_non_dict(self): + """Should return None for other types like int.""" + from application.api.user.sources.routes import _get_provider_from_remote_data + + result = _get_provider_from_remote_data(123) + assert result is None + + +def _get_response_status(response): + """Helper to get status code from response (handles both tuple and Response).""" + if isinstance(response, tuple): + return response[1] + return response.status_code + + +def _get_response_json(response): + """Helper to get JSON from response (handles both tuple and Response).""" + if isinstance(response, tuple): + return response[0].json + return response.json + + +@pytest.mark.unit +class TestSyncSourceEndpoint: + """Test the /sync_source endpoint.""" + + @pytest.fixture + def mock_sources_collection(self, mock_mongo_db): + """Get mock sources collection.""" + from application.core.settings import settings + + return mock_mongo_db[settings.MONGO_DB_NAME]["sources"] + + def test_sync_source_returns_401_without_token(self, flask_app): + """Should return 401 when no decoded_token is present.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": "123"} + ): + from flask import request + + request.decoded_token = None + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 401 + + def test_sync_source_returns_400_for_missing_source_id(self, flask_app): + """Should return 400 when source_id is missing.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + + with app.test_request_context("/api/sync_source", method="POST", json={}): + from flask import request + + request.decoded_token = {"sub": "test_user"} + resource = SyncSource() + response = resource.post() + + # check_required_fields returns a response tuple on missing fields + assert response is not None + + def test_sync_source_returns_400_for_invalid_source_id(self, flask_app): + """Should return 400 for invalid ObjectId.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": "invalid"} + ): + from flask import request + + request.decoded_token = {"sub": "test_user"} + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 400 + assert "Invalid source ID" in _get_response_json(response)["message"] + + def test_sync_source_returns_404_for_nonexistent_source( + self, flask_app, mock_mongo_db + ): + """Should return 404 when source doesn't exist.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + source_id = str(ObjectId()) + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": source_id} + ): + from flask import request + + request.decoded_token = {"sub": "test_user"} + + with patch( + "application.api.user.sources.routes.sources_collection", + mock_mongo_db["docsgpt"]["sources"], + ): + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 404 + assert "not found" in _get_response_json(response)["message"] + + def test_sync_source_returns_400_for_connector_type( + self, flask_app, mock_mongo_db, mock_sources_collection + ): + """Should return 400 for connector sources.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + source_id = ObjectId() + + # Insert a connector source + mock_sources_collection.insert_one( + { + "_id": source_id, + "user": "test_user", + "type": "connector_slack", + "name": "Slack Source", + } + ) + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": str(source_id)} + ): + from flask import request + + request.decoded_token = {"sub": "test_user"} + + with patch( + "application.api.user.sources.routes.sources_collection", + mock_sources_collection, + ): + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 400 + assert "Connector sources" in _get_response_json(response)["message"] + + def test_sync_source_returns_400_for_non_syncable_source( + self, flask_app, mock_mongo_db, mock_sources_collection + ): + """Should return 400 when source has no remote_data.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + source_id = ObjectId() + + # Insert a source without remote_data + mock_sources_collection.insert_one( + { + "_id": source_id, + "user": "test_user", + "type": "file", + "name": "Local Source", + "remote_data": None, + } + ) + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": str(source_id)} + ): + from flask import request + + request.decoded_token = {"sub": "test_user"} + + with patch( + "application.api.user.sources.routes.sources_collection", + mock_sources_collection, + ): + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 400 + assert "not syncable" in _get_response_json(response)["message"] + + def test_sync_source_triggers_sync_task( + self, flask_app, mock_mongo_db, mock_sources_collection + ): + """Should trigger sync task for valid syncable source.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + source_id = ObjectId() + + # Insert a valid syncable source + mock_sources_collection.insert_one( + { + "_id": source_id, + "user": "test_user", + "type": "s3", + "name": "S3 Source", + "remote_data": json.dumps( + { + "provider": "s3", + "bucket": "my-bucket", + "aws_access_key_id": "key", + "aws_secret_access_key": "secret", + } + ), + "sync_frequency": "daily", + "retriever": "classic", + } + ) + + mock_task = MagicMock() + mock_task.id = "task-123" + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": str(source_id)} + ): + from flask import request + + request.decoded_token = {"sub": "test_user"} + + with patch( + "application.api.user.sources.routes.sources_collection", + mock_sources_collection, + ): + with patch( + "application.api.user.sources.routes.sync_source" + ) as mock_sync: + mock_sync.delay.return_value = mock_task + + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 200 + assert _get_response_json(response)["success"] is True + assert _get_response_json(response)["task_id"] == "task-123" + + mock_sync.delay.assert_called_once() + call_kwargs = mock_sync.delay.call_args[1] + assert call_kwargs["user"] == "test_user" + assert call_kwargs["loader"] == "s3" + assert call_kwargs["doc_id"] == str(source_id) + + def test_sync_source_handles_task_error( + self, flask_app, mock_mongo_db, mock_sources_collection + ): + """Should return 400 when task fails to start.""" + from flask import Flask + from application.api.user.sources.routes import SyncSource + + app = Flask(__name__) + source_id = ObjectId() + + mock_sources_collection.insert_one( + { + "_id": source_id, + "user": "test_user", + "type": "github", + "name": "GitHub Source", + "remote_data": "https://github.com/test/repo", + "sync_frequency": "weekly", + "retriever": "classic", + } + ) + + with app.test_request_context( + "/api/sync_source", method="POST", json={"source_id": str(source_id)} + ): + from flask import request + + request.decoded_token = {"sub": "test_user"} + + with patch( + "application.api.user.sources.routes.sources_collection", + mock_sources_collection, + ): + with patch( + "application.api.user.sources.routes.sync_source" + ) as mock_sync: + mock_sync.delay.side_effect = Exception("Celery error") + + resource = SyncSource() + response = resource.post() + + assert _get_response_status(response) == 400 + assert _get_response_json(response)["success"] is False diff --git a/tests/parser/remote/test_s3_loader.py b/tests/parser/remote/test_s3_loader.py new file mode 100644 index 00000000..13fbe514 --- /dev/null +++ b/tests/parser/remote/test_s3_loader.py @@ -0,0 +1,714 @@ +"""Tests for S3 loader implementation.""" + +import json +import pytest +from unittest.mock import MagicMock, patch + +from botocore.exceptions import ClientError, NoCredentialsError + + +@pytest.fixture +def mock_boto3(): + """Mock boto3 module.""" + with patch.dict("sys.modules", {"boto3": MagicMock()}): + with patch("application.parser.remote.s3_loader.boto3") as mock: + yield mock + + +@pytest.fixture +def s3_loader(mock_boto3): + """Create S3Loader instance with mocked boto3.""" + from application.parser.remote.s3_loader import S3Loader + + loader = S3Loader() + return loader + + +class TestS3LoaderInit: + """Test S3Loader initialization.""" + + def test_init_raises_import_error_when_boto3_missing(self): + """Should raise ImportError when boto3 is not installed.""" + with patch("application.parser.remote.s3_loader.boto3", None): + from application.parser.remote.s3_loader import S3Loader + + with pytest.raises(ImportError, match="boto3 is required"): + S3Loader() + + def test_init_sets_client_to_none(self, mock_boto3): + """Should initialize with s3_client as None.""" + from application.parser.remote.s3_loader import S3Loader + + loader = S3Loader() + assert loader.s3_client is None + + +class TestNormalizeEndpointUrl: + """Test endpoint URL normalization for S3-compatible services.""" + + def test_returns_unchanged_for_empty_endpoint(self, s3_loader): + """Should return unchanged values when endpoint_url is empty.""" + endpoint, bucket = s3_loader._normalize_endpoint_url("", "my-bucket") + assert endpoint == "" + assert bucket == "my-bucket" + + def test_returns_unchanged_for_none_endpoint(self, s3_loader): + """Should return unchanged values when endpoint_url is None.""" + endpoint, bucket = s3_loader._normalize_endpoint_url(None, "my-bucket") + assert endpoint is None + assert bucket == "my-bucket" + + def test_extracts_bucket_from_do_spaces_url(self, s3_loader): + """Should extract bucket name from DigitalOcean Spaces bucket-prefixed URL.""" + endpoint, bucket = s3_loader._normalize_endpoint_url( + "https://mybucket.nyc3.digitaloceanspaces.com", "" + ) + assert endpoint == "https://nyc3.digitaloceanspaces.com" + assert bucket == "mybucket" + + def test_extracts_bucket_overrides_provided_bucket(self, s3_loader): + """Should use extracted bucket when it differs from provided one.""" + endpoint, bucket = s3_loader._normalize_endpoint_url( + "https://mybucket.lon1.digitaloceanspaces.com", "other-bucket" + ) + assert endpoint == "https://lon1.digitaloceanspaces.com" + assert bucket == "mybucket" + + def test_keeps_provided_bucket_when_matches_extracted(self, s3_loader): + """Should keep bucket when provided matches extracted.""" + endpoint, bucket = s3_loader._normalize_endpoint_url( + "https://mybucket.sfo3.digitaloceanspaces.com", "mybucket" + ) + assert endpoint == "https://sfo3.digitaloceanspaces.com" + assert bucket == "mybucket" + + def test_returns_unchanged_for_standard_do_endpoint(self, s3_loader): + """Should return unchanged for standard DO Spaces endpoint.""" + endpoint, bucket = s3_loader._normalize_endpoint_url( + "https://nyc3.digitaloceanspaces.com", "my-bucket" + ) + assert endpoint == "https://nyc3.digitaloceanspaces.com" + assert bucket == "my-bucket" + + def test_returns_unchanged_for_aws_endpoint(self, s3_loader): + """Should return unchanged for standard AWS S3 endpoints.""" + endpoint, bucket = s3_loader._normalize_endpoint_url( + "https://s3.us-east-1.amazonaws.com", "my-bucket" + ) + assert endpoint == "https://s3.us-east-1.amazonaws.com" + assert bucket == "my-bucket" + + def test_handles_minio_endpoint(self, s3_loader): + """Should return unchanged for MinIO endpoints.""" + endpoint, bucket = s3_loader._normalize_endpoint_url( + "http://localhost:9000", "my-bucket" + ) + assert endpoint == "http://localhost:9000" + assert bucket == "my-bucket" + + +class TestInitClient: + """Test S3 client initialization.""" + + def test_init_client_creates_boto3_client(self, s3_loader, mock_boto3): + """Should create boto3 S3 client with provided credentials.""" + s3_loader._init_client( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-west-2", + ) + + mock_boto3.client.assert_called_once() + call_kwargs = mock_boto3.client.call_args[1] + assert call_kwargs["aws_access_key_id"] == "test-key" + assert call_kwargs["aws_secret_access_key"] == "test-secret" + assert call_kwargs["region_name"] == "us-west-2" + + def test_init_client_with_custom_endpoint(self, s3_loader, mock_boto3): + """Should configure path-style addressing for custom endpoints.""" + s3_loader._init_client( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + endpoint_url="https://nyc3.digitaloceanspaces.com", + bucket="my-bucket", + ) + + call_kwargs = mock_boto3.client.call_args[1] + assert call_kwargs["endpoint_url"] == "https://nyc3.digitaloceanspaces.com" + assert "config" in call_kwargs + + def test_init_client_normalizes_do_endpoint(self, s3_loader, mock_boto3): + """Should normalize DigitalOcean Spaces bucket-prefixed URLs.""" + corrected_bucket = s3_loader._init_client( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + endpoint_url="https://mybucket.nyc3.digitaloceanspaces.com", + bucket="", + ) + + assert corrected_bucket == "mybucket" + call_kwargs = mock_boto3.client.call_args[1] + assert call_kwargs["endpoint_url"] == "https://nyc3.digitaloceanspaces.com" + + def test_init_client_returns_bucket_name(self, s3_loader, mock_boto3): + """Should return the bucket name (potentially corrected).""" + result = s3_loader._init_client( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + bucket="my-bucket", + ) + + assert result == "my-bucket" + + +class TestIsTextFile: + """Test text file detection.""" + + def test_recognizes_common_text_extensions(self, s3_loader): + """Should recognize common text file extensions.""" + text_files = [ + "readme.txt", + "docs.md", + "config.json", + "data.yaml", + "script.py", + "app.js", + "main.go", + "style.css", + "index.html", + ] + for filename in text_files: + assert s3_loader.is_text_file(filename), f"{filename} should be text" + + def test_rejects_binary_extensions(self, s3_loader): + """Should reject binary file extensions.""" + binary_files = ["image.png", "photo.jpg", "archive.zip", "app.exe", "doc.pdf"] + for filename in binary_files: + assert not s3_loader.is_text_file(filename), f"{filename} should not be text" + + def test_case_insensitive_matching(self, s3_loader): + """Should match extensions case-insensitively.""" + assert s3_loader.is_text_file("README.TXT") + assert s3_loader.is_text_file("Config.JSON") + assert s3_loader.is_text_file("Script.PY") + + +class TestIsSupportedDocument: + """Test document file detection.""" + + def test_recognizes_document_extensions(self, s3_loader): + """Should recognize document file extensions.""" + doc_files = [ + "report.pdf", + "document.docx", + "spreadsheet.xlsx", + "presentation.pptx", + "book.epub", + ] + for filename in doc_files: + assert s3_loader.is_supported_document( + filename + ), f"{filename} should be document" + + def test_rejects_non_document_extensions(self, s3_loader): + """Should reject non-document file extensions.""" + non_doc_files = ["image.png", "script.py", "readme.txt", "archive.zip"] + for filename in non_doc_files: + assert not s3_loader.is_supported_document( + filename + ), f"{filename} should not be document" + + def test_case_insensitive_matching(self, s3_loader): + """Should match extensions case-insensitively.""" + assert s3_loader.is_supported_document("Report.PDF") + assert s3_loader.is_supported_document("Document.DOCX") + + +class TestListObjects: + """Test S3 object listing.""" + + def test_list_objects_returns_file_keys(self, s3_loader, mock_boto3): + """Should return list of file keys from bucket.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "file1.txt"}, + {"Key": "file2.md"}, + {"Key": "folder/"}, # Directory marker, should be skipped + {"Key": "folder/file3.py"}, + ] + } + ] + + result = s3_loader.list_objects("test-bucket", "") + + assert result == ["file1.txt", "file2.md", "folder/file3.py"] + mock_client.get_paginator.assert_called_once_with("list_objects_v2") + paginator.paginate.assert_called_once_with(Bucket="test-bucket", Prefix="") + + def test_list_objects_with_prefix(self, s3_loader): + """Should filter objects by prefix.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [ + {"Contents": [{"Key": "docs/readme.md"}, {"Key": "docs/guide.txt"}]} + ] + + result = s3_loader.list_objects("test-bucket", "docs/") + + paginator.paginate.assert_called_once_with(Bucket="test-bucket", Prefix="docs/") + assert len(result) == 2 + + def test_list_objects_handles_empty_bucket(self, s3_loader): + """Should return empty list for empty bucket.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{}] # No Contents key + + result = s3_loader.list_objects("test-bucket", "") + + assert result == [] + + def test_list_objects_raises_on_no_such_bucket(self, s3_loader): + """Should raise exception when bucket doesn't exist.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value.__iter__ = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "NoSuchBucket", "Message": "Bucket not found"}}, + "ListObjectsV2", + ) + ) + + with pytest.raises(Exception, match="does not exist"): + s3_loader.list_objects("nonexistent-bucket", "") + + def test_list_objects_raises_on_access_denied(self, s3_loader): + """Should raise exception on access denied.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value.__iter__ = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, + "ListObjectsV2", + ) + ) + + with pytest.raises(Exception, match="Access denied"): + s3_loader.list_objects("test-bucket", "") + + def test_list_objects_raises_on_no_credentials(self, s3_loader): + """Should raise exception when credentials are missing.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value.__iter__ = MagicMock( + side_effect=NoCredentialsError() + ) + + with pytest.raises(Exception, match="credentials not found"): + s3_loader.list_objects("test-bucket", "") + + +class TestGetObjectContent: + """Test S3 object content retrieval.""" + + def test_get_text_file_content(self, s3_loader): + """Should return decoded text content for text files.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + mock_body = MagicMock() + mock_body.read.return_value = b"Hello, World!" + mock_client.get_object.return_value = {"Body": mock_body} + + result = s3_loader.get_object_content("test-bucket", "readme.txt") + + assert result == "Hello, World!" + mock_client.get_object.assert_called_once_with( + Bucket="test-bucket", Key="readme.txt" + ) + + def test_skip_unsupported_file_types(self, s3_loader): + """Should return None for unsupported file types.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + result = s3_loader.get_object_content("test-bucket", "image.png") + + assert result is None + mock_client.get_object.assert_not_called() + + def test_skip_empty_text_files(self, s3_loader): + """Should return None for empty text files.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + mock_body = MagicMock() + mock_body.read.return_value = b" \n\t " + mock_client.get_object.return_value = {"Body": mock_body} + + result = s3_loader.get_object_content("test-bucket", "empty.txt") + + assert result is None + + def test_returns_none_on_unicode_decode_error(self, s3_loader): + """Should return None when text file can't be decoded.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + mock_body = MagicMock() + mock_body.read.return_value = b"\xff\xfe" # Invalid UTF-8 + mock_client.get_object.return_value = {"Body": mock_body} + + result = s3_loader.get_object_content("test-bucket", "binary.txt") + + assert result is None + + def test_returns_none_on_no_such_key(self, s3_loader): + """Should return None when object doesn't exist.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + mock_client.get_object.side_effect = ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Key not found"}}, + "GetObject", + ) + + result = s3_loader.get_object_content("test-bucket", "missing.txt") + + assert result is None + + def test_returns_none_on_access_denied(self, s3_loader): + """Should return None when access is denied.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + mock_client.get_object.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, + "GetObject", + ) + + result = s3_loader.get_object_content("test-bucket", "secret.txt") + + assert result is None + + def test_processes_document_files(self, s3_loader): + """Should process document files through parser.""" + mock_client = MagicMock() + s3_loader.s3_client = mock_client + + mock_body = MagicMock() + mock_body.read.return_value = b"PDF content" + mock_client.get_object.return_value = {"Body": mock_body} + + with patch.object( + s3_loader, "_process_document", return_value="Extracted text" + ) as mock_process: + result = s3_loader.get_object_content("test-bucket", "document.pdf") + + assert result == "Extracted text" + mock_process.assert_called_once_with(b"PDF content", "document.pdf") + + +class TestLoadData: + """Test main load_data method.""" + + def test_load_data_from_dict_input(self, s3_loader, mock_boto3): + """Should load documents from dict input.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + # Setup mock paginator + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [ + {"Contents": [{"Key": "readme.md"}, {"Key": "guide.txt"}]} + ] + + # Setup mock get_object + def get_object_side_effect(Bucket, Key): + mock_body = MagicMock() + mock_body.read.return_value = f"Content of {Key}".encode() + return {"Body": mock_body} + + mock_client.get_object.side_effect = get_object_side_effect + + input_data = { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "test-bucket", + } + + docs = s3_loader.load_data(input_data) + + assert len(docs) == 2 + assert docs[0].text == "Content of readme.md" + assert docs[0].extra_info["bucket"] == "test-bucket" + assert docs[0].extra_info["key"] == "readme.md" + assert docs[0].extra_info["source"] == "s3://test-bucket/readme.md" + + def test_load_data_from_json_string(self, s3_loader, mock_boto3): + """Should load documents from JSON string input.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{"Contents": [{"Key": "file.txt"}]}] + + mock_body = MagicMock() + mock_body.read.return_value = b"File content" + mock_client.get_object.return_value = {"Body": mock_body} + + input_json = json.dumps( + { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "test-bucket", + } + ) + + docs = s3_loader.load_data(input_json) + + assert len(docs) == 1 + assert docs[0].text == "File content" + + def test_load_data_with_prefix(self, s3_loader, mock_boto3): + """Should filter objects by prefix.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{"Contents": [{"Key": "docs/readme.md"}]}] + + mock_body = MagicMock() + mock_body.read.return_value = b"Documentation" + mock_client.get_object.return_value = {"Body": mock_body} + + input_data = { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "test-bucket", + "prefix": "docs/", + } + + s3_loader.load_data(input_data) + + paginator.paginate.assert_called_once_with(Bucket="test-bucket", Prefix="docs/") + + def test_load_data_with_custom_region(self, s3_loader, mock_boto3): + """Should use custom region.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{}] + + input_data = { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "test-bucket", + "region": "eu-west-1", + } + + s3_loader.load_data(input_data) + + call_kwargs = mock_boto3.client.call_args[1] + assert call_kwargs["region_name"] == "eu-west-1" + + def test_load_data_with_custom_endpoint(self, s3_loader, mock_boto3): + """Should use custom endpoint URL.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{}] + + input_data = { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "test-bucket", + "endpoint_url": "https://nyc3.digitaloceanspaces.com", + } + + s3_loader.load_data(input_data) + + call_kwargs = mock_boto3.client.call_args[1] + assert call_kwargs["endpoint_url"] == "https://nyc3.digitaloceanspaces.com" + + def test_load_data_raises_on_invalid_json(self, s3_loader): + """Should raise ValueError for invalid JSON input.""" + with pytest.raises(ValueError, match="Invalid JSON"): + s3_loader.load_data("not valid json") + + def test_load_data_raises_on_missing_required_fields(self, s3_loader): + """Should raise ValueError when required fields are missing.""" + with pytest.raises(ValueError, match="Missing required fields"): + s3_loader.load_data({"aws_access_key_id": "test-key"}) + + with pytest.raises(ValueError, match="Missing required fields"): + s3_loader.load_data( + {"aws_access_key_id": "test-key", "aws_secret_access_key": "secret"} + ) + + def test_load_data_skips_unsupported_files(self, s3_loader, mock_boto3): + """Should skip unsupported file types.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "readme.txt"}, + {"Key": "image.png"}, # Unsupported + {"Key": "photo.jpg"}, # Unsupported + ] + } + ] + + def get_object_side_effect(Bucket, Key): + mock_body = MagicMock() + mock_body.read.return_value = b"Text content" + return {"Body": mock_body} + + mock_client.get_object.side_effect = get_object_side_effect + + input_data = { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "test-bucket", + } + + docs = s3_loader.load_data(input_data) + + # Only txt file should be loaded + assert len(docs) == 1 + assert docs[0].extra_info["key"] == "readme.txt" + + def test_load_data_uses_corrected_bucket_from_endpoint(self, s3_loader, mock_boto3): + """Should use bucket name extracted from DO Spaces URL.""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{"Contents": [{"Key": "file.txt"}]}] + + mock_body = MagicMock() + mock_body.read.return_value = b"Content" + mock_client.get_object.return_value = {"Body": mock_body} + + input_data = { + "aws_access_key_id": "test-key", + "aws_secret_access_key": "test-secret", + "bucket": "wrong-bucket", # Will be corrected from endpoint + "endpoint_url": "https://mybucket.nyc3.digitaloceanspaces.com", + } + + docs = s3_loader.load_data(input_data) + + # Verify bucket name was corrected + paginator.paginate.assert_called_once_with(Bucket="mybucket", Prefix="") + assert docs[0].extra_info["bucket"] == "mybucket" + + +class TestProcessDocument: + """Test document processing.""" + + def test_process_document_extracts_text(self, s3_loader): + """Should extract text from document files.""" + mock_doc = MagicMock() + mock_doc.text = "Extracted document text" + + with patch( + "application.parser.file.bulk.SimpleDirectoryReader" + ) as mock_reader_class: + mock_reader = MagicMock() + mock_reader.load_data.return_value = [mock_doc] + mock_reader_class.return_value = mock_reader + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_file = MagicMock() + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_file.name = "/tmp/test.pdf" + mock_temp.return_value = mock_file + + with patch("os.path.exists", return_value=True): + with patch("os.unlink"): + result = s3_loader._process_document( + b"PDF content", "document.pdf" + ) + + assert result == "Extracted document text" + + def test_process_document_returns_none_on_error(self, s3_loader): + """Should return None when document processing fails.""" + with patch( + "application.parser.file.bulk.SimpleDirectoryReader" + ) as mock_reader_class: + mock_reader_class.side_effect = Exception("Parse error") + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_file = MagicMock() + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_file.name = "/tmp/test.pdf" + mock_temp.return_value = mock_file + + with patch("os.path.exists", return_value=True): + with patch("os.unlink"): + result = s3_loader._process_document( + b"PDF content", "document.pdf" + ) + + assert result is None + + def test_process_document_cleans_up_temp_file(self, s3_loader): + """Should clean up temporary file after processing.""" + with patch( + "application.parser.file.bulk.SimpleDirectoryReader" + ) as mock_reader_class: + mock_reader = MagicMock() + mock_reader.load_data.return_value = [] + mock_reader_class.return_value = mock_reader + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_file = MagicMock() + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_file.name = "/tmp/test.pdf" + mock_temp.return_value = mock_file + + with patch("os.path.exists", return_value=True) as mock_exists: + with patch("os.unlink") as mock_unlink: + s3_loader._process_document(b"PDF content", "document.pdf") + + mock_exists.assert_called_with("/tmp/test.pdf") + mock_unlink.assert_called_with("/tmp/test.pdf")