- Added image upload functionality for agents in the backend and frontend.

- Implemented image URL generation based on storage strategy (S3 or local).
- Updated agent creation and update endpoints to handle image files.
- Enhanced frontend components to display agent images with fallbacks.
- New API endpoint to serve images from storage.
- Refactored API client to support FormData for file uploads.
- Improved error handling and logging for image processing.
This commit is contained in:
Siddhant Rai
2025-06-18 18:17:20 +05:30
parent cab1f3787a
commit e35f1d70e4
12 changed files with 385 additions and 146 deletions

View File

@@ -28,7 +28,12 @@ from application.core.settings import settings
from application.extensions import api
from application.storage.storage_creator import StorageCreator
from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields, safe_filename, validate_function_name
from application.utils import (
check_required_fields,
safe_filename,
validate_function_name,
generate_image_url,
)
from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
@@ -253,7 +258,7 @@ class GetSingleConversation(Resource):
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
# Process queries to include attachment names
queries = conversation["queries"]
for query in queries:
@@ -265,13 +270,18 @@ class GetSingleConversation(Resource):
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append({
"id": str(attachment["_id"]),
"fileName": attachment.get("filename", "Unknown file")
})
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
except Exception as err:
@@ -499,7 +509,7 @@ class UploadFile(Resource):
)
user = decoded_token.get("sub")
job_name = request.form["name"]
# Create safe versions for filesystem operations
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
@@ -1069,27 +1079,28 @@ class UpdatePrompt(Resource):
@user_ns.route("/api/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "ID of the agent"}, description="Get a single agent by ID")
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
if not (agent_id := request.args.get("id")):
return {"success": False, "message": "ID required"}, 400
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": user}
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
return {"status": "Not found"}, 404
data = {
"id": str(agent["_id"]),
"name": agent["name"],
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"source": (
str(source_doc["_id"])
if isinstance(agent.get("source"), DBRef)
@@ -1116,19 +1127,20 @@ class GetAgent(Resource):
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
}
except Exception as err:
current_app.logger.error(f"Error retrieving agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(data), 200)
return make_response(jsonify(data), 200)
except Exception as e:
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
return {"success": False}, 400
@user_ns.route("/api/get_agents")
class GetAgents(Resource):
@api.doc(description="Retrieve agents for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
user = decoded_token.get("sub")
try:
user_doc = ensure_user_doc(user)
@@ -1140,6 +1152,9 @@ class GetAgents(Resource):
"id": str(agent["_id"]),
"name": agent["name"],
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"source": (
str(source_doc["_id"])
if isinstance(agent.get("source"), DBRef)
@@ -1184,8 +1199,8 @@ class CreateAgent(Resource):
"description": fields.String(
required=True, description="Description of the agent"
),
"image": fields.String(
required=False, description="Image URL or identifier"
"image": fields.Raw(
required=False, description="Image file upload", type="file"
),
"source": fields.String(required=True, description="Source ID"),
"chunks": fields.Integer(required=True, description="Chunks count"),
@@ -1204,12 +1219,20 @@ class CreateAgent(Resource):
@api.expect(create_agent_model)
@api.doc(description="Create a new agent")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
user = decoded_token.get("sub")
data = request.get_json()
if request.content_type == "application/json":
data = request.get_json()
else:
print(request.form)
data = request.form.to_dict()
if "tools" in data:
try:
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
print(f"Received data: {data}")
if data.get("status") not in ["draft", "published"]:
return make_response(
jsonify({"success": False, "message": "Invalid status"}), 400
@@ -1230,13 +1253,29 @@ class CreateAgent(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
image_url = ""
if "image" in request.files:
file = request.files["image"]
if file.filename != "":
filename = secure_filename(file.filename)
upload_path = f"agents/{user}/{str(uuid.uuid4())}_{filename}"
try:
storage.save_file(file, upload_path)
image_url = upload_path
except Exception as e:
current_app.logger.error(f"Error uploading agent image: {e}")
return make_response(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
try:
key = str(uuid.uuid4())
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": data.get("image", ""),
"image": image_url,
"source": (
DBRef("sources", ObjectId(data.get("source")))
if ObjectId.is_valid(data.get("source"))
@@ -1294,11 +1333,18 @@ class UpdateAgent(Resource):
@api.expect(update_agent_model)
@api.doc(description="Update an existing agent")
def put(self, agent_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
user = decoded_token.get("sub")
data = request.get_json()
if request.content_type == "application/json":
data = request.get_json()
else:
data = request.form.to_dict()
if "tools" in data:
try:
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if not ObjectId.is_valid(agent_id):
return make_response(
@@ -1323,6 +1369,23 @@ class UpdateAgent(Resource):
),
404,
)
image_url = existing_agent.get("image", "")
if "image" in request.files:
file = request.files["image"]
if file.filename != "":
filename = secure_filename(file.filename)
upload_path = f"agents/{user}/{str(uuid.uuid4())}_{filename}"
try:
storage.save_file(file, upload_path)
image_url = upload_path
except Exception as e:
current_app.logger.error(f"Error uploading agent image: {e}")
return make_response(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
update_fields = {}
allowed_fields = [
"name",
@@ -1394,6 +1457,8 @@ class UpdateAgent(Resource):
)
else:
update_fields[field] = data[field]
if image_url:
update_fields["image"] = image_url
if not update_fields:
return make_response(
jsonify({"success": False, "message": "No update data provided"}), 400
@@ -1542,6 +1607,9 @@ class PinnedAgents(Resource):
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"source": (
str(db.dereference(agent["source"])["_id"])
if "source" in agent
@@ -1702,6 +1770,11 @@ class SharedAgent(Resource):
"id": agent_id,
"user": shared_agent.get("user", ""),
"name": shared_agent.get("name", ""),
"image": (
generate_image_url(shared_agent["image"])
if shared_agent.get("image")
else ""
),
"description": shared_agent.get("description", ""),
"tools": shared_agent.get("tools", []),
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
@@ -1777,6 +1850,9 @@ class SharedAgents(Resource):
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
@@ -2241,7 +2317,7 @@ class GetPubliclySharedConversations(Resource):
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
for query in conversation_queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
@@ -2251,13 +2327,18 @@ class GetPubliclySharedConversations(Resource):
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append({
"id": str(attachment["_id"]),
"fileName": attachment.get("filename", "Unknown file")
})
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
else:
@@ -3493,3 +3574,33 @@ class StoreAttachment(Resource):
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@user_ns.route("/api/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
try:
s3_storage = StorageCreator.create_storage("s3")
file_obj = s3_storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"
if extension == "jpg":
content_type = "image/jpeg"
response = make_response(file_obj.read())
response.headers.set("Content-Type", content_type)
response.headers.set("Cache-Control", "max-age=86400")
return response
except FileNotFoundError:
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(
jsonify({"success": False, "message": "Error retrieving image"}), 500
)

View File

@@ -1,13 +1,14 @@
"""S3 storage implementation."""
import io
from typing import BinaryIO, List, Callable
import os
from typing import BinaryIO, Callable, List
import boto3
from botocore.exceptions import ClientError
from application.core.settings import settings
from application.storage.base import BaseStorage
from application.core.settings import settings
from botocore.exceptions import ClientError
class S3Storage(BaseStorage):
@@ -20,18 +21,21 @@ class S3Storage(BaseStorage):
Args:
bucket_name: S3 bucket name (optional, defaults to settings)
"""
self.bucket_name = bucket_name or getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
self.bucket_name = bucket_name or getattr(
settings, "S3_BUCKET_NAME", "docsgpt-test-bucket"
)
# Get credentials from settings
aws_access_key_id = getattr(settings, "SAGEMAKER_ACCESS_KEY", None)
aws_secret_access_key = getattr(settings, "SAGEMAKER_SECRET_KEY", None)
region_name = getattr(settings, "SAGEMAKER_REGION", None)
self.s3 = boto3.client(
's3',
"s3",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name
region_name=region_name,
)
def save_file(self, file_data: BinaryIO, path: str) -> dict:
@@ -41,17 +45,16 @@ class S3Storage(BaseStorage):
region = getattr(settings, "SAGEMAKER_REGION", None)
return {
'storage_type': 's3',
'bucket_name': self.bucket_name,
'uri': f's3://{self.bucket_name}/{path}',
'region': region
"storage_type": "s3",
"bucket_name": self.bucket_name,
"uri": f"s3://{self.bucket_name}/{path}",
"region": region,
}
def get_file(self, path: str) -> BinaryIO:
"""Get a file from S3 storage."""
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
file_obj = io.BytesIO()
self.s3.download_fileobj(self.bucket_name, path, file_obj)
file_obj.seek(0)
@@ -76,18 +79,17 @@ class S3Storage(BaseStorage):
def list_files(self, directory: str) -> List[str]:
"""List all files in a directory in S3 storage."""
# Ensure directory ends with a slash if it's not empty
if directory and not directory.endswith('/'):
directory += '/'
if directory and not directory.endswith("/"):
directory += "/"
result = []
paginator = self.s3.get_paginator('list_objects_v2')
paginator = self.s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=directory)
for page in pages:
if 'Contents' in page:
for obj in page['Contents']:
result.append(obj['Key'])
if "Contents" in page:
for obj in page["Contents"]:
result.append(obj["Key"])
return result
def process_file(self, path: str, processor_func: Callable, **kwargs):
@@ -98,22 +100,24 @@ class S3Storage(BaseStorage):
path: Path to the file
processor_func: Function that processes the file
**kwargs: Additional arguments to pass to the processor function
Returns:
The result of the processor function
"""
import tempfile
import logging
import tempfile
if not self.file_exists(path):
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(path)[1], delete=True) as temp_file:
with tempfile.NamedTemporaryFile(
suffix=os.path.splitext(path)[1], delete=True
) as temp_file:
try:
# Download the file from S3 to the temporary file
self.s3.download_fileobj(self.bucket_name, path, temp_file)
temp_file.flush()
return processor_func(local_path=temp_file.name, **kwargs)
except Exception as e:
logging.error(f"Error processing S3 file {path}: {e}", exc_info=True)

View File

@@ -6,6 +6,7 @@ import uuid
import tiktoken
from flask import jsonify, make_response
from werkzeug.utils import secure_filename
from application.core.settings import settings
_encoding = None
@@ -22,24 +23,24 @@ def safe_filename(filename):
"""
Creates a safe filename that preserves the original extension.
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
Args:
filename (str): The original filename
Returns:
str: A safe filename that can be used for storage
"""
if not filename:
return str(uuid.uuid4())
_, extension = os.path.splitext(filename)
safe_name = secure_filename(filename)
# If secure_filename returns just the extension or an empty string
if not safe_name or safe_name == extension.lstrip('.'):
if not safe_name or safe_name == extension.lstrip("."):
return f"{str(uuid.uuid4())}{extension}"
return safe_name
@@ -137,3 +138,14 @@ def validate_function_name(function_name):
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True
def generate_image_url(image_path):
strategy = getattr(settings, "URL_STRATEGY", "backend")
if strategy == "s3":
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
region_name = getattr(settings, "SAGEMAKER_REGION", "eu-central-1")
return f"https://{bucket_name}.s3.{region_name}.amazonaws.com/{image_path}"
else:
base_url = getattr(settings, "API_URL", "http://localhost:7091")
return f"{base_url}/api/images/{image_path}"