mirror of
https://github.com/GH05TCREW/pentestagent.git
synced 2026-03-08 06:44:11 +00:00
mcp: add unix socket control + tools call/list via daemon
This commit is contained in:
@@ -94,6 +94,17 @@ Examples:
|
||||
help="Temporarily connect to configured MCP servers and include their tools",
|
||||
)
|
||||
|
||||
# tools call
|
||||
tools_call = tools_subparsers.add_parser("call", help="Call a tool (via MCP daemon if available)")
|
||||
tools_call.add_argument("server", help="MCP server name")
|
||||
tools_call.add_argument("tool", help="Tool name")
|
||||
tools_call.add_argument(
|
||||
"--json",
|
||||
dest="json_args",
|
||||
help="JSON string of arguments to pass to the tool",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# tools info
|
||||
tools_info = tools_subparsers.add_parser("info", help="Show tool details")
|
||||
tools_info.add_argument("name", help="Tool name")
|
||||
@@ -108,6 +119,9 @@ Examples:
|
||||
# mcp list
|
||||
mcp_subparsers.add_parser("list", help="List configured MCP servers")
|
||||
|
||||
# mcp status
|
||||
mcp_subparsers.add_parser("status", help="Show MCP daemon status (socket)" )
|
||||
|
||||
# mcp add
|
||||
mcp_add = mcp_subparsers.add_parser("add", help="Add an MCP server")
|
||||
mcp_add.add_argument("name", help="Server name")
|
||||
@@ -195,14 +209,49 @@ def handle_tools_command(args: argparse.Namespace):
|
||||
if args.tools_command == "list":
|
||||
# Optionally include MCP-discovered tools by connecting temporarily
|
||||
manager = None
|
||||
if getattr(args, "include_mcp", False):
|
||||
from ..mcp.manager import MCPManager
|
||||
mcp_socket_path = None
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
manager = MCPManager()
|
||||
try:
|
||||
asyncio.run(manager.connect_all())
|
||||
except Exception:
|
||||
pass
|
||||
mcp_socket_path = Path.home() / ".pentestagent" / "mcp.sock"
|
||||
except Exception:
|
||||
mcp_socket_path = None
|
||||
|
||||
if getattr(args, "include_mcp", False):
|
||||
# Try to query running MCP daemon via unix socket first
|
||||
tried_socket = False
|
||||
if mcp_socket_path and mcp_socket_path.exists():
|
||||
tried_socket = True
|
||||
try:
|
||||
import socket, json
|
||||
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
|
||||
s.connect(str(mcp_socket_path))
|
||||
s.sendall((json.dumps({"cmd": "list_tools"}) + "\n").encode("utf-8"))
|
||||
# Read until EOF
|
||||
resp = b""
|
||||
while True:
|
||||
part = s.recv(4096)
|
||||
if not part:
|
||||
break
|
||||
resp += part
|
||||
data = json.loads(resp.decode("utf-8"))
|
||||
mcp_tools = []
|
||||
if data.get("status") == "ok":
|
||||
mcp_tools = data.get("tools", [])
|
||||
else:
|
||||
mcp_tools = []
|
||||
except Exception:
|
||||
tried_socket = False
|
||||
|
||||
if not tried_socket:
|
||||
from ..mcp.manager import MCPManager
|
||||
|
||||
manager = MCPManager()
|
||||
try:
|
||||
asyncio.run(manager.connect_all())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
tools = get_all_tools()
|
||||
@@ -215,6 +264,18 @@ def handle_tools_command(args: argparse.Namespace):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Merge MCP daemon tools (if returned by socket) into displayed list
|
||||
if 'mcp_tools' in locals() and mcp_tools:
|
||||
# Create lightweight objects to display alongside registered tools
|
||||
class _FakeTool:
|
||||
def __init__(self, name, category, description):
|
||||
self.name = name
|
||||
self.category = category
|
||||
self.description = description
|
||||
|
||||
for t in mcp_tools:
|
||||
tools.append(_FakeTool(f"mcp_{t.get('server')}_{t.get('name')}", "mcp", t.get("description", "")))
|
||||
|
||||
if not tools:
|
||||
console.print("[yellow]No tools found[/]")
|
||||
return
|
||||
@@ -291,6 +352,63 @@ def handle_tools_command(args: argparse.Namespace):
|
||||
else:
|
||||
console.print("[yellow]Use 'pentestagent tools --help' for commands[/]")
|
||||
|
||||
if args.tools_command == "call":
|
||||
import json, socket
|
||||
|
||||
server = args.server
|
||||
tool = args.tool
|
||||
json_args = {}
|
||||
if args.json_args:
|
||||
try:
|
||||
json_args = json.loads(args.json_args)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Invalid JSON for --json: {e}[/]")
|
||||
return
|
||||
|
||||
# Try daemon socket first
|
||||
from pathlib import Path
|
||||
sock = Path.home() / ".pentestagent" / "mcp.sock"
|
||||
if sock.exists():
|
||||
try:
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
|
||||
s.connect(str(sock))
|
||||
s.sendall((json.dumps({"cmd": "call_tool", "server": server, "tool": tool, "args": json_args}) + "\n").encode("utf-8"))
|
||||
resp = b""
|
||||
while True:
|
||||
part = s.recv(4096)
|
||||
if not part:
|
||||
break
|
||||
resp += part
|
||||
data = json.loads(resp.decode("utf-8"))
|
||||
if data.get("status") == "ok":
|
||||
console.print(f"[green]Tool call succeeded. Result:[/] {data.get('result')}")
|
||||
else:
|
||||
console.print(f"[red]Tool call failed: {data.get('error')} {data.get('message','')}[/]")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: temporary connect and call
|
||||
from ..mcp.manager import MCPManager
|
||||
|
||||
manager = MCPManager()
|
||||
|
||||
async def _call():
|
||||
sv = await manager.connect_server(server)
|
||||
if not sv:
|
||||
raise RuntimeError(f"Failed to connect to server: {server}")
|
||||
try:
|
||||
res = await manager.call_tool(server, tool, json_args)
|
||||
return res
|
||||
finally:
|
||||
await manager.disconnect_all()
|
||||
|
||||
try:
|
||||
res = asyncio.run(_call())
|
||||
console.print(f"[green]Tool call succeeded. Result:[/] {res}")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Tool call failed: {e}[/]")
|
||||
|
||||
|
||||
def handle_mcp_command(args: argparse.Namespace):
|
||||
"""Handle MCP subcommand."""
|
||||
@@ -408,6 +526,12 @@ def handle_mcp_command(args: argparse.Namespace):
|
||||
console.print(f"[red]Failed to connect: {name}[/]")
|
||||
return
|
||||
|
||||
# Start control socket so other CLI invocations can query daemon
|
||||
try:
|
||||
await manager.start_control_server()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
console.print("[green]Connected. Press Ctrl-C to stop and disconnect.[/]")
|
||||
await stop_event.wait()
|
||||
|
||||
@@ -416,6 +540,10 @@ def handle_mcp_command(args: argparse.Namespace):
|
||||
await manager.disconnect_all()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await manager.stop_control_server()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If detach requested, perform a simple double-fork to daemonize
|
||||
if detach:
|
||||
@@ -519,6 +647,50 @@ def handle_mcp_command(args: argparse.Namespace):
|
||||
|
||||
asyncio.run(run_disconnect())
|
||||
|
||||
elif args.mcp_command == "status":
|
||||
# Try querying the daemon socket
|
||||
from pathlib import Path
|
||||
import socket, json
|
||||
|
||||
sock = Path.home() / ".pentestagent" / "mcp.sock"
|
||||
if sock.exists():
|
||||
try:
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
|
||||
s.connect(str(sock))
|
||||
s.sendall((json.dumps({"cmd": "status"}) + "\n").encode("utf-8"))
|
||||
resp = b""
|
||||
while True:
|
||||
part = s.recv(4096)
|
||||
if not part:
|
||||
break
|
||||
resp += part
|
||||
data = json.loads(resp.decode("utf-8"))
|
||||
if data.get("status") == "ok":
|
||||
rows = data.get("servers", [])
|
||||
if not rows:
|
||||
console.print("[yellow]No MCP servers connected[/]")
|
||||
return
|
||||
table = Table(title="MCP Daemon Status")
|
||||
table.add_column("Name")
|
||||
table.add_column("Connected")
|
||||
table.add_column("Tools")
|
||||
for r in rows:
|
||||
table.add_row(r.get("name"), "+" if r.get("connected") else "-", str(r.get("tool_count", 0)))
|
||||
console.print(table)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: show configured servers and whether manager can see them
|
||||
servers = manager.list_configured_servers()
|
||||
table = Table(title="Configured MCP Servers")
|
||||
table.add_column("Name")
|
||||
table.add_column("Command")
|
||||
table.add_column("Connected")
|
||||
for s in servers:
|
||||
table.add_row(s.get("name"), s.get("command"), "+" if s.get("connected") else "-")
|
||||
console.print(table)
|
||||
|
||||
else:
|
||||
console.print("[yellow]Use 'pentestagent mcp --help' for available commands[/]")
|
||||
|
||||
|
||||
@@ -76,6 +76,10 @@ class MCPManager:
|
||||
# Track adapters we auto-started so we can stop them later
|
||||
self._started_adapters: Dict[str, object] = {}
|
||||
self._message_id = 0
|
||||
# Control socket server attributes
|
||||
self._control_server: Optional[asyncio.AbstractServer] = None
|
||||
self._control_task: Optional[asyncio.Task] = None
|
||||
self._control_path: Optional[Path] = None
|
||||
# Ensure we attempt to clean up vendored servers on process exit
|
||||
try:
|
||||
atexit.register(self._atexit_cleanup)
|
||||
@@ -523,3 +527,129 @@ class MCPManager:
|
||||
def is_connected(self, name: str) -> bool:
|
||||
server = self.servers.get(name)
|
||||
return server is not None and server.connected
|
||||
|
||||
async def start_control_server(self, path: Optional[str] = None) -> str:
|
||||
"""Start a lightweight UNIX-domain socket control server.
|
||||
|
||||
The control server accepts newline-delimited JSON requests. Supported
|
||||
commands:
|
||||
{"cmd": "status"} -> returns connected servers and counts
|
||||
{"cmd": "list_tools"} -> returns list of MCP tools (name, server, description)
|
||||
|
||||
Returns the path of the socket in use.
|
||||
"""
|
||||
if not path:
|
||||
path = str(Path.home() / ".pentestagent" / "mcp.sock")
|
||||
|
||||
sock_path = Path(path)
|
||||
# Ensure parent exists
|
||||
sock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Remove stale socket if present
|
||||
try:
|
||||
if sock_path.exists():
|
||||
try:
|
||||
sock_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
try:
|
||||
data = await reader.readline()
|
||||
if not data:
|
||||
return
|
||||
try:
|
||||
req = json.loads(data.decode("utf-8"))
|
||||
except Exception:
|
||||
resp = {"status": "error", "error": "invalid_json"}
|
||||
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
return
|
||||
|
||||
cmd = req.get("cmd") if isinstance(req, dict) else None
|
||||
if cmd == "status":
|
||||
servers = []
|
||||
for name, s in self.servers.items():
|
||||
servers.append({"name": name, "connected": bool(s.connected), "tool_count": len(s.tools)})
|
||||
resp = {"status": "ok", "servers": servers}
|
||||
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
elif cmd == "list_tools":
|
||||
tools = []
|
||||
for sname, s in self.servers.items():
|
||||
for t in s.tools:
|
||||
tools.append({"name": t.get("name"), "server": sname, "description": t.get("description", "")})
|
||||
resp = {"status": "ok", "tools": tools}
|
||||
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
elif cmd == "call_tool":
|
||||
# Expecting: {"cmd":"call_tool","server":"name","tool":"tool_name","args":{...}}
|
||||
server_name = req.get("server")
|
||||
tool_name = req.get("tool")
|
||||
arguments = req.get("args", {}) if isinstance(req.get("args", {}), dict) else {}
|
||||
if not server_name or not tool_name:
|
||||
writer.write((json.dumps({"status": "error", "error": "missing_parameters"}) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
return
|
||||
try:
|
||||
# perform the tool call
|
||||
result = await self.call_tool(server_name, tool_name, arguments)
|
||||
writer.write((json.dumps({"status": "ok", "result": result}) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
except Exception as e:
|
||||
writer.write((json.dumps({"status": "error", "error": "call_failed", "message": str(e)}) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
|
||||
else:
|
||||
resp = {"status": "error", "error": "unknown_cmd"}
|
||||
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
except Exception:
|
||||
try:
|
||||
writer.write((json.dumps({"status": "error", "error": "internal"}) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start the asyncio unix server
|
||||
loop = asyncio.get_running_loop()
|
||||
server = await asyncio.start_unix_server(_handle, path=path)
|
||||
self._control_server = server
|
||||
self._control_path = Path(path)
|
||||
|
||||
# Keep server serving in background task
|
||||
self._control_task = loop.create_task(server.serve_forever())
|
||||
# Restrict socket access to current user where possible
|
||||
try:
|
||||
os.chmod(path, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return path
|
||||
|
||||
async def stop_control_server(self):
|
||||
try:
|
||||
if self._control_task:
|
||||
self._control_task.cancel()
|
||||
self._control_task = None
|
||||
if self._control_server:
|
||||
self._control_server.close()
|
||||
try:
|
||||
await self._control_server.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self._control_server = None
|
||||
if self._control_path and self._control_path.exists():
|
||||
try:
|
||||
self._control_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
self._control_path = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user