mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 20:32:11 +00:00
* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution * refactor: workflow schemas and introduce WorkflowEngine - Updated schemas in `schemas.py` to include new agent types and configurations. - Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution. - Enhanced `StreamProcessor` to handle workflow-related data. - Added new routes and utilities for managing workflows in the user API. - Implemented validation and serialization functions for workflows. - Established MongoDB collections and indexes for workflows and related entities. * refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine * feat: workflow builder and managing in frontend - Added new endpoints for workflows in `endpoints.ts`. - Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`. - Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects. - Enhanced styling in `index.css` with new theme variables and animations. - Refactored modal components for better layout and styling. - Configured TypeScript paths and Vite aliases for cleaner imports. * feat: add workflow preview component and related state management - Implemented WorkflowPreview component for displaying workflow execution. - Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps. - Added WorkflowMiniMap for visual representation of workflow nodes and their statuses. - Integrated conversation handling with the ability to fetch answers and manage query states. - Introduced reusable Sheet component for UI overlays. - Updated Redux store to include workflowPreview reducer. * feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview * feat: enhance workflow components with improved UI and functionality - Updated WorkflowPreview to allow text truncation for better display of long names. - Enhanced BaseNode with connectable handles and improved styling for better visibility. - Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder. - Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition. * feat(workflow): add owner validation and graph version support * fix: ruff lint --------- Co-authored-by: Alex <a@tushynski.me>
379 lines
10 KiB
Python
379 lines
10 KiB
Python
"""Centralized utilities for API routes."""
|
|
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
from bson.errors import InvalidId
|
|
from bson.objectid import ObjectId
|
|
from flask import jsonify, make_response, request, Response
|
|
from pymongo.collection import Collection
|
|
|
|
|
|
def get_user_id() -> Optional[str]:
|
|
"""
|
|
Extract user ID from decoded JWT token.
|
|
|
|
Returns:
|
|
User ID string or None if not authenticated
|
|
"""
|
|
decoded_token = getattr(request, "decoded_token", None)
|
|
return decoded_token.get("sub") if decoded_token else None
|
|
|
|
|
|
def require_auth(func: Callable) -> Callable:
|
|
"""
|
|
Decorator to require authentication for route handlers.
|
|
|
|
Usage:
|
|
@require_auth
|
|
def get(self):
|
|
user_id = get_user_id()
|
|
...
|
|
"""
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
user_id = get_user_id()
|
|
if not user_id:
|
|
return error_response("Unauthorized", 401)
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def success_response(
|
|
data: Optional[Dict[str, Any]] = None, status: int = 200
|
|
) -> Response:
|
|
"""
|
|
Create a standardized success response.
|
|
|
|
Args:
|
|
data: Optional data dictionary to include in response
|
|
status: HTTP status code (default: 200)
|
|
|
|
Returns:
|
|
Flask Response object
|
|
|
|
Example:
|
|
return success_response({"users": [...], "total": 10})
|
|
"""
|
|
response = {"success": True}
|
|
if data:
|
|
response.update(data)
|
|
return make_response(jsonify(response), status)
|
|
|
|
|
|
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
|
"""
|
|
Create a standardized error response.
|
|
|
|
Args:
|
|
message: Error message string
|
|
status: HTTP status code (default: 400)
|
|
**kwargs: Additional fields to include in response
|
|
|
|
Returns:
|
|
Flask Response object
|
|
|
|
Example:
|
|
return error_response("Resource not found", 404)
|
|
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
|
"""
|
|
response = {"success": False, "message": message}
|
|
response.update(kwargs)
|
|
return make_response(jsonify(response), status)
|
|
|
|
|
|
def validate_object_id(
|
|
id_string: str, resource_name: str = "Resource"
|
|
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
|
"""
|
|
Validate and convert string to ObjectId.
|
|
|
|
Args:
|
|
id_string: String to convert
|
|
resource_name: Name of resource for error message
|
|
|
|
Returns:
|
|
Tuple of (ObjectId or None, error_response or None)
|
|
|
|
Example:
|
|
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
|
if error:
|
|
return error
|
|
"""
|
|
try:
|
|
return ObjectId(id_string), None
|
|
except (InvalidId, TypeError):
|
|
return None, error_response(f"Invalid {resource_name} ID format")
|
|
|
|
|
|
def validate_pagination(
|
|
default_limit: int = 20, max_limit: int = 100
|
|
) -> Tuple[int, int, Optional[Response]]:
|
|
"""
|
|
Extract and validate pagination parameters from request.
|
|
|
|
Args:
|
|
default_limit: Default items per page
|
|
max_limit: Maximum allowed items per page
|
|
|
|
Returns:
|
|
Tuple of (limit, skip, error_response or None)
|
|
|
|
Example:
|
|
limit, skip, error = validate_pagination()
|
|
if error:
|
|
return error
|
|
"""
|
|
try:
|
|
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
|
skip = int(request.args.get("skip", 0))
|
|
if limit < 1 or skip < 0:
|
|
return 0, 0, error_response("Invalid pagination parameters")
|
|
return limit, skip, None
|
|
except ValueError:
|
|
return 0, 0, error_response("Invalid pagination parameters")
|
|
|
|
|
|
def check_resource_ownership(
|
|
collection: Collection,
|
|
resource_id: ObjectId,
|
|
user_id: str,
|
|
resource_name: str = "Resource",
|
|
) -> Tuple[Optional[Dict], Optional[Response]]:
|
|
"""
|
|
Check if resource exists and belongs to user.
|
|
|
|
Args:
|
|
collection: MongoDB collection
|
|
resource_id: Resource ObjectId
|
|
user_id: User ID string
|
|
resource_name: Name of resource for error messages
|
|
|
|
Returns:
|
|
Tuple of (resource_dict or None, error_response or None)
|
|
|
|
Example:
|
|
workflow, error = check_resource_ownership(
|
|
workflows_collection,
|
|
workflow_id,
|
|
user_id,
|
|
"Workflow"
|
|
)
|
|
if error:
|
|
return error
|
|
"""
|
|
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
|
if not resource:
|
|
return None, error_response(f"{resource_name} not found", 404)
|
|
return resource, None
|
|
|
|
|
|
def serialize_object_id(
|
|
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Convert ObjectId to string in a dictionary.
|
|
|
|
Args:
|
|
obj: Dictionary containing ObjectId
|
|
id_field: Field name containing ObjectId
|
|
new_field: New field name for string ID
|
|
|
|
Returns:
|
|
Modified dictionary
|
|
|
|
Example:
|
|
user = serialize_object_id(user_doc)
|
|
# user["id"] = "507f1f77bcf86cd799439011"
|
|
"""
|
|
if id_field in obj:
|
|
obj[new_field] = str(obj[id_field])
|
|
if id_field != new_field:
|
|
obj.pop(id_field, None)
|
|
return obj
|
|
|
|
|
|
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
|
"""
|
|
Apply serializer function to list of items.
|
|
|
|
Args:
|
|
items: List of dictionaries
|
|
serializer: Function to apply to each item
|
|
|
|
Returns:
|
|
List of serialized items
|
|
|
|
Example:
|
|
workflows = serialize_list(workflow_docs, serialize_workflow)
|
|
"""
|
|
return [serializer(item) for item in items]
|
|
|
|
|
|
def paginated_response(
|
|
collection: Collection,
|
|
query: Dict[str, Any],
|
|
serializer: Callable[[Dict], Dict],
|
|
limit: int,
|
|
skip: int,
|
|
sort_field: str = "created_at",
|
|
sort_order: int = -1,
|
|
response_key: str = "items",
|
|
) -> Response:
|
|
"""
|
|
Create paginated response for collection query.
|
|
|
|
Args:
|
|
collection: MongoDB collection
|
|
query: Query dictionary
|
|
serializer: Function to serialize each item
|
|
limit: Items per page
|
|
skip: Number of items to skip
|
|
sort_field: Field to sort by
|
|
sort_order: Sort order (1=asc, -1=desc)
|
|
response_key: Key name for items in response
|
|
|
|
Returns:
|
|
Flask Response with paginated data
|
|
|
|
Example:
|
|
return paginated_response(
|
|
workflows_collection,
|
|
{"user": user_id},
|
|
serialize_workflow,
|
|
limit, skip,
|
|
response_key="workflows"
|
|
)
|
|
"""
|
|
items = list(
|
|
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
|
)
|
|
total = collection.count_documents(query)
|
|
|
|
return success_response(
|
|
{
|
|
response_key: serialize_list(items, serializer),
|
|
"total": total,
|
|
"limit": limit,
|
|
"skip": skip,
|
|
}
|
|
)
|
|
|
|
|
|
def require_fields(required: List[str]) -> Callable:
|
|
"""
|
|
Decorator to validate required fields in request JSON.
|
|
|
|
Args:
|
|
required: List of required field names
|
|
|
|
Returns:
|
|
Decorator function
|
|
|
|
Example:
|
|
@require_fields(["name", "description"])
|
|
def post(self):
|
|
data = request.get_json()
|
|
...
|
|
"""
|
|
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
data = request.get_json()
|
|
if not data:
|
|
return error_response("Request body required")
|
|
missing = [field for field in required if not data.get(field)]
|
|
if missing:
|
|
return error_response(f"Missing required fields: {', '.join(missing)}")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def safe_db_operation(
|
|
operation: Callable, error_message: str = "Database operation failed"
|
|
) -> Tuple[Any, Optional[Response]]:
|
|
"""
|
|
Safely execute database operation with error handling.
|
|
|
|
Args:
|
|
operation: Function to execute
|
|
error_message: Error message if operation fails
|
|
|
|
Returns:
|
|
Tuple of (result or None, error_response or None)
|
|
|
|
Example:
|
|
result, error = safe_db_operation(
|
|
lambda: collection.insert_one(doc),
|
|
"Failed to create resource"
|
|
)
|
|
if error:
|
|
return error
|
|
"""
|
|
try:
|
|
result = operation()
|
|
return result, None
|
|
except Exception as e:
|
|
return None, error_response(f"{error_message}: {str(e)}")
|
|
|
|
|
|
def validate_enum(
|
|
value: Any, allowed: List[Any], field_name: str
|
|
) -> Optional[Response]:
|
|
"""
|
|
Validate that value is in allowed list.
|
|
|
|
Args:
|
|
value: Value to validate
|
|
allowed: List of allowed values
|
|
field_name: Field name for error message
|
|
|
|
Returns:
|
|
error_response if invalid, None if valid
|
|
|
|
Example:
|
|
error = validate_enum(status, ["draft", "published"], "status")
|
|
if error:
|
|
return error
|
|
"""
|
|
if value not in allowed:
|
|
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
|
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
|
return None
|
|
|
|
|
|
def extract_sort_params(
|
|
default_field: str = "created_at",
|
|
default_order: str = "desc",
|
|
allowed_fields: Optional[List[str]] = None,
|
|
) -> Tuple[str, int]:
|
|
"""
|
|
Extract and validate sort parameters from request.
|
|
|
|
Args:
|
|
default_field: Default sort field
|
|
default_order: Default sort order ("asc" or "desc")
|
|
allowed_fields: List of allowed sort fields (None = no validation)
|
|
|
|
Returns:
|
|
Tuple of (sort_field, sort_order)
|
|
|
|
Example:
|
|
sort_field, sort_order = extract_sort_params(
|
|
allowed_fields=["name", "date", "status"]
|
|
)
|
|
"""
|
|
sort_field = request.args.get("sort", default_field)
|
|
sort_order_str = request.args.get("order", default_order).lower()
|
|
|
|
if allowed_fields and sort_field not in allowed_fields:
|
|
sort_field = default_field
|
|
sort_order = -1 if sort_order_str == "desc" else 1
|
|
return sort_field, sort_order
|