mirror of
https://github.com/GH05TCREW/pentestagent.git
synced 2026-03-07 14:23:20 +00:00
mcp: add SSE listener + pending request handling to support async 202 flows
This commit is contained in:
@@ -175,6 +175,10 @@ class SSETransport(MCPTransport):
|
||||
self.session: Optional[Any] = None # aiohttp.ClientSession
|
||||
self._connected = False
|
||||
self._post_url: Optional[str] = None
|
||||
self._sse_response: Optional[Any] = None
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._pending: dict[str, asyncio.Future] = {}
|
||||
self._pending_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -188,46 +192,20 @@ class SSETransport(MCPTransport):
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
# Try to discover the POST endpoint for sending messages.
|
||||
# Many MCP SSE servers (including MetasploitMCP) expose an SSE
|
||||
# endpoint (e.g. /sse) which returns an initial `endpoint` event
|
||||
# with a `data: /messages/?session_id=...` value. We perform a
|
||||
# short GET to read that event and extract the messages POST URL.
|
||||
# Open a persistent SSE connection so we can receive async
|
||||
# responses delivered over the event stream. Keep the response
|
||||
# object alive and run a background task to parse events.
|
||||
try:
|
||||
async with self.session.get(self.url, timeout=5) as resp:
|
||||
if resp.status != 200:
|
||||
# Still consider connected (we may only need POST),
|
||||
# but leave discovery to send() which will give clearer
|
||||
# errors if the server isn't compatible.
|
||||
self._connected = True
|
||||
return
|
||||
|
||||
# Read a few lines to find `data:`
|
||||
for _ in range(20):
|
||||
line = await resp.content.readline()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
text = line.decode(errors="ignore").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if text.startswith("data:"):
|
||||
endpoint = text.split("data:", 1)[1].strip()
|
||||
# Build absolute POST URL from the discovered endpoint
|
||||
from urllib.parse import urlparse
|
||||
|
||||
p = urlparse(self.url)
|
||||
if endpoint.startswith("http"):
|
||||
self._post_url = endpoint
|
||||
elif endpoint.startswith("/"):
|
||||
self._post_url = f"{p.scheme}://{p.netloc}{endpoint}"
|
||||
else:
|
||||
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/') }"
|
||||
break
|
||||
|
||||
# Do not use a short timeout; keep the connection open.
|
||||
resp = await self.session.get(self.url, timeout=None)
|
||||
# Store response and start background reader
|
||||
self._sse_response = resp
|
||||
self._sse_task = asyncio.create_task(self._sse_listener(resp))
|
||||
except Exception:
|
||||
# Discovery failed — still create session and let send() report errors.
|
||||
pass
|
||||
# If opening the SSE stream fails, still mark connected so
|
||||
# send() can attempt POST discovery and report meaningful errors.
|
||||
self._sse_response = None
|
||||
self._sse_task = None
|
||||
|
||||
self._connected = True
|
||||
except ImportError as e:
|
||||
@@ -248,24 +226,144 @@ class SSETransport(MCPTransport):
|
||||
if not self.session:
|
||||
raise RuntimeError("Transport not connected")
|
||||
|
||||
if not self.session:
|
||||
raise RuntimeError("Transport not connected")
|
||||
|
||||
post_target = self._post_url or self.url
|
||||
|
||||
try:
|
||||
post_target = self._post_url or self.url
|
||||
async with self.session.post(
|
||||
post_target, json=message, headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise RuntimeError(f"HTTP error: {response.status}")
|
||||
|
||||
return await response.json()
|
||||
status = response.status
|
||||
if status == 200:
|
||||
return await response.json()
|
||||
if status == 202:
|
||||
# Asynchronous response: wait for matching SSE event with the same id
|
||||
if "id" not in message:
|
||||
return {}
|
||||
msg_id = str(message["id"])
|
||||
fut = asyncio.get_running_loop().create_future()
|
||||
async with self._pending_lock:
|
||||
self._pending[msg_id] = fut
|
||||
try:
|
||||
result = await asyncio.wait_for(fut, timeout=15.0)
|
||||
return result
|
||||
finally:
|
||||
async with self._pending_lock:
|
||||
self._pending.pop(msg_id, None)
|
||||
# Other statuses are errors
|
||||
raise RuntimeError(f"HTTP error: {status}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"SSE request failed: {e}") from e
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close the HTTP session."""
|
||||
# Cancel listener and close SSE response
|
||||
try:
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
try:
|
||||
await self._sse_task
|
||||
except Exception:
|
||||
pass
|
||||
self._sse_task = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if self._sse_response:
|
||||
try:
|
||||
await self._sse_response.release()
|
||||
except Exception:
|
||||
pass
|
||||
self._sse_response = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fail any pending requests
|
||||
async with self._pending_lock:
|
||||
for fut in list(self._pending.values()):
|
||||
if not fut.done():
|
||||
fut.set_exception(RuntimeError("Transport disconnected"))
|
||||
self._pending.clear()
|
||||
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
self._connected = False
|
||||
|
||||
async def _sse_listener(self, resp: Any):
|
||||
"""Background task that reads SSE events and resolves pending futures.
|
||||
|
||||
The listener expects SSE-formatted events where `data:` lines may
|
||||
contain JSON payloads. If a JSON object contains an `id` field that
|
||||
matches a pending request, the corresponding future is completed with
|
||||
that JSON value.
|
||||
"""
|
||||
try:
|
||||
# Read the stream line-by-line, accumulating event blocks
|
||||
event_lines: list[str] = []
|
||||
async for raw in resp.content:
|
||||
try:
|
||||
line = raw.decode(errors="ignore").rstrip("\r\n")
|
||||
except Exception:
|
||||
continue
|
||||
if line == "":
|
||||
# End of event; process accumulated lines
|
||||
event_name = None
|
||||
data_lines: list[str] = []
|
||||
for l in event_lines:
|
||||
if l.startswith("event:"):
|
||||
event_name = l.split(":", 1)[1].strip()
|
||||
elif l.startswith("data:"):
|
||||
data_lines.append(l.split(":", 1)[1].lstrip())
|
||||
|
||||
if data_lines:
|
||||
data_text = "\n".join(data_lines)
|
||||
# If this is an endpoint announcement, record POST URL
|
||||
if event_name == "endpoint":
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
p = urlparse(self.url)
|
||||
endpoint = data_text.strip()
|
||||
if endpoint.startswith("http"):
|
||||
self._post_url = endpoint
|
||||
elif endpoint.startswith("/"):
|
||||
self._post_url = f"{p.scheme}://{p.netloc}{endpoint}"
|
||||
else:
|
||||
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}"
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON and resolve pending futures
|
||||
try:
|
||||
obj = json.loads(data_text)
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
msg_id = str(obj.get("id"))
|
||||
async with self._pending_lock:
|
||||
fut = self._pending.get(msg_id)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(obj)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
event_lines = []
|
||||
else:
|
||||
event_lines.append(line)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception:
|
||||
# On error, fail pending futures
|
||||
async with self._pending_lock:
|
||||
for fut in list(self._pending.values()):
|
||||
if not fut.done():
|
||||
fut.set_exception(RuntimeError("SSE listener error"))
|
||||
self._pending.clear()
|
||||
finally:
|
||||
# Ensure we mark disconnected state
|
||||
self._connected = False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user