mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-19 10:51:01 +00:00
Add Amazon S3 support and synchronization features (#2244)
* Add Amazon S3 support and synchronization features * refactor: remove unused variable in load_data test
This commit is contained in:
@@ -3,6 +3,7 @@ from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.remote.web_loader import WebLoader
|
||||
from application.parser.remote.reddit_loader import RedditPostsLoaderRemote
|
||||
from application.parser.remote.github_loader import GitHubLoader
|
||||
from application.parser.remote.s3_loader import S3Loader
|
||||
|
||||
|
||||
class RemoteCreator:
|
||||
@@ -22,6 +23,7 @@ class RemoteCreator:
|
||||
"crawler": CrawlerLoader,
|
||||
"reddit": RedditPostsLoaderRemote,
|
||||
"github": GitHubLoader,
|
||||
"s3": S3Loader,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
427
application/parser/remote/s3_loader.py
Normal file
427
application/parser/remote/s3_loader.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import mimetypes
|
||||
from typing import List, Optional
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3Loader(BaseRemote):
|
||||
"""Load documents from an AWS S3 bucket."""
|
||||
|
||||
def __init__(self):
|
||||
if boto3 is None:
|
||||
raise ImportError(
|
||||
"boto3 is required for S3Loader. Install it with: pip install boto3"
|
||||
)
|
||||
self.s3_client = None
|
||||
|
||||
def _normalize_endpoint_url(self, endpoint_url: str, bucket: str) -> tuple[str, str]:
|
||||
"""
|
||||
Normalize endpoint URL for S3-compatible services.
|
||||
|
||||
Detects common mistakes like using bucket-prefixed URLs and extracts
|
||||
the correct endpoint and bucket name.
|
||||
|
||||
Args:
|
||||
endpoint_url: The provided endpoint URL
|
||||
bucket: The provided bucket name
|
||||
|
||||
Returns:
|
||||
Tuple of (normalized_endpoint_url, bucket_name)
|
||||
"""
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if not endpoint_url:
|
||||
return endpoint_url, bucket
|
||||
|
||||
parsed = urlparse(endpoint_url)
|
||||
host = parsed.netloc or parsed.path
|
||||
|
||||
# Check for DigitalOcean Spaces bucket-prefixed URL pattern
|
||||
# e.g., https://mybucket.nyc3.digitaloceanspaces.com
|
||||
do_match = re.match(r"^([^.]+)\.([a-z0-9]+)\.digitaloceanspaces\.com$", host)
|
||||
if do_match:
|
||||
extracted_bucket = do_match.group(1)
|
||||
region = do_match.group(2)
|
||||
correct_endpoint = f"https://{region}.digitaloceanspaces.com"
|
||||
logger.warning(
|
||||
f"Detected bucket-prefixed DigitalOcean Spaces URL. "
|
||||
f"Extracted bucket '{extracted_bucket}' from endpoint. "
|
||||
f"Using endpoint: {correct_endpoint}"
|
||||
)
|
||||
# If bucket wasn't provided or differs, use extracted one
|
||||
if not bucket or bucket != extracted_bucket:
|
||||
logger.info(f"Using extracted bucket name: '{extracted_bucket}' (was: '{bucket}')")
|
||||
bucket = extracted_bucket
|
||||
return correct_endpoint, bucket
|
||||
|
||||
# Check for just "digitaloceanspaces.com" without region
|
||||
if host == "digitaloceanspaces.com":
|
||||
logger.error(
|
||||
"Invalid DigitalOcean Spaces endpoint: missing region. "
|
||||
"Use format: https://<region>.digitaloceanspaces.com (e.g., https://lon1.digitaloceanspaces.com)"
|
||||
)
|
||||
|
||||
return endpoint_url, bucket
|
||||
|
||||
def _init_client(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
region_name: str = "us-east-1",
|
||||
endpoint_url: Optional[str] = None,
|
||||
bucket: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Initialize the S3 client with credentials.
|
||||
|
||||
Returns:
|
||||
The potentially corrected bucket name if endpoint URL was normalized
|
||||
"""
|
||||
from botocore.config import Config
|
||||
|
||||
client_kwargs = {
|
||||
"aws_access_key_id": aws_access_key_id,
|
||||
"aws_secret_access_key": aws_secret_access_key,
|
||||
"region_name": region_name,
|
||||
}
|
||||
|
||||
logger.info(f"Initializing S3 client with region: {region_name}")
|
||||
|
||||
corrected_bucket = bucket
|
||||
if endpoint_url:
|
||||
# Normalize the endpoint URL and potentially extract bucket name
|
||||
normalized_endpoint, corrected_bucket = self._normalize_endpoint_url(endpoint_url, bucket)
|
||||
logger.info(f"Original endpoint URL: {endpoint_url}")
|
||||
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
|
||||
logger.info(f"Bucket name: '{corrected_bucket}'")
|
||||
|
||||
client_kwargs["endpoint_url"] = normalized_endpoint
|
||||
# Use path-style addressing for S3-compatible services
|
||||
# (DigitalOcean Spaces, MinIO, etc.)
|
||||
client_kwargs["config"] = Config(s3={"addressing_style": "path"})
|
||||
else:
|
||||
logger.info("Using default AWS S3 endpoint")
|
||||
|
||||
self.s3_client = boto3.client("s3", **client_kwargs)
|
||||
logger.info("S3 client initialized successfully")
|
||||
|
||||
return corrected_bucket
|
||||
|
||||
def is_text_file(self, file_path: str) -> bool:
|
||||
"""Determine if a file is a text file based on extension."""
|
||||
text_extensions = {
|
||||
".txt",
|
||||
".md",
|
||||
".markdown",
|
||||
".rst",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".sass",
|
||||
".less",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".sql",
|
||||
".r",
|
||||
".m",
|
||||
".mat",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".config",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerignore",
|
||||
".editorconfig",
|
||||
".log",
|
||||
".csv",
|
||||
".tsv",
|
||||
}
|
||||
|
||||
file_lower = file_path.lower()
|
||||
for ext in text_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if mime_type and (
|
||||
mime_type.startswith("text")
|
||||
or mime_type in ["application/json", "application/xml"]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_supported_document(self, file_path: str) -> bool:
|
||||
"""Check if file is a supported document type for parsing."""
|
||||
document_extensions = {
|
||||
".pdf",
|
||||
".docx",
|
||||
".doc",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".epub",
|
||||
".odt",
|
||||
".rtf",
|
||||
}
|
||||
|
||||
file_lower = file_path.lower()
|
||||
for ext in document_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_objects(self, bucket: str, prefix: str = "") -> List[str]:
|
||||
"""
|
||||
List all objects in the bucket with the given prefix.
|
||||
|
||||
Args:
|
||||
bucket: S3 bucket name
|
||||
prefix: Optional path prefix to filter objects
|
||||
|
||||
Returns:
|
||||
List of object keys
|
||||
"""
|
||||
objects = []
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
|
||||
logger.info(f"Listing objects in bucket: '{bucket}' with prefix: '{prefix}'")
|
||||
logger.debug(f"S3 client endpoint: {self.s3_client.meta.endpoint_url}")
|
||||
|
||||
try:
|
||||
page_count = 0
|
||||
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||||
page_count += 1
|
||||
logger.debug(f"Processing page {page_count}, keys in response: {list(page.keys())}")
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
key = obj["Key"]
|
||||
if not key.endswith("/"):
|
||||
objects.append(key)
|
||||
logger.debug(f"Found object: {key}")
|
||||
else:
|
||||
logger.info(f"Page {page_count} has no 'Contents' key - bucket may be empty or prefix not found")
|
||||
|
||||
logger.info(f"Found {len(objects)} objects in bucket '{bucket}'")
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
error_message = e.response.get("Error", {}).get("Message", "")
|
||||
logger.error(f"ClientError listing objects - Code: {error_code}, Message: {error_message}")
|
||||
logger.error(f"Full error response: {e.response}")
|
||||
logger.error(f"Bucket: '{bucket}', Prefix: '{prefix}', Endpoint: {self.s3_client.meta.endpoint_url}")
|
||||
|
||||
if error_code == "NoSuchBucket":
|
||||
raise Exception(f"S3 bucket '{bucket}' does not exist")
|
||||
elif error_code == "AccessDenied":
|
||||
raise Exception(
|
||||
f"Access denied to S3 bucket '{bucket}'. Check your credentials and permissions."
|
||||
)
|
||||
elif error_code == "NoSuchKey":
|
||||
# This is unusual for ListObjectsV2 - may indicate endpoint/bucket configuration issue
|
||||
logger.error(
|
||||
"NoSuchKey error on ListObjectsV2 - this may indicate the bucket name "
|
||||
"is incorrect or the endpoint URL format is wrong. "
|
||||
"For DigitalOcean Spaces, the endpoint should be like: "
|
||||
"https://<region>.digitaloceanspaces.com and bucket should be just the space name."
|
||||
)
|
||||
raise Exception(
|
||||
f"S3 error: {e}. For S3-compatible services, verify: "
|
||||
f"1) Endpoint URL format (e.g., https://nyc3.digitaloceanspaces.com), "
|
||||
f"2) Bucket name is just the space/bucket name without region prefix"
|
||||
)
|
||||
else:
|
||||
raise Exception(f"S3 error: {e}")
|
||||
except NoCredentialsError:
|
||||
raise Exception(
|
||||
"AWS credentials not found. Please provide valid credentials."
|
||||
)
|
||||
|
||||
return objects
|
||||
|
||||
def get_object_content(self, bucket: str, key: str) -> Optional[str]:
|
||||
"""
|
||||
Get the content of an S3 object as text.
|
||||
|
||||
Args:
|
||||
bucket: S3 bucket name
|
||||
key: Object key
|
||||
|
||||
Returns:
|
||||
File content as string, or None if file should be skipped
|
||||
"""
|
||||
if not self.is_text_file(key) and not self.is_supported_document(key):
|
||||
return None
|
||||
|
||||
try:
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
content = response["Body"].read()
|
||||
|
||||
if self.is_text_file(key):
|
||||
try:
|
||||
decoded_content = content.decode("utf-8").strip()
|
||||
if not decoded_content:
|
||||
return None
|
||||
return decoded_content
|
||||
except UnicodeDecodeError:
|
||||
return None
|
||||
elif self.is_supported_document(key):
|
||||
return self._process_document(content, key)
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code == "NoSuchKey":
|
||||
return None
|
||||
elif error_code == "AccessDenied":
|
||||
print(f"Access denied to object: {key}")
|
||||
return None
|
||||
else:
|
||||
print(f"Error fetching object {key}: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _process_document(self, content: bytes, key: str) -> Optional[str]:
|
||||
"""
|
||||
Process a document file (PDF, DOCX, etc.) and extract text.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
key: Object key (filename)
|
||||
|
||||
Returns:
|
||||
Extracted text content
|
||||
"""
|
||||
ext = os.path.splitext(key)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(input_files=[tmp_path])
|
||||
documents = reader.load_data()
|
||||
if documents:
|
||||
return "\n\n".join(doc.text for doc in documents if doc.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing document {key}: {e}")
|
||||
return None
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
def load_data(self, inputs) -> List[Document]:
|
||||
"""
|
||||
Load documents from an S3 bucket.
|
||||
|
||||
Args:
|
||||
inputs: JSON string or dict containing:
|
||||
- aws_access_key_id: AWS access key ID
|
||||
- aws_secret_access_key: AWS secret access key
|
||||
- bucket: S3 bucket name
|
||||
- prefix: Optional path prefix to filter objects
|
||||
- region: AWS region (default: us-east-1)
|
||||
- endpoint_url: Custom S3 endpoint URL (for MinIO, R2, etc.)
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
try:
|
||||
data = json.loads(inputs)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON input: {e}")
|
||||
else:
|
||||
data = inputs
|
||||
|
||||
required_fields = ["aws_access_key_id", "aws_secret_access_key", "bucket"]
|
||||
missing_fields = [field for field in required_fields if not data.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
aws_access_key_id = data["aws_access_key_id"]
|
||||
aws_secret_access_key = data["aws_secret_access_key"]
|
||||
bucket = data["bucket"]
|
||||
prefix = data.get("prefix", "")
|
||||
region = data.get("region", "us-east-1")
|
||||
endpoint_url = data.get("endpoint_url", "")
|
||||
|
||||
logger.info(f"Loading data from S3 - Bucket: '{bucket}', Prefix: '{prefix}', Region: '{region}'")
|
||||
if endpoint_url:
|
||||
logger.info(f"Custom endpoint URL provided: '{endpoint_url}'")
|
||||
|
||||
corrected_bucket = self._init_client(
|
||||
aws_access_key_id, aws_secret_access_key, region, endpoint_url or None, bucket
|
||||
)
|
||||
|
||||
# Use the corrected bucket name if endpoint URL normalization extracted one
|
||||
if corrected_bucket and corrected_bucket != bucket:
|
||||
logger.info(f"Using corrected bucket name: '{corrected_bucket}' (original: '{bucket}')")
|
||||
bucket = corrected_bucket
|
||||
|
||||
objects = self.list_objects(bucket, prefix)
|
||||
documents = []
|
||||
|
||||
for key in objects:
|
||||
content = self.get_object_content(bucket, key)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
text=content,
|
||||
doc_id=key,
|
||||
extra_info={
|
||||
"title": os.path.basename(key),
|
||||
"source": f"s3://{bucket}/{key}",
|
||||
"bucket": bucket,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(documents)} documents from S3 bucket '{bucket}'")
|
||||
return documents
|
||||
Reference in New Issue
Block a user