From 080a32a8fa081e0f4bff51209cdd31f2ed7d32ab Mon Sep 17 00:00:00 2001 From: giveen Date: Wed, 21 Jan 2026 11:06:36 -0700 Subject: [PATCH] mcp: add unix socket control + tools call/list via daemon --- pentestagent/interface/main.py | 186 +++++++++++++++++++++++++++++++-- pentestagent/mcp/manager.py | 130 +++++++++++++++++++++++ 2 files changed, 309 insertions(+), 7 deletions(-) diff --git a/pentestagent/interface/main.py b/pentestagent/interface/main.py index 682e3b3..09e6717 100644 --- a/pentestagent/interface/main.py +++ b/pentestagent/interface/main.py @@ -94,6 +94,17 @@ Examples: help="Temporarily connect to configured MCP servers and include their tools", ) + # tools call + tools_call = tools_subparsers.add_parser("call", help="Call a tool (via MCP daemon if available)") + tools_call.add_argument("server", help="MCP server name") + tools_call.add_argument("tool", help="Tool name") + tools_call.add_argument( + "--json", + dest="json_args", + help="JSON string of arguments to pass to the tool", + default=None, + ) + # tools info tools_info = tools_subparsers.add_parser("info", help="Show tool details") tools_info.add_argument("name", help="Tool name") @@ -108,6 +119,9 @@ Examples: # mcp list mcp_subparsers.add_parser("list", help="List configured MCP servers") + # mcp status + mcp_subparsers.add_parser("status", help="Show MCP daemon status (socket)" ) + # mcp add mcp_add = mcp_subparsers.add_parser("add", help="Add an MCP server") mcp_add.add_argument("name", help="Server name") @@ -195,14 +209,49 @@ def handle_tools_command(args: argparse.Namespace): if args.tools_command == "list": # Optionally include MCP-discovered tools by connecting temporarily manager = None - if getattr(args, "include_mcp", False): - from ..mcp.manager import MCPManager + mcp_socket_path = None + try: + from pathlib import Path - manager = MCPManager() - try: - asyncio.run(manager.connect_all()) - except Exception: - pass + mcp_socket_path = Path.home() / ".pentestagent" / "mcp.sock" + except Exception: + mcp_socket_path = None + + if getattr(args, "include_mcp", False): + # Try to query running MCP daemon via unix socket first + tried_socket = False + if mcp_socket_path and mcp_socket_path.exists(): + tried_socket = True + try: + import socket, json + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.connect(str(mcp_socket_path)) + s.sendall((json.dumps({"cmd": "list_tools"}) + "\n").encode("utf-8")) + # Read until EOF + resp = b"" + while True: + part = s.recv(4096) + if not part: + break + resp += part + data = json.loads(resp.decode("utf-8")) + mcp_tools = [] + if data.get("status") == "ok": + mcp_tools = data.get("tools", []) + else: + mcp_tools = [] + except Exception: + tried_socket = False + + if not tried_socket: + from ..mcp.manager import MCPManager + + manager = MCPManager() + try: + asyncio.run(manager.connect_all()) + except Exception: + pass try: tools = get_all_tools() @@ -215,6 +264,18 @@ def handle_tools_command(args: argparse.Namespace): except Exception: pass + # Merge MCP daemon tools (if returned by socket) into displayed list + if 'mcp_tools' in locals() and mcp_tools: + # Create lightweight objects to display alongside registered tools + class _FakeTool: + def __init__(self, name, category, description): + self.name = name + self.category = category + self.description = description + + for t in mcp_tools: + tools.append(_FakeTool(f"mcp_{t.get('server')}_{t.get('name')}", "mcp", t.get("description", ""))) + if not tools: console.print("[yellow]No tools found[/]") return @@ -291,6 +352,63 @@ def handle_tools_command(args: argparse.Namespace): else: console.print("[yellow]Use 'pentestagent tools --help' for commands[/]") + if args.tools_command == "call": + import json, socket + + server = args.server + tool = args.tool + json_args = {} + if args.json_args: + try: + json_args = json.loads(args.json_args) + except Exception as e: + console.print(f"[red]Invalid JSON for --json: {e}[/]") + return + + # Try daemon socket first + from pathlib import Path + sock = Path.home() / ".pentestagent" / "mcp.sock" + if sock.exists(): + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.connect(str(sock)) + s.sendall((json.dumps({"cmd": "call_tool", "server": server, "tool": tool, "args": json_args}) + "\n").encode("utf-8")) + resp = b"" + while True: + part = s.recv(4096) + if not part: + break + resp += part + data = json.loads(resp.decode("utf-8")) + if data.get("status") == "ok": + console.print(f"[green]Tool call succeeded. Result:[/] {data.get('result')}") + else: + console.print(f"[red]Tool call failed: {data.get('error')} {data.get('message','')}[/]") + return + except Exception: + pass + + # Fallback: temporary connect and call + from ..mcp.manager import MCPManager + + manager = MCPManager() + + async def _call(): + sv = await manager.connect_server(server) + if not sv: + raise RuntimeError(f"Failed to connect to server: {server}") + try: + res = await manager.call_tool(server, tool, json_args) + return res + finally: + await manager.disconnect_all() + + try: + res = asyncio.run(_call()) + console.print(f"[green]Tool call succeeded. Result:[/] {res}") + except Exception as e: + console.print(f"[red]Tool call failed: {e}[/]") + def handle_mcp_command(args: argparse.Namespace): """Handle MCP subcommand.""" @@ -408,6 +526,12 @@ def handle_mcp_command(args: argparse.Namespace): console.print(f"[red]Failed to connect: {name}[/]") return + # Start control socket so other CLI invocations can query daemon + try: + await manager.start_control_server() + except Exception: + pass + console.print("[green]Connected. Press Ctrl-C to stop and disconnect.[/]") await stop_event.wait() @@ -416,6 +540,10 @@ def handle_mcp_command(args: argparse.Namespace): await manager.disconnect_all() except Exception: pass + try: + await manager.stop_control_server() + except Exception: + pass # If detach requested, perform a simple double-fork to daemonize if detach: @@ -519,6 +647,50 @@ def handle_mcp_command(args: argparse.Namespace): asyncio.run(run_disconnect()) + elif args.mcp_command == "status": + # Try querying the daemon socket + from pathlib import Path + import socket, json + + sock = Path.home() / ".pentestagent" / "mcp.sock" + if sock.exists(): + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.connect(str(sock)) + s.sendall((json.dumps({"cmd": "status"}) + "\n").encode("utf-8")) + resp = b"" + while True: + part = s.recv(4096) + if not part: + break + resp += part + data = json.loads(resp.decode("utf-8")) + if data.get("status") == "ok": + rows = data.get("servers", []) + if not rows: + console.print("[yellow]No MCP servers connected[/]") + return + table = Table(title="MCP Daemon Status") + table.add_column("Name") + table.add_column("Connected") + table.add_column("Tools") + for r in rows: + table.add_row(r.get("name"), "+" if r.get("connected") else "-", str(r.get("tool_count", 0))) + console.print(table) + return + except Exception: + pass + + # Fallback: show configured servers and whether manager can see them + servers = manager.list_configured_servers() + table = Table(title="Configured MCP Servers") + table.add_column("Name") + table.add_column("Command") + table.add_column("Connected") + for s in servers: + table.add_row(s.get("name"), s.get("command"), "+" if s.get("connected") else "-") + console.print(table) + else: console.print("[yellow]Use 'pentestagent mcp --help' for available commands[/]") diff --git a/pentestagent/mcp/manager.py b/pentestagent/mcp/manager.py index c553324..93ff048 100644 --- a/pentestagent/mcp/manager.py +++ b/pentestagent/mcp/manager.py @@ -76,6 +76,10 @@ class MCPManager: # Track adapters we auto-started so we can stop them later self._started_adapters: Dict[str, object] = {} self._message_id = 0 + # Control socket server attributes + self._control_server: Optional[asyncio.AbstractServer] = None + self._control_task: Optional[asyncio.Task] = None + self._control_path: Optional[Path] = None # Ensure we attempt to clean up vendored servers on process exit try: atexit.register(self._atexit_cleanup) @@ -523,3 +527,129 @@ class MCPManager: def is_connected(self, name: str) -> bool: server = self.servers.get(name) return server is not None and server.connected + + async def start_control_server(self, path: Optional[str] = None) -> str: + """Start a lightweight UNIX-domain socket control server. + + The control server accepts newline-delimited JSON requests. Supported + commands: + {"cmd": "status"} -> returns connected servers and counts + {"cmd": "list_tools"} -> returns list of MCP tools (name, server, description) + + Returns the path of the socket in use. + """ + if not path: + path = str(Path.home() / ".pentestagent" / "mcp.sock") + + sock_path = Path(path) + # Ensure parent exists + sock_path.parent.mkdir(parents=True, exist_ok=True) + + # Remove stale socket if present + try: + if sock_path.exists(): + try: + sock_path.unlink() + except Exception: + pass + except Exception: + pass + + async def _handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + try: + data = await reader.readline() + if not data: + return + try: + req = json.loads(data.decode("utf-8")) + except Exception: + resp = {"status": "error", "error": "invalid_json"} + writer.write((json.dumps(resp) + "\n").encode("utf-8")) + await writer.drain() + return + + cmd = req.get("cmd") if isinstance(req, dict) else None + if cmd == "status": + servers = [] + for name, s in self.servers.items(): + servers.append({"name": name, "connected": bool(s.connected), "tool_count": len(s.tools)}) + resp = {"status": "ok", "servers": servers} + writer.write((json.dumps(resp) + "\n").encode("utf-8")) + await writer.drain() + elif cmd == "list_tools": + tools = [] + for sname, s in self.servers.items(): + for t in s.tools: + tools.append({"name": t.get("name"), "server": sname, "description": t.get("description", "")}) + resp = {"status": "ok", "tools": tools} + writer.write((json.dumps(resp) + "\n").encode("utf-8")) + await writer.drain() + elif cmd == "call_tool": + # Expecting: {"cmd":"call_tool","server":"name","tool":"tool_name","args":{...}} + server_name = req.get("server") + tool_name = req.get("tool") + arguments = req.get("args", {}) if isinstance(req.get("args", {}), dict) else {} + if not server_name or not tool_name: + writer.write((json.dumps({"status": "error", "error": "missing_parameters"}) + "\n").encode("utf-8")) + await writer.drain() + return + try: + # perform the tool call + result = await self.call_tool(server_name, tool_name, arguments) + writer.write((json.dumps({"status": "ok", "result": result}) + "\n").encode("utf-8")) + await writer.drain() + except Exception as e: + writer.write((json.dumps({"status": "error", "error": "call_failed", "message": str(e)}) + "\n").encode("utf-8")) + await writer.drain() + + else: + resp = {"status": "error", "error": "unknown_cmd"} + writer.write((json.dumps(resp) + "\n").encode("utf-8")) + await writer.drain() + except Exception: + try: + writer.write((json.dumps({"status": "error", "error": "internal"}) + "\n").encode("utf-8")) + await writer.drain() + except Exception: + pass + finally: + try: + writer.close() + except Exception: + pass + + # Start the asyncio unix server + loop = asyncio.get_running_loop() + server = await asyncio.start_unix_server(_handle, path=path) + self._control_server = server + self._control_path = Path(path) + + # Keep server serving in background task + self._control_task = loop.create_task(server.serve_forever()) + # Restrict socket access to current user where possible + try: + os.chmod(path, 0o600) + except Exception: + pass + return path + + async def stop_control_server(self): + try: + if self._control_task: + self._control_task.cancel() + self._control_task = None + if self._control_server: + self._control_server.close() + try: + await self._control_server.wait_closed() + except Exception: + pass + self._control_server = None + if self._control_path and self._control_path.exists(): + try: + self._control_path.unlink() + except Exception: + pass + self._control_path = None + except Exception: + pass