feat: Implement OAuth flow for MCP server integration

- Added MCPOAuthManager to handle OAuth authorization.
- Updated MCPServerSave resource to manage OAuth status and callback.
- Introduced new endpoints for OAuth status and callback handling.
- Enhanced user interface to support OAuth authentication type.
- Implemented polling mechanism for OAuth status in MCPServerModal.
- Updated frontend services and endpoints to accommodate new OAuth features.
- Improved error handling and user feedback for OAuth processes.
This commit is contained in:
Siddhant Rai
2025-09-26 02:44:08 +05:30
parent 00b4e133d4
commit 3b27db36f2
9 changed files with 991 additions and 289 deletions

View File

@@ -8,6 +8,7 @@ import uuid
import zipfile
from functools import wraps
from typing import Optional, Tuple
from urllib.parse import unquote
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
@@ -25,7 +26,7 @@ from flask_restx import fields, inputs, Namespace, Resource
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.agents.tools.mcp_tool import MCPTool
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.agents.tools.tool_manager import ToolManager
from application.api import api
@@ -37,6 +38,8 @@ from application.api.user.tasks import (
process_agent_webhook,
store_attachment,
)
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
@@ -4303,7 +4306,7 @@ class TestMCPServerConfig(Resource):
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(test_config, user)
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
@@ -4371,8 +4374,33 @@ class MCPServerSave(Resource):
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
if auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(mcp_config, user)
if auth_type == "oauth":
if not config.get("oauth_task_id"):
return make_response(
jsonify(
{
"success": False,
"error": "Connection not authorized. Please complete the OAuth authorization first.",
}
),
400,
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
{
"success": False,
"error": "OAuth failed or not completed. Please try authorizing again.",
}
),
400,
)
actions_metadata = result.get("tools", [])
elif auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(config=mcp_config, user_id=user)
mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata()
else:
@@ -4455,3 +4483,96 @@ class MCPServerSave(Resource):
),
500,
)
@user_ns.route("/api/mcp_server/callback")
class MCPOAuthCallback(Resource):
@api.expect(
api.model(
"MCPServerCallbackModel",
{
"code": fields.String(required=True, description="Authorization code"),
"state": fields.String(required=True, description="State parameter"),
"error": fields.String(
required=False, description="Error message (if any)"
),
},
)
)
@api.doc(
description="Handle OAuth callback by providing the authorization code and state"
)
def get(self):
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
f"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
)
try:
redis_client = get_redis_instance()
if not redis_client:
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
return redirect(
f"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
)
else:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
)
except Exception as e:
current_app.logger.error(
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@user_ns.route("/api/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
}
),
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)