diff --git a/application/storage/base.py b/application/storage/base.py index 88fed0c6..273e7761 100644 --- a/application/storage/base.py +++ b/application/storage/base.py @@ -7,84 +7,87 @@ class BaseStorage(ABC): """Abstract base class for storage implementations.""" @abstractmethod - def save_file(self, file_data: BinaryIO, path: str) -> str: + def save_file(self, file_data: BinaryIO, path: str) -> dict: """ Save a file to storage. - + Args: file_data: File-like object containing the data path: Path where the file should be stored - + Returns: - str: The complete path where the file was saved + dict: A dictionary containing metadata about the saved file, including: + - 'path': The path where the file was saved + - 'storage_type': The type of storage (e.g., 'local', 's3') + - Other storage-specific metadata (e.g., 'uri', 'bucket_name', etc.) """ pass - + @abstractmethod def get_file(self, path: str) -> BinaryIO: """ Retrieve a file from storage. - + Args: path: Path to the file - + Returns: BinaryIO: File-like object containing the file data """ pass - + @abstractmethod def process_file(self, path: str, processor_func: Callable, **kwargs): """ Process a file using the provided processor function. - + This method handles the details of retrieving the file and providing it to the processor function in an appropriate way based on the storage type. - + Args: path: Path to the file processor_func: Function that processes the file **kwargs: Additional arguments to pass to the processor function - + Returns: The result of the processor function """ pass - + @abstractmethod def delete_file(self, path: str) -> bool: """ Delete a file from storage. - + Args: path: Path to the file - + Returns: bool: True if deletion was successful """ pass - + @abstractmethod def file_exists(self, path: str) -> bool: """ Check if a file exists. - + Args: path: Path to the file - + Returns: bool: True if the file exists """ pass - + @abstractmethod def list_files(self, directory: str) -> List[str]: """ List all files in a directory. - + Args: directory: Directory path to list - + Returns: List[str]: List of file paths """ diff --git a/application/storage/local.py b/application/storage/local.py index 91c5c264..db11b63c 100644 --- a/application/storage/local.py +++ b/application/storage/local.py @@ -8,98 +8,96 @@ from application.storage.base import BaseStorage class LocalStorage(BaseStorage): """Local file system storage implementation.""" - + def __init__(self, base_dir: str = None): """ Initialize local storage. - + Args: base_dir: Base directory for all operations. If None, uses current directory. """ self.base_dir = base_dir or os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) - + def _get_full_path(self, path: str) -> str: """Get absolute path by combining base_dir and path.""" if os.path.isabs(path): return path return os.path.join(self.base_dir, path) - - def save_file(self, file_data: BinaryIO, path: str) -> str: + + def save_file(self, file_data: BinaryIO, path: str) -> dict: """Save a file to local storage.""" full_path = self._get_full_path(path) - - # Ensure directory exists + os.makedirs(os.path.dirname(full_path), exist_ok=True) - - # Write file + if hasattr(file_data, 'save'): - # Handle Flask's FileStorage objects file_data.save(full_path) else: - # Handle regular file-like objects with open(full_path, 'wb') as f: shutil.copyfileobj(file_data, f) - - return path - + + return { + 'storage_type': 'local' + } + def get_file(self, path: str) -> BinaryIO: """Get a file from local storage.""" full_path = self._get_full_path(path) - + if not os.path.exists(full_path): raise FileNotFoundError(f"File not found: {full_path}") - + return open(full_path, 'rb') - + def delete_file(self, path: str) -> bool: """Delete a file from local storage.""" full_path = self._get_full_path(path) - + if not os.path.exists(full_path): return False - + os.remove(full_path) return True - + def file_exists(self, path: str) -> bool: """Check if a file exists in local storage.""" full_path = self._get_full_path(path) return os.path.exists(full_path) - + def list_files(self, directory: str) -> List[str]: """List all files in a directory in local storage.""" full_path = self._get_full_path(directory) - + if not os.path.exists(full_path): return [] - + result = [] for root, _, files in os.walk(full_path): for file in files: rel_path = os.path.relpath(os.path.join(root, file), self.base_dir) result.append(rel_path) - + return result def process_file(self, path: str, processor_func: Callable, **kwargs): """ Process a file using the provided processor function. - + For local storage, we can directly pass the full path to the processor. - + Args: path: Path to the file processor_func: Function that processes the file **kwargs: Additional arguments to pass to the processor function - + Returns: The result of the processor function """ full_path = self._get_full_path(path) - + if not os.path.exists(full_path): raise FileNotFoundError(f"File not found: {full_path}") - + return processor_func(file_path=full_path, **kwargs) diff --git a/application/storage/s3.py b/application/storage/s3.py index e02a2a5a..e8df210e 100644 --- a/application/storage/s3.py +++ b/application/storage/s3.py @@ -12,43 +12,51 @@ from application.core.settings import settings class S3Storage(BaseStorage): """AWS S3 storage implementation.""" - + def __init__(self, bucket_name=None): """ Initialize S3 storage. - + Args: bucket_name: S3 bucket name (optional, defaults to settings) """ self.bucket_name = bucket_name or getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket") - + # Get credentials from settings aws_access_key_id = getattr(settings, "SAGEMAKER_ACCESS_KEY", None) aws_secret_access_key = getattr(settings, "SAGEMAKER_SECRET_KEY", None) region_name = getattr(settings, "SAGEMAKER_REGION", None) - + self.s3 = boto3.client( 's3', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=region_name ) - - def save_file(self, file_data: BinaryIO, path: str) -> str: + + def save_file(self, file_data: BinaryIO, path: str) -> dict: """Save a file to S3 storage.""" self.s3.upload_fileobj(file_data, self.bucket_name, path) - return path - + + region = getattr(settings, "SAGEMAKER_REGION", None) + + return { + 'storage_type': 's3', + 'bucket_name': self.bucket_name, + 'uri': f's3://{self.bucket_name}/{path}', + 'region': region + } + def get_file(self, path: str) -> BinaryIO: """Get a file from S3 storage.""" if not self.file_exists(path): raise FileNotFoundError(f"File not found: {path}") - + file_obj = io.BytesIO() self.s3.download_fileobj(self.bucket_name, path, file_obj) file_obj.seek(0) return file_obj - + def delete_file(self, path: str) -> bool: """Delete a file from S3 storage.""" try: @@ -56,7 +64,7 @@ class S3Storage(BaseStorage): return True except ClientError: return False - + def file_exists(self, path: str) -> bool: """Check if a file exists in S3 storage.""" try: @@ -64,42 +72,42 @@ class S3Storage(BaseStorage): return True except ClientError: return False - + def list_files(self, directory: str) -> List[str]: """List all files in a directory in S3 storage.""" # Ensure directory ends with a slash if it's not empty if directory and not directory.endswith('/'): directory += '/' - + result = [] paginator = self.s3.get_paginator('list_objects_v2') pages = paginator.paginate(Bucket=self.bucket_name, Prefix=directory) - + for page in pages: if 'Contents' in page: for obj in page['Contents']: result.append(obj['Key']) - + return result def process_file(self, path: str, processor_func: Callable, **kwargs): """ Process a file using the provided processor function. - + Args: path: Path to the file processor_func: Function that processes the file **kwargs: Additional arguments to pass to the processor function - + Returns: The result of the processor function """ import tempfile import logging - + if not self.file_exists(path): raise FileNotFoundError(f"File not found in S3: {path}") - + with tempfile.NamedTemporaryFile(suffix=os.path.splitext(path)[1], delete=True) as temp_file: try: # Download the file from S3 to the temporary file