mcp: add unix socket control + tools call/list via daemon

This commit is contained in:
giveen
2026-01-21 11:06:36 -07:00
parent f3f3b0956b
commit 080a32a8fa
2 changed files with 309 additions and 7 deletions

View File

@@ -94,6 +94,17 @@ Examples:
help="Temporarily connect to configured MCP servers and include their tools",
)
# tools call
tools_call = tools_subparsers.add_parser("call", help="Call a tool (via MCP daemon if available)")
tools_call.add_argument("server", help="MCP server name")
tools_call.add_argument("tool", help="Tool name")
tools_call.add_argument(
"--json",
dest="json_args",
help="JSON string of arguments to pass to the tool",
default=None,
)
# tools info
tools_info = tools_subparsers.add_parser("info", help="Show tool details")
tools_info.add_argument("name", help="Tool name")
@@ -108,6 +119,9 @@ Examples:
# mcp list
mcp_subparsers.add_parser("list", help="List configured MCP servers")
# mcp status
mcp_subparsers.add_parser("status", help="Show MCP daemon status (socket)" )
# mcp add
mcp_add = mcp_subparsers.add_parser("add", help="Add an MCP server")
mcp_add.add_argument("name", help="Server name")
@@ -195,14 +209,49 @@ def handle_tools_command(args: argparse.Namespace):
if args.tools_command == "list":
# Optionally include MCP-discovered tools by connecting temporarily
manager = None
if getattr(args, "include_mcp", False):
from ..mcp.manager import MCPManager
mcp_socket_path = None
try:
from pathlib import Path
manager = MCPManager()
try:
asyncio.run(manager.connect_all())
except Exception:
pass
mcp_socket_path = Path.home() / ".pentestagent" / "mcp.sock"
except Exception:
mcp_socket_path = None
if getattr(args, "include_mcp", False):
# Try to query running MCP daemon via unix socket first
tried_socket = False
if mcp_socket_path and mcp_socket_path.exists():
tried_socket = True
try:
import socket, json
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
s.connect(str(mcp_socket_path))
s.sendall((json.dumps({"cmd": "list_tools"}) + "\n").encode("utf-8"))
# Read until EOF
resp = b""
while True:
part = s.recv(4096)
if not part:
break
resp += part
data = json.loads(resp.decode("utf-8"))
mcp_tools = []
if data.get("status") == "ok":
mcp_tools = data.get("tools", [])
else:
mcp_tools = []
except Exception:
tried_socket = False
if not tried_socket:
from ..mcp.manager import MCPManager
manager = MCPManager()
try:
asyncio.run(manager.connect_all())
except Exception:
pass
try:
tools = get_all_tools()
@@ -215,6 +264,18 @@ def handle_tools_command(args: argparse.Namespace):
except Exception:
pass
# Merge MCP daemon tools (if returned by socket) into displayed list
if 'mcp_tools' in locals() and mcp_tools:
# Create lightweight objects to display alongside registered tools
class _FakeTool:
def __init__(self, name, category, description):
self.name = name
self.category = category
self.description = description
for t in mcp_tools:
tools.append(_FakeTool(f"mcp_{t.get('server')}_{t.get('name')}", "mcp", t.get("description", "")))
if not tools:
console.print("[yellow]No tools found[/]")
return
@@ -291,6 +352,63 @@ def handle_tools_command(args: argparse.Namespace):
else:
console.print("[yellow]Use 'pentestagent tools --help' for commands[/]")
if args.tools_command == "call":
import json, socket
server = args.server
tool = args.tool
json_args = {}
if args.json_args:
try:
json_args = json.loads(args.json_args)
except Exception as e:
console.print(f"[red]Invalid JSON for --json: {e}[/]")
return
# Try daemon socket first
from pathlib import Path
sock = Path.home() / ".pentestagent" / "mcp.sock"
if sock.exists():
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
s.connect(str(sock))
s.sendall((json.dumps({"cmd": "call_tool", "server": server, "tool": tool, "args": json_args}) + "\n").encode("utf-8"))
resp = b""
while True:
part = s.recv(4096)
if not part:
break
resp += part
data = json.loads(resp.decode("utf-8"))
if data.get("status") == "ok":
console.print(f"[green]Tool call succeeded. Result:[/] {data.get('result')}")
else:
console.print(f"[red]Tool call failed: {data.get('error')} {data.get('message','')}[/]")
return
except Exception:
pass
# Fallback: temporary connect and call
from ..mcp.manager import MCPManager
manager = MCPManager()
async def _call():
sv = await manager.connect_server(server)
if not sv:
raise RuntimeError(f"Failed to connect to server: {server}")
try:
res = await manager.call_tool(server, tool, json_args)
return res
finally:
await manager.disconnect_all()
try:
res = asyncio.run(_call())
console.print(f"[green]Tool call succeeded. Result:[/] {res}")
except Exception as e:
console.print(f"[red]Tool call failed: {e}[/]")
def handle_mcp_command(args: argparse.Namespace):
"""Handle MCP subcommand."""
@@ -408,6 +526,12 @@ def handle_mcp_command(args: argparse.Namespace):
console.print(f"[red]Failed to connect: {name}[/]")
return
# Start control socket so other CLI invocations can query daemon
try:
await manager.start_control_server()
except Exception:
pass
console.print("[green]Connected. Press Ctrl-C to stop and disconnect.[/]")
await stop_event.wait()
@@ -416,6 +540,10 @@ def handle_mcp_command(args: argparse.Namespace):
await manager.disconnect_all()
except Exception:
pass
try:
await manager.stop_control_server()
except Exception:
pass
# If detach requested, perform a simple double-fork to daemonize
if detach:
@@ -519,6 +647,50 @@ def handle_mcp_command(args: argparse.Namespace):
asyncio.run(run_disconnect())
elif args.mcp_command == "status":
# Try querying the daemon socket
from pathlib import Path
import socket, json
sock = Path.home() / ".pentestagent" / "mcp.sock"
if sock.exists():
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
s.connect(str(sock))
s.sendall((json.dumps({"cmd": "status"}) + "\n").encode("utf-8"))
resp = b""
while True:
part = s.recv(4096)
if not part:
break
resp += part
data = json.loads(resp.decode("utf-8"))
if data.get("status") == "ok":
rows = data.get("servers", [])
if not rows:
console.print("[yellow]No MCP servers connected[/]")
return
table = Table(title="MCP Daemon Status")
table.add_column("Name")
table.add_column("Connected")
table.add_column("Tools")
for r in rows:
table.add_row(r.get("name"), "+" if r.get("connected") else "-", str(r.get("tool_count", 0)))
console.print(table)
return
except Exception:
pass
# Fallback: show configured servers and whether manager can see them
servers = manager.list_configured_servers()
table = Table(title="Configured MCP Servers")
table.add_column("Name")
table.add_column("Command")
table.add_column("Connected")
for s in servers:
table.add_row(s.get("name"), s.get("command"), "+" if s.get("connected") else "-")
console.print(table)
else:
console.print("[yellow]Use 'pentestagent mcp --help' for available commands[/]")

