mcp: add SSE listener + pending request handling to support async 202 flows

This commit is contained in:
giveen
2026-01-14 16:52:28 -07:00
parent 580fc37614
commit 1476b1e117

View File

@@ -175,6 +175,10 @@ class SSETransport(MCPTransport):
self.session: Optional[Any] = None # aiohttp.ClientSession
self._connected = False
self._post_url: Optional[str] = None
self._sse_response: Optional[Any] = None
self._sse_task: Optional[asyncio.Task] = None
self._pending: dict[str, asyncio.Future] = {}
self._pending_lock = asyncio.Lock()
@property
def is_connected(self) -> bool:
@@ -188,46 +192,20 @@ class SSETransport(MCPTransport):
self.session = aiohttp.ClientSession()
# Try to discover the POST endpoint for sending messages.
# Many MCP SSE servers (including MetasploitMCP) expose an SSE
# endpoint (e.g. /sse) which returns an initial `endpoint` event
# with a `data: /messages/?session_id=...` value. We perform a
# short GET to read that event and extract the messages POST URL.
# Open a persistent SSE connection so we can receive async
# responses delivered over the event stream. Keep the response
# object alive and run a background task to parse events.
try:
async with self.session.get(self.url, timeout=5) as resp:
if resp.status != 200:
# Still consider connected (we may only need POST),
# but leave discovery to send() which will give clearer
# errors if the server isn't compatible.
self._connected = True
return
# Read a few lines to find `data:`
for _ in range(20):
line = await resp.content.readline()
if not line:
break
try:
text = line.decode(errors="ignore").strip()
except Exception:
continue
if text.startswith("data:"):
endpoint = text.split("data:", 1)[1].strip()
# Build absolute POST URL from the discovered endpoint
from urllib.parse import urlparse
p = urlparse(self.url)
if endpoint.startswith("http"):
self._post_url = endpoint
elif endpoint.startswith("/"):
self._post_url = f"{p.scheme}://{p.netloc}{endpoint}"
else:
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/') }"
break
# Do not use a short timeout; keep the connection open.
resp = await self.session.get(self.url, timeout=None)
# Store response and start background reader
self._sse_response = resp
self._sse_task = asyncio.create_task(self._sse_listener(resp))
except Exception:
# Discovery failed — still create session and let send() report errors.
pass
# If opening the SSE stream fails, still mark connected so
# send() can attempt POST discovery and report meaningful errors.
self._sse_response = None
self._sse_task = None
self._connected = True
except ImportError as e:
@@ -248,24 +226,144 @@ class SSETransport(MCPTransport):
if not self.session:
raise RuntimeError("Transport not connected")
if not self.session:
raise RuntimeError("Transport not connected")
post_target = self._post_url or self.url
try:
post_target = self._post_url or self.url
async with self.session.post(
post_target, json=message, headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
raise RuntimeError(f"HTTP error: {response.status}")
return await response.json()
status = response.status
if status == 200:
return await response.json()
if status == 202:
# Asynchronous response: wait for matching SSE event with the same id
if "id" not in message:
return {}
msg_id = str(message["id"])
fut = asyncio.get_running_loop().create_future()
async with self._pending_lock:
self._pending[msg_id] = fut
try:
result = await asyncio.wait_for(fut, timeout=15.0)
return result
finally:
async with self._pending_lock:
self._pending.pop(msg_id, None)
# Other statuses are errors
raise RuntimeError(f"HTTP error: {status}")
except Exception as e:
raise RuntimeError(f"SSE request failed: {e}") from e
async def disconnect(self):
"""Close the HTTP session."""
# Cancel listener and close SSE response
try:
if self._sse_task:
self._sse_task.cancel()
try:
await self._sse_task
except Exception:
pass
self._sse_task = None
except Exception:
pass
try:
if self._sse_response:
try:
await self._sse_response.release()
except Exception:
pass
self._sse_response = None
except Exception:
pass
# Fail any pending requests
async with self._pending_lock:
for fut in list(self._pending.values()):
if not fut.done():
fut.set_exception(RuntimeError("Transport disconnected"))
self._pending.clear()
if self.session:
await self.session.close()
self.session = None
self._connected = False
async def _sse_listener(self, resp: Any):
"""Background task that reads SSE events and resolves pending futures.
The listener expects SSE-formatted events where `data:` lines may
contain JSON payloads. If a JSON object contains an `id` field that
matches a pending request, the corresponding future is completed with
that JSON value.
"""
try:
# Read the stream line-by-line, accumulating event blocks
event_lines: list[str] = []
async for raw in resp.content:
try:
line = raw.decode(errors="ignore").rstrip("\r\n")
except Exception:
continue
if line == "":
# End of event; process accumulated lines
event_name = None
data_lines: list[str] = []
for l in event_lines:
if l.startswith("event:"):
event_name = l.split(":", 1)[1].strip()
elif l.startswith("data:"):
data_lines.append(l.split(":", 1)[1].lstrip())
if data_lines:
data_text = "\n".join(data_lines)
# If this is an endpoint announcement, record POST URL
if event_name == "endpoint":
try:
from urllib.parse import urlparse
p = urlparse(self.url)
endpoint = data_text.strip()
if endpoint.startswith("http"):
self._post_url = endpoint
elif endpoint.startswith("/"):
self._post_url = f"{p.scheme}://{p.netloc}{endpoint}"
else:
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}"
except Exception:
pass
else:
# Try to parse as JSON and resolve pending futures
try:
obj = json.loads(data_text)
if isinstance(obj, dict) and "id" in obj:
msg_id = str(obj.get("id"))
async with self._pending_lock:
fut = self._pending.get(msg_id)
if fut and not fut.done():
fut.set_result(obj)
except Exception:
pass
event_lines = []
else:
event_lines.append(line)
except asyncio.CancelledError:
return
except Exception:
# On error, fail pending futures
async with self._pending_lock:
for fut in list(self._pending.values()):
if not fut.done():
fut.set_exception(RuntimeError("SSE listener error"))
self._pending.clear()
finally:
# Ensure we mark disconnected state
self._connected = False