mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
388 lines
10 KiB
Python
388 lines
10 KiB
Python
"""Centralized utilities for API routes."""
|
|
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
from bson.errors import InvalidId
|
|
from bson.objectid import ObjectId
|
|
from flask import (
|
|
Response,
|
|
current_app,
|
|
has_app_context,
|
|
jsonify,
|
|
make_response,
|
|
request,
|
|
)
|
|
from pymongo.collection import Collection
|
|
|
|
|
|
def get_user_id() -> Optional[str]:
|
|
"""
|
|
Extract user ID from decoded JWT token.
|
|
|
|
Returns:
|
|
User ID string or None if not authenticated
|
|
"""
|
|
decoded_token = getattr(request, "decoded_token", None)
|
|
return decoded_token.get("sub") if decoded_token else None
|
|
|
|
|
|
def require_auth(func: Callable) -> Callable:
|
|
"""
|
|
Decorator to require authentication for route handlers.
|
|
|
|
Usage:
|
|
@require_auth
|
|
def get(self):
|
|
user_id = get_user_id()
|
|
...
|
|
"""
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
user_id = get_user_id()
|
|
if not user_id:
|
|
return error_response("Unauthorized", 401)
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def success_response(
|
|
data: Optional[Dict[str, Any]] = None, status: int = 200
|
|
) -> Response:
|
|
"""
|
|
Create a standardized success response.
|
|
|
|
Args:
|
|
data: Optional data dictionary to include in response
|
|
status: HTTP status code (default: 200)
|
|
|
|
Returns:
|
|
Flask Response object
|
|
|
|
Example:
|
|
return success_response({"users": [...], "total": 10})
|
|
"""
|
|
response = {"success": True}
|
|
if data:
|
|
response.update(data)
|
|
return make_response(jsonify(response), status)
|
|
|
|
|
|
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
|
"""
|
|
Create a standardized error response.
|
|
|
|
Args:
|
|
message: Error message string
|
|
status: HTTP status code (default: 400)
|
|
**kwargs: Additional fields to include in response
|
|
|
|
Returns:
|
|
Flask Response object
|
|
|
|
Example:
|
|
return error_response("Resource not found", 404)
|
|
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
|
"""
|
|
response = {"success": False, "message": message}
|
|
response.update(kwargs)
|
|
return make_response(jsonify(response), status)
|
|
|
|
|
|
def validate_object_id(
|
|
id_string: str, resource_name: str = "Resource"
|
|
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
|
"""
|
|
Validate and convert string to ObjectId.
|
|
|
|
Args:
|
|
id_string: String to convert
|
|
resource_name: Name of resource for error message
|
|
|
|
Returns:
|
|
Tuple of (ObjectId or None, error_response or None)
|
|
|
|
Example:
|
|
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
|
if error:
|
|
return error
|
|
"""
|
|
try:
|
|
return ObjectId(id_string), None
|
|
except (InvalidId, TypeError):
|
|
return None, error_response(f"Invalid {resource_name} ID format")
|
|
|
|
|
|
def validate_pagination(
|
|
default_limit: int = 20, max_limit: int = 100
|
|
) -> Tuple[int, int, Optional[Response]]:
|
|
"""
|
|
Extract and validate pagination parameters from request.
|
|
|
|
Args:
|
|
default_limit: Default items per page
|
|
max_limit: Maximum allowed items per page
|
|
|
|
Returns:
|
|
Tuple of (limit, skip, error_response or None)
|
|
|
|
Example:
|
|
limit, skip, error = validate_pagination()
|
|
if error:
|
|
return error
|
|
"""
|
|
try:
|
|
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
|
skip = int(request.args.get("skip", 0))
|
|
if limit < 1 or skip < 0:
|
|
return 0, 0, error_response("Invalid pagination parameters")
|
|
return limit, skip, None
|
|
except ValueError:
|
|
return 0, 0, error_response("Invalid pagination parameters")
|
|
|
|
|
|
def check_resource_ownership(
|
|
collection: Collection,
|
|
resource_id: ObjectId,
|
|
user_id: str,
|
|
resource_name: str = "Resource",
|
|
) -> Tuple[Optional[Dict], Optional[Response]]:
|
|
"""
|
|
Check if resource exists and belongs to user.
|
|
|
|
Args:
|
|
collection: MongoDB collection
|
|
resource_id: Resource ObjectId
|
|
user_id: User ID string
|
|
resource_name: Name of resource for error messages
|
|
|
|
Returns:
|
|
Tuple of (resource_dict or None, error_response or None)
|
|
|
|
Example:
|
|
workflow, error = check_resource_ownership(
|
|
workflows_collection,
|
|
workflow_id,
|
|
user_id,
|
|
"Workflow"
|
|
)
|
|
if error:
|
|
return error
|
|
"""
|
|
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
|
if not resource:
|
|
return None, error_response(f"{resource_name} not found", 404)
|
|
return resource, None
|
|
|
|
|
|
def serialize_object_id(
|
|
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Convert ObjectId to string in a dictionary.
|
|
|
|
Args:
|
|
obj: Dictionary containing ObjectId
|
|
id_field: Field name containing ObjectId
|
|
new_field: New field name for string ID
|
|
|
|
Returns:
|
|
Modified dictionary
|
|
|
|
Example:
|
|
user = serialize_object_id(user_doc)
|
|
# user["id"] = "507f1f77bcf86cd799439011"
|
|
"""
|
|
if id_field in obj:
|
|
obj[new_field] = str(obj[id_field])
|
|
if id_field != new_field:
|
|
obj.pop(id_field, None)
|
|
return obj
|
|
|
|
|
|
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
|
"""
|
|
Apply serializer function to list of items.
|
|
|
|
Args:
|
|
items: List of dictionaries
|
|
serializer: Function to apply to each item
|
|
|
|
Returns:
|
|
List of serialized items
|
|
|
|
Example:
|
|
workflows = serialize_list(workflow_docs, serialize_workflow)
|
|
"""
|
|
return [serializer(item) for item in items]
|
|
|
|
|
|
def paginated_response(
|
|
collection: Collection,
|
|
query: Dict[str, Any],
|
|
serializer: Callable[[Dict], Dict],
|
|
limit: int,
|
|
skip: int,
|
|
sort_field: str = "created_at",
|
|
sort_order: int = -1,
|
|
response_key: str = "items",
|
|
) -> Response:
|
|
"""
|
|
Create paginated response for collection query.
|
|
|
|
Args:
|
|
collection: MongoDB collection
|
|
query: Query dictionary
|
|
serializer: Function to serialize each item
|
|
limit: Items per page
|
|
skip: Number of items to skip
|
|
sort_field: Field to sort by
|
|
sort_order: Sort order (1=asc, -1=desc)
|
|
response_key: Key name for items in response
|
|
|
|
Returns:
|
|
Flask Response with paginated data
|
|
|
|
Example:
|
|
return paginated_response(
|
|
workflows_collection,
|
|
{"user": user_id},
|
|
serialize_workflow,
|
|
limit, skip,
|
|
response_key="workflows"
|
|
)
|
|
"""
|
|
items = list(
|
|
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
|
)
|
|
total = collection.count_documents(query)
|
|
|
|
return success_response(
|
|
{
|
|
response_key: serialize_list(items, serializer),
|
|
"total": total,
|
|
"limit": limit,
|
|
"skip": skip,
|
|
}
|
|
)
|
|
|
|
|
|
def require_fields(required: List[str]) -> Callable:
|
|
"""
|
|
Decorator to validate required fields in request JSON.
|
|
|
|
Args:
|
|
required: List of required field names
|
|
|
|
Returns:
|
|
Decorator function
|
|
|
|
Example:
|
|
@require_fields(["name", "description"])
|
|
def post(self):
|
|
data = request.get_json()
|
|
...
|
|
"""
|
|
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
data = request.get_json()
|
|
if not data:
|
|
return error_response("Request body required")
|
|
missing = [field for field in required if not data.get(field)]
|
|
if missing:
|
|
return error_response(f"Missing required fields: {', '.join(missing)}")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def safe_db_operation(
|
|
operation: Callable, error_message: str = "Database operation failed"
|
|
) -> Tuple[Any, Optional[Response]]:
|
|
"""
|
|
Safely execute database operation with error handling.
|
|
|
|
Args:
|
|
operation: Function to execute
|
|
error_message: Error message if operation fails
|
|
|
|
Returns:
|
|
Tuple of (result or None, error_response or None)
|
|
|
|
Example:
|
|
result, error = safe_db_operation(
|
|
lambda: collection.insert_one(doc),
|
|
"Failed to create resource"
|
|
)
|
|
if error:
|
|
return error
|
|
"""
|
|
try:
|
|
result = operation()
|
|
return result, None
|
|
except Exception as err:
|
|
if has_app_context():
|
|
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
|
|
return None, error_response(error_message)
|
|
|
|
|
|
def validate_enum(
|
|
value: Any, allowed: List[Any], field_name: str
|
|
) -> Optional[Response]:
|
|
"""
|
|
Validate that value is in allowed list.
|
|
|
|
Args:
|
|
value: Value to validate
|
|
allowed: List of allowed values
|
|
field_name: Field name for error message
|
|
|
|
Returns:
|
|
error_response if invalid, None if valid
|
|
|
|
Example:
|
|
error = validate_enum(status, ["draft", "published"], "status")
|
|
if error:
|
|
return error
|
|
"""
|
|
if value not in allowed:
|
|
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
|
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
|
return None
|
|
|
|
|
|
def extract_sort_params(
|
|
default_field: str = "created_at",
|
|
default_order: str = "desc",
|
|
allowed_fields: Optional[List[str]] = None,
|
|
) -> Tuple[str, int]:
|
|
"""
|
|
Extract and validate sort parameters from request.
|
|
|
|
Args:
|
|
default_field: Default sort field
|
|
default_order: Default sort order ("asc" or "desc")
|
|
allowed_fields: List of allowed sort fields (None = no validation)
|
|
|
|
Returns:
|
|
Tuple of (sort_field, sort_order)
|
|
|
|
Example:
|
|
sort_field, sort_order = extract_sort_params(
|
|
allowed_fields=["name", "date", "status"]
|
|
)
|
|
"""
|
|
sort_field = request.args.get("sort", default_field)
|
|
sort_order_str = request.args.get("order", default_order).lower()
|
|
|
|
if allowed_fields and sort_field not in allowed_fields:
|
|
sort_field = default_field
|
|
sort_order = -1 if sort_order_str == "desc" else 1
|
|
return sort_field, sort_order
|