diff --git a/pentestagent/mcp/transport.py b/pentestagent/mcp/transport.py index d9468dc..1232a8e 100644 --- a/pentestagent/mcp/transport.py +++ b/pentestagent/mcp/transport.py @@ -175,6 +175,10 @@ class SSETransport(MCPTransport): self.session: Optional[Any] = None # aiohttp.ClientSession self._connected = False self._post_url: Optional[str] = None + self._sse_response: Optional[Any] = None + self._sse_task: Optional[asyncio.Task] = None + self._pending: dict[str, asyncio.Future] = {} + self._pending_lock = asyncio.Lock() @property def is_connected(self) -> bool: @@ -188,46 +192,20 @@ class SSETransport(MCPTransport): self.session = aiohttp.ClientSession() - # Try to discover the POST endpoint for sending messages. - # Many MCP SSE servers (including MetasploitMCP) expose an SSE - # endpoint (e.g. /sse) which returns an initial `endpoint` event - # with a `data: /messages/?session_id=...` value. We perform a - # short GET to read that event and extract the messages POST URL. + # Open a persistent SSE connection so we can receive async + # responses delivered over the event stream. Keep the response + # object alive and run a background task to parse events. try: - async with self.session.get(self.url, timeout=5) as resp: - if resp.status != 200: - # Still consider connected (we may only need POST), - # but leave discovery to send() which will give clearer - # errors if the server isn't compatible. - self._connected = True - return - - # Read a few lines to find `data:` - for _ in range(20): - line = await resp.content.readline() - if not line: - break - try: - text = line.decode(errors="ignore").strip() - except Exception: - continue - if text.startswith("data:"): - endpoint = text.split("data:", 1)[1].strip() - # Build absolute POST URL from the discovered endpoint - from urllib.parse import urlparse - - p = urlparse(self.url) - if endpoint.startswith("http"): - self._post_url = endpoint - elif endpoint.startswith("/"): - self._post_url = f"{p.scheme}://{p.netloc}{endpoint}" - else: - self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/') }" - break - + # Do not use a short timeout; keep the connection open. + resp = await self.session.get(self.url, timeout=None) + # Store response and start background reader + self._sse_response = resp + self._sse_task = asyncio.create_task(self._sse_listener(resp)) except Exception: - # Discovery failed — still create session and let send() report errors. - pass + # If opening the SSE stream fails, still mark connected so + # send() can attempt POST discovery and report meaningful errors. + self._sse_response = None + self._sse_task = None self._connected = True except ImportError as e: @@ -248,24 +226,144 @@ class SSETransport(MCPTransport): if not self.session: raise RuntimeError("Transport not connected") + if not self.session: + raise RuntimeError("Transport not connected") + + post_target = self._post_url or self.url + try: - post_target = self._post_url or self.url async with self.session.post( post_target, json=message, headers={"Content-Type": "application/json"} ) as response: - if response.status != 200: - raise RuntimeError(f"HTTP error: {response.status}") - - return await response.json() + status = response.status + if status == 200: + return await response.json() + if status == 202: + # Asynchronous response: wait for matching SSE event with the same id + if "id" not in message: + return {} + msg_id = str(message["id"]) + fut = asyncio.get_running_loop().create_future() + async with self._pending_lock: + self._pending[msg_id] = fut + try: + result = await asyncio.wait_for(fut, timeout=15.0) + return result + finally: + async with self._pending_lock: + self._pending.pop(msg_id, None) + # Other statuses are errors + raise RuntimeError(f"HTTP error: {status}") except Exception as e: raise RuntimeError(f"SSE request failed: {e}") from e async def disconnect(self): """Close the HTTP session.""" + # Cancel listener and close SSE response + try: + if self._sse_task: + self._sse_task.cancel() + try: + await self._sse_task + except Exception: + pass + self._sse_task = None + except Exception: + pass + + try: + if self._sse_response: + try: + await self._sse_response.release() + except Exception: + pass + self._sse_response = None + except Exception: + pass + + # Fail any pending requests + async with self._pending_lock: + for fut in list(self._pending.values()): + if not fut.done(): + fut.set_exception(RuntimeError("Transport disconnected")) + self._pending.clear() + if self.session: await self.session.close() self.session = None + self._connected = False + + async def _sse_listener(self, resp: Any): + """Background task that reads SSE events and resolves pending futures. + + The listener expects SSE-formatted events where `data:` lines may + contain JSON payloads. If a JSON object contains an `id` field that + matches a pending request, the corresponding future is completed with + that JSON value. + """ + try: + # Read the stream line-by-line, accumulating event blocks + event_lines: list[str] = [] + async for raw in resp.content: + try: + line = raw.decode(errors="ignore").rstrip("\r\n") + except Exception: + continue + if line == "": + # End of event; process accumulated lines + event_name = None + data_lines: list[str] = [] + for l in event_lines: + if l.startswith("event:"): + event_name = l.split(":", 1)[1].strip() + elif l.startswith("data:"): + data_lines.append(l.split(":", 1)[1].lstrip()) + + if data_lines: + data_text = "\n".join(data_lines) + # If this is an endpoint announcement, record POST URL + if event_name == "endpoint": + try: + from urllib.parse import urlparse + + p = urlparse(self.url) + endpoint = data_text.strip() + if endpoint.startswith("http"): + self._post_url = endpoint + elif endpoint.startswith("/"): + self._post_url = f"{p.scheme}://{p.netloc}{endpoint}" + else: + self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}" + except Exception: + pass + else: + # Try to parse as JSON and resolve pending futures + try: + obj = json.loads(data_text) + if isinstance(obj, dict) and "id" in obj: + msg_id = str(obj.get("id")) + async with self._pending_lock: + fut = self._pending.get(msg_id) + if fut and not fut.done(): + fut.set_result(obj) + except Exception: + pass + + event_lines = [] + else: + event_lines.append(line) + except asyncio.CancelledError: + return + except Exception: + # On error, fail pending futures + async with self._pending_lock: + for fut in list(self._pending.values()): + if not fut.done(): + fut.set_exception(RuntimeError("SSE listener error")) + self._pending.clear() + finally: + # Ensure we mark disconnected state self._connected = False