View File

@@ -76,6 +76,10 @@ class MCPManager:
# Track adapters we auto-started so we can stop them later
self._started_adapters: Dict[str, object] = {}
self._message_id = 0
# Control socket server attributes
self._control_server: Optional[asyncio.AbstractServer] = None
self._control_task: Optional[asyncio.Task] = None
self._control_path: Optional[Path] = None
# Ensure we attempt to clean up vendored servers on process exit
try:
atexit.register(self._atexit_cleanup)
@@ -523,3 +527,129 @@ class MCPManager:
def is_connected(self, name: str) -> bool:
server = self.servers.get(name)
return server is not None and server.connected
async def start_control_server(self, path: Optional[str] = None) -> str:
"""Start a lightweight UNIX-domain socket control server.
The control server accepts newline-delimited JSON requests. Supported
commands:
{"cmd": "status"} -> returns connected servers and counts
{"cmd": "list_tools"} -> returns list of MCP tools (name, server, description)
Returns the path of the socket in use.
"""
if not path:
path = str(Path.home() / ".pentestagent" / "mcp.sock")
sock_path = Path(path)
# Ensure parent exists
sock_path.parent.mkdir(parents=True, exist_ok=True)
# Remove stale socket if present
try:
if sock_path.exists():
try:
sock_path.unlink()
except Exception:
pass
except Exception:
pass
async def _handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
try:
data = await reader.readline()
if not data:
return
try:
req = json.loads(data.decode("utf-8"))
except Exception:
resp = {"status": "error", "error": "invalid_json"}
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
await writer.drain()
return
cmd = req.get("cmd") if isinstance(req, dict) else None
if cmd == "status":
servers = []
for name, s in self.servers.items():
servers.append({"name": name, "connected": bool(s.connected), "tool_count": len(s.tools)})
resp = {"status": "ok", "servers": servers}
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
await writer.drain()
elif cmd == "list_tools":
tools = []
for sname, s in self.servers.items():
for t in s.tools:
tools.append({"name": t.get("name"), "server": sname, "description": t.get("description", "")})
resp = {"status": "ok", "tools": tools}
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
await writer.drain()
elif cmd == "call_tool":
# Expecting: {"cmd":"call_tool","server":"name","tool":"tool_name","args":{...}}
server_name = req.get("server")
tool_name = req.get("tool")
arguments = req.get("args", {}) if isinstance(req.get("args", {}), dict) else {}
if not server_name or not tool_name:
writer.write((json.dumps({"status": "error", "error": "missing_parameters"}) + "\n").encode("utf-8"))
await writer.drain()
return
try:
# perform the tool call
result = await self.call_tool(server_name, tool_name, arguments)
writer.write((json.dumps({"status": "ok", "result": result}) + "\n").encode("utf-8"))
await writer.drain()
except Exception as e:
writer.write((json.dumps({"status": "error", "error": "call_failed", "message": str(e)}) + "\n").encode("utf-8"))
await writer.drain()
else:
resp = {"status": "error", "error": "unknown_cmd"}
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
await writer.drain()
except Exception:
try:
writer.write((json.dumps({"status": "error", "error": "internal"}) + "\n").encode("utf-8"))
await writer.drain()
except Exception:
pass
finally:
try:
writer.close()
except Exception:
pass
# Start the asyncio unix server
loop = asyncio.get_running_loop()
server = await asyncio.start_unix_server(_handle, path=path)
self._control_server = server
self._control_path = Path(path)
# Keep server serving in background task
self._control_task = loop.create_task(server.serve_forever())
# Restrict socket access to current user where possible
try:
os.chmod(path, 0o600)
except Exception:
pass
return path
async def stop_control_server(self):
try:
if self._control_task:
self._control_task.cancel()
self._control_task = None
if self._control_server:
self._control_server.close()
try:
await self._control_server.wait_closed()
except Exception:
pass
self._control_server = None
if self._control_path and self._control_path.exists():
try:
self._control_path.unlink()
except Exception:
pass
self._control_path = None
except Exception:
pass