(fix:s3) processor func

This commit is contained in:
ManishMadan2882
2025-04-17 02:36:55 +05:30
parent 0a0e16547e
commit 9454150f7d
2 changed files with 21 additions and 23 deletions

View File

@@ -1,6 +1,6 @@
"""Base storage class for file system abstraction."""
from abc import ABC, abstractmethod
from typing import BinaryIO, List, Optional, Callable
from typing import BinaryIO, List, Callable
class BaseStorage(ABC):

View File

@@ -1,28 +1,31 @@
"""S3 storage implementation."""
import io
from typing import BinaryIO, List, Callable
import os
import boto3
from botocore.exceptions import ClientError
from application.storage.base import BaseStorage
from application.core.settings import settings
class S3Storage(BaseStorage):
"""AWS S3 storage implementation."""
def __init__(self, bucket_name: str, aws_access_key_id=None,
aws_secret_access_key=None, region_name=None):
def __init__(self, bucket_name=None):
"""
Initialize S3 storage.
Args:
bucket_name: S3 bucket name
aws_access_key_id: AWS access key ID (optional if using IAM roles)
aws_secret_access_key: AWS secret access key (optional if using IAM roles)
region_name: AWS region name (optional)
bucket_name: S3 bucket name (optional, defaults to settings)
"""
self.bucket_name = bucket_name
self.bucket_name = bucket_name or getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
# Get credentials from settings
aws_access_key_id = getattr(settings, "SAGEMAKER_ACCESS_KEY", None)
aws_secret_access_key = getattr(settings, "SAGEMAKER_SECRET_KEY", None)
region_name = getattr(settings, "SAGEMAKER_REGION", None)
self.s3 = boto3.client(
's3',
@@ -83,8 +86,6 @@ class S3Storage(BaseStorage):
"""
Process a file using the provided processor function.
For S3 storage, we need to download the file to a temporary location first.
Args:
path: Path to the file
processor_func: Function that processes the file
@@ -94,21 +95,18 @@ class S3Storage(BaseStorage):
The result of the processor function
"""
import tempfile
import os
import logging
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
self.s3.download_fileobj(self.bucket_name, path, temp_file)
temp_path = temp_file.name
try:
result = processor_func(file_path=temp_path, **kwargs)
return result
finally:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(path)[1], delete=True) as temp_file:
try:
os.unlink(temp_path)
# Download the file from S3 to the temporary file
self.s3.download_fileobj(self.bucket_name, path, temp_file)
temp_file.flush()
result = processor_func(file_path=temp_file.name, **kwargs)
return result
except Exception as e:
import logging
logging.warning(f"Failed to delete temporary file: {e}")
logging.error(f"Error processing S3 file {path}: {e}", exc_info=True)
raise