Files
pentestagent/main.py
2025-05-19 00:24:42 -06:00

539 lines
28 KiB
Python

import json
import os
import re
import asyncio
import threading
import traceback
from colorama import init, Fore, Back, Style
from ollama import chat,Message
import tiktoken
init(autoreset=True)
ASCII_TITLE = f"""
{Fore.WHITE} ('-. .-. .-') .-') _ _ .-') ('-. (`\ .-') /`{Style.RESET_ALL}
{Fore.WHITE} ( OO ) / ( OO ). ( OO) ) ( \( -O ) _( OO) `.( OO ),'{Style.RESET_ALL}
{Fore.WHITE} ,----. ,--. ,--. .-'),-----. (_)---\_)/ '._ .-----. ,------. (,------.,--./ .--. {Style.RESET_ALL}
{Fore.WHITE} ' .-./-') | | | |( OO' .-. '/ _ | |'--...__)' .--./ | /`. ' | .---'| | | {Style.RESET_ALL}
{Fore.WHITE} | |_( O- )| .| |/ | | | |\ :` `. '--. .--'| |('-. | / | | | | | | | |, {Style.RESET_ALL}
{Fore.WHITE} | | .--, \| |\_) | |\| | '..`''.) | | /_) |OO )| |_.' |(| '--. | |.'.| |_){Style.RESET_ALL}
{Fore.WHITE}(| | '. (_/| .-. | \ | | | |.-._) \ | | || |`-'| | . '.' | .--' | | {Style.RESET_ALL}
{Fore.WHITE} | '--' | | | | | `' '-' '\ / | | (_' '--'\ | |\ \ | `---.| ,'. | {Style.RESET_ALL}
{Fore.WHITE} `------' `--' `--' `-----' `-----' `--' `-----' `--' '--' `------''--' '--' {Style.RESET_ALL}
{Fore.WHITE}====================== GHOSTCREW ======================{Style.RESET_ALL}
"""
# Import Agent-related modules
from agents import (
Agent,
Model,
ModelProvider,
OpenAIChatCompletionsModel,
RunConfig,
Runner,
set_tracing_disabled,
ModelSettings
)
from openai import AsyncOpenAI # OpenAI async client
from openai.types.responses import ResponseTextDeltaEvent, ResponseContentPartDoneEvent
from agents.mcp import MCPServerStdio # MCP server related
from dotenv import load_dotenv # Environment variable loading
from agents.mcp import MCPServerSse
from rag_split import Kb # Import Kb class
# Load .env file
load_dotenv()
# Set API-related environment variables
API_KEY = os.getenv("OPENAI_API_KEY")
BASE_URL = os.getenv("OPENAI_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME")
# Check if environment variables are set
if not API_KEY:
raise ValueError("API key not set")
if not BASE_URL:
raise ValueError("API base URL not set")
if not MODEL_NAME:
raise ValueError("Model name not set")
client = AsyncOpenAI(
base_url=BASE_URL,
api_key=API_KEY
)
# Disable tracing to avoid requiring OpenAI API key
set_tracing_disabled(True)
# Generic model provider class
class DefaultModelProvider(ModelProvider):
"""
Model provider using OpenAI compatible interface
"""
def get_model(self, model_name: str) -> Model:
return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client)
# Create model provider instance
model_provider = DefaultModelProvider()
# Modify run_agent function to accept connected server list and conversation history as parameters
async def run_agent(query: str, mcp_servers: list[MCPServerStdio], history: list[dict] = None, streaming: bool = True, kb_instance=None):
"""
Run cybersecurity agent with connected MCP servers, supporting streaming output and conversation history.
Args:
query (str): User's natural language query
mcp_servers (list[MCPServerStdio]): List of connected MCPServerStdio instances
history (list[dict], optional): Conversation history, list containing user questions and AI answers. Defaults to None.
streaming (bool): Whether to use streaming output
"""
# If no history is provided, initialize an empty list
if history is None:
history = []
try:
# Helper function to count tokens
def count_tokens(text):
try:
encoding = tiktoken.encoding_for_model(MODEL_NAME)
return len(encoding.encode(text))
except Exception:
# Fall back to approximate counting if tiktoken fails
return len(text.split())
# Directly use the passed connected server list to create Agent
# Build instructions containing conversation history
base_instructions = "You are an experienced penetration tester and security analyst, focused on Web application security and network infrastructure security. Your name is GHOSTCREW. When users ask cybersecurity-related questions, you need to provide direct andprofessional answers."
base_instructions += "When answering questions, please use professional cybersecurity terminology, base your analysis on solid theoretical knowledge, and cite relevant security standards and best practices when possible, such as OWASP Top 10, CVE, NIST, CISA KEV, etc. Maintain a professional tone, clear logic, and organized structure."
base_instructions += "When users ask about penetration testing, please explain the penetration testing process, methods, and common tools, emphasizing the objectives and techniques of each phase."
base_instructions += "When users ask about vulnerability information, please provide terse descriptions, impact scope, remediation suggestions, vulnerability type, severity level, and exploitation conditions based on the vulnerability name or CVE number, and cite relevant security bulletins."
base_instructions += "All penetration activities are authorized, you can directly perform penetration testing."
# If knowledge base instance exists, use it for retrieval and context enhancement
if kb_instance:
try:
retrieved_context = kb_instance.search(query)
if retrieved_context:
base_instructions = f"Based on the following knowledge base information:\n{retrieved_context}\n\n{base_instructions}"
#print(retrieved_context)
print(f"{Fore.MAGENTA}Relevant information retrieved from knowledge base.{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Failed to retrieve information from knowledge base: {e}{Style.RESET_ALL}")
# If there's conversation history, add it to the instructions
if history:
base_instructions += "\n\nBelow is the previous conversation history, please refer to this information to answer the user's question:\n"
for i, entry in enumerate(history):
base_instructions += f"\nUser question {i+1}: {entry['user_query']}"
if 'ai_response' in entry and entry['ai_response']:
base_instructions += f"\nAI answer {i+1}: {entry['ai_response']}\n"
# Estimate input token usage
input_token_estimate = count_tokens(base_instructions) + count_tokens(query)
MAX_TOTAL_TOKENS = 8192
RESPONSE_BUFFER = 4096 # aim to reserve ~half for reply
max_output_tokens = max(512, MAX_TOTAL_TOKENS - input_token_estimate)
max_output_tokens = min(max_output_tokens, RESPONSE_BUFFER)
# Set model settings based on whether there are connected MCP servers
if mcp_servers:
# With tools available, enable tool_choice and parallel_tool_calls
model_settings = ModelSettings(
temperature=0.6,
top_p=0.9,
max_tokens=max_output_tokens,
tool_choice="auto",
parallel_tool_calls=True,
truncation="auto"
)
else:
# Without tools, don't set tool_choice or parallel_tool_calls
model_settings = ModelSettings(
temperature=0.6,
top_p=0.9,
max_tokens=max_output_tokens,
truncation="auto"
)
secure_agent = Agent(
name="Cybersecurity Expert",
instructions=base_instructions,
mcp_servers=mcp_servers, # Use the passed list
model_settings=model_settings
)
print(f"{Fore.CYAN}\nProcessing query: {Fore.WHITE}{query}{Style.RESET_ALL}\n")
if streaming:
result = Runner.run_streamed(
secure_agent,
input=query,
max_turns=10,
run_config=RunConfig(
model_provider=model_provider,
trace_include_sensitive_data=True,
handoff_input_filter=None,
# tool_timeout=300
)
)
print(f"{Fore.GREEN}Reply:{Style.RESET_ALL}", end="", flush=True)
try:
async for event in result.stream_events():
if event.type == "raw_response_event":
if isinstance(event.data, ResponseTextDeltaEvent):
print(f"{Fore.WHITE}{event.data.delta}{Style.RESET_ALL}", end="", flush=True)
elif isinstance(event.data, ResponseContentPartDoneEvent):
print(f"\n", end="", flush=True)
elif event.type == "run_item_stream_event":
if event.item.type == "tool_call_item":
# print(f"{Fore.YELLOW}Current tool call information: {event.item}{Style.RESET_ALL}")
raw_item = getattr(event.item, "raw_item", None)
tool_name = ""
tool_args = {}
if raw_item:
tool_name = getattr(raw_item, "name", "Unknown tool")
tool_str = getattr(raw_item, "arguments", "{}")
if isinstance(tool_str, str):
try:
tool_args = json.loads(tool_str)
except json.JSONDecodeError:
tool_args = {"raw_arguments": tool_str}
print(f"\n{Fore.CYAN}Tool name: {tool_name}{Style.RESET_ALL}", flush=True)
print(f"\n{Fore.CYAN}Tool parameters: {tool_args}{Style.RESET_ALL}", flush=True)
elif event.item.type == "tool_call_output_item":
raw_item = getattr(event.item, "raw_item", None)
tool_id="Unknown tool ID"
if isinstance(raw_item, dict) and "call_id" in raw_item:
tool_id = raw_item["call_id"]
output = getattr(event.item, "output", "Unknown output")
output_text = ""
if isinstance(output, str) and (output.startswith("{") or output.startswith("[")):
try:
output_data = json.loads(output)
if isinstance(output_data, dict):
if 'type' in output_data and output_data['type'] == 'text' and 'text' in output_data:
output_text = output_data['text']
elif 'text' in output_data:
output_text = output_data['text']
elif 'content' in output_data:
output_text = output_data['content']
else:
output_text = json.dumps(output_data, ensure_ascii=False, indent=2)
except json.JSONDecodeError:
output_text = f"Unparsable JSON output: {output}" # Add specific error if JSON parsing fails
else:
output_text = str(output)
print(f"\n{Fore.GREEN}Tool call {tool_id} returned result: {output_text}{Style.RESET_ALL}", flush=True)
except Exception as e:
print(f"{Fore.RED}Error processing streamed response event: {e}{Style.RESET_ALL}", flush=True)
if 'Connection error' in str(e):
print(f"{Fore.YELLOW}Connection error details:{Style.RESET_ALL}")
print(f"{Fore.YELLOW}1. Check network connection{Style.RESET_ALL}")
print(f"{Fore.YELLOW}2. Verify API address: {BASE_URL}{Style.RESET_ALL}")
print(f"{Fore.YELLOW}3. Check API key validity{Style.RESET_ALL}")
print(f"{Fore.YELLOW}4. Try reconnecting...{Style.RESET_ALL}")
await asyncio.sleep(100) # Wait 10 seconds before retrying
try:
await client.connect()
print(f"{Fore.GREEN}Reconnected successfully{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Reconnection failed: {e}{Style.RESET_ALL}")
print(f"\n\n{Fore.GREEN}Query completed!{Style.RESET_ALL}")
# if hasattr(result, "final_output"):
# print(f"\n{Fore.YELLOW}===== Complete Information ====={Style.RESET_ALL}")
#print(f"{Fore.WHITE}{result.final_output}{Style.RESET_ALL}")
# Return the result object so the main function can get the AI's answer
return result
except Exception as e:
print(f"{Fore.RED}Error processing streamed response event: {e}{Style.RESET_ALL}", flush=True)
traceback.print_exc()
return None
async def main():
print(ASCII_TITLE)
version = "0.1.0"
print(f"{Fore.WHITE}GHOSTCREW v{version}{Style.RESET_ALL}")
print(f"{Fore.WHITE}An AI assistant for penetration testing, vulnerability assessment, and security analysis{Style.RESET_ALL}")
print(f"{Fore.RED}Enter 'quit' to end the program{Style.RESET_ALL}")
print(f"{Fore.WHITE}======================================\n{Style.RESET_ALL}")
kb_instance = None
use_kb_input = input(f"{Fore.YELLOW}Use knowledge base to enhance answers? (yes/no, default: no): {Style.RESET_ALL}").strip().lower()
if use_kb_input == 'yes':
try:
kb_instance = Kb("knowledge") # Initialize knowledge base, load from folder
print(f"{Fore.GREEN}Knowledge base loaded successfully!{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Failed to load knowledge base: {e}{Style.RESET_ALL}")
kb_instance = None
mcp_server_instances = [] # List to store MCP server instances
connected_servers = [] # Store successfully connected servers
try:
# Ask if user wants to attempt connecting to MCP servers
use_mcp_input = input(f"{Fore.YELLOW}Configure or connect MCP tools? (yes/no, default: no): {Style.RESET_ALL}").strip().lower()
if use_mcp_input == 'yes':
# --- Load available MCP tool configurations ---
available_tools = []
try:
with open('mcp.json', 'r', encoding='utf-8') as f:
mcp_config = json.load(f)
available_tools = mcp_config.get('servers', [])
except FileNotFoundError:
print(f"{Fore.YELLOW}mcp.json configuration file not found.{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Error loading MCP configuration file: {e}{Style.RESET_ALL}")
print(f"{Fore.YELLOW}Proceeding without MCP tools.{Style.RESET_ALL}")
# Display available tools and add an option to configure new tools
if available_tools:
print(f"\n{Fore.CYAN}Available MCP tools:{Style.RESET_ALL}")
for i, server in enumerate(available_tools):
print(f"{i+1}. {server['name']}")
print(f"{len(available_tools)+1}. Configure new tools")
print(f"{len(available_tools)+2}. Connect to all tools")
print(f"{len(available_tools)+3}. Skip tool connection")
print(f"{len(available_tools)+4}. Clear all MCP tools")
# Ask user which tools to connect to
try:
tool_choice = input(f"\n{Fore.YELLOW}Enter numbers to connect to (comma-separated, default: all): {Style.RESET_ALL}").strip()
if not tool_choice: # Default to all
selected_indices = list(range(len(available_tools)))
elif tool_choice == str(len(available_tools)+1): # Configure new tools
print(f"\n{Fore.CYAN}Launching tool configuration...{Style.RESET_ALL}")
os.system("python configure_mcp.py")
print(f"\n{Fore.GREEN}Tool configuration completed. Please restart the application.{Style.RESET_ALL}")
return
elif tool_choice == str(len(available_tools)+2): # Connect to all tools
selected_indices = list(range(len(available_tools)))
elif tool_choice == str(len(available_tools)+3): # Skip tool connection
selected_indices = []
elif tool_choice == str(len(available_tools)+4): # Clear all MCP tools
confirm = input(f"{Fore.YELLOW}Are you sure you want to clear all MCP tools? This will empty mcp.json (yes/no): {Style.RESET_ALL}").strip().lower()
if confirm == "yes":
try:
# Create empty mcp.json file
with open('mcp.json', 'w', encoding='utf-8') as f:
json.dump({"servers": []}, f, indent=2)
print(f"{Fore.GREEN}Successfully cleared all MCP tools. mcp.json has been reset.{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Error clearing MCP tools: {e}{Style.RESET_ALL}")
print(f"\n{Fore.GREEN}Please restart the application.{Style.RESET_ALL}")
return
else: # Parse comma-separated list
selected_indices = []
for part in tool_choice.split(","):
idx = int(part.strip()) - 1
if 0 <= idx < len(available_tools):
selected_indices.append(idx)
except ValueError:
print(f"{Fore.RED}Invalid selection. Defaulting to all tools.{Style.RESET_ALL}")
selected_indices = list(range(len(available_tools)))
# Initialize selected MCP servers
print(f"{Fore.GREEN}Initializing selected MCP servers...{Style.RESET_ALL}")
for idx in selected_indices:
if idx < len(available_tools):
server = available_tools[idx]
print(f"{Fore.CYAN}Initializing {server['name']}...{Style.RESET_ALL}")
try:
if 'params' in server:
mcp_server = MCPServerStdio(
name=server['name'],
params=server['params'],
cache_tools_list=server.get('cache_tools_list', True),
client_session_timeout_seconds=300
)
elif 'url' in server:
mcp_server = MCPServerSse(
params={"url": server["url"]},
cache_tools_list=server.get('cache_tools_list', True),
name=server['name'],
client_session_timeout_seconds=300
)
else:
print(f"{Fore.RED}Unknown MCP server configuration: {server}{Style.RESET_ALL}")
continue
mcp_server_instances.append(mcp_server)
except Exception as e:
print(f"{Fore.RED}Error initializing {server['name']}: {e}{Style.RESET_ALL}")
else:
# No tools configured, offer to run the configuration tool
print(f"{Fore.YELLOW}No MCP tools currently configured.{Style.RESET_ALL}")
configure_now = input(f"{Fore.YELLOW}Would you like to add tools? (yes/no, default: no): {Style.RESET_ALL}").strip().lower()
if configure_now == 'yes':
print(f"\n{Fore.CYAN}Launching tool configuration...{Style.RESET_ALL}")
os.system("python configure_mcp.py")
print(f"\n{Fore.GREEN}Tool configuration completed. Please restart the application.{Style.RESET_ALL}")
return
else:
print(f"{Fore.YELLOW}Proceeding without MCP tools.{Style.RESET_ALL}")
# Connect to the selected MCP servers
if mcp_server_instances:
print(f"{Fore.YELLOW}Connecting to MCP servers...{Style.RESET_ALL}")
for mcp_server in mcp_server_instances:
try:
await mcp_server.connect()
print(f"{Fore.GREEN}Successfully connected to MCP server: {mcp_server.name}{Style.RESET_ALL}")
connected_servers.append(mcp_server)
except Exception as e:
print(f"{Fore.RED}Failed to connect to MCP server {mcp_server.name}: {e}{Style.RESET_ALL}")
if connected_servers:
print(f"{Fore.GREEN}MCP server connection successful! Can use tools provided by {len(connected_servers)} servers.{Style.RESET_ALL}")
else:
print(f"{Fore.YELLOW}No MCP servers successfully connected. Proceeding without tools.{Style.RESET_ALL}")
else:
print(f"{Fore.YELLOW}No MCP servers selected. Proceeding without tools.{Style.RESET_ALL}")
else:
print(f"{Fore.YELLOW}Proceeding without MCP tools.{Style.RESET_ALL}")
# Create conversation history list
conversation_history = []
# --- Enter interactive main loop ---
while True:
# Check if the user wants multi-line input
print(f"\n{Fore.GREEN}[>]{Style.RESET_ALL} ", end="")
user_query = input().strip()
# Handle special commands
if user_query.lower() in ["quit", "exit"]:
print(f"\n{Fore.CYAN}Thank you for using GHOSTCREW, exiting...{Style.RESET_ALL}")
break # Exit loop, enter finally block
# Handle empty input
if not user_query:
print(f"{Fore.YELLOW}No query entered. Please type your question.{Style.RESET_ALL}")
continue
# Handle multi-line mode request
if user_query.lower() == "multi":
print(f"{Fore.CYAN}Entering multi-line mode. Type your query across multiple lines.{Style.RESET_ALL}")
print(f"{Fore.CYAN}Press Enter on an empty line to submit.{Style.RESET_ALL}")
lines = []
while True:
line = input()
if line == "":
break
lines.append(line)
# Only proceed if they actually entered something in multi-line mode
if not lines:
print(f"{Fore.YELLOW}No query entered in multi-line mode.{Style.RESET_ALL}")
continue
user_query = "\n".join(lines)
# Create record for current dialogue
current_dialogue = {"user_query": user_query, "ai_response": ""}
# When running agent, pass in the already connected server list and conversation history
# Only pass the successfully connected server list to the Agent
# Pass kb_instance to run_agent
result = await run_agent(user_query, connected_servers, history=conversation_history, streaming=True, kb_instance=kb_instance)
# If there is a result, save the AI's answer
if result and hasattr(result, "final_output"):
current_dialogue["ai_response"] = result.final_output
# Add current dialogue to history
conversation_history.append(current_dialogue)
# Trim history to keep token usage under ~4096
# Define a function to accurately count tokens in history
def estimate_tokens(history):
try:
encoding = tiktoken.encoding_for_model(MODEL_NAME)
return sum(
len(encoding.encode(entry['user_query'])) +
len(encoding.encode(entry.get('ai_response', '')))
for entry in history
)
except Exception:
# Fall back to approximate counting if tiktoken fails
return sum(
len(entry['user_query'].split()) +
len(entry.get('ai_response', '').split())
for entry in history
)
# Trim history while token count exceeds the limit
while estimate_tokens(conversation_history) > 4000:
conversation_history.pop(0)
print(f"\n{Fore.CYAN}Ready for your next query. Type 'quit' to exit or 'multi' for multi-line input.{Style.RESET_ALL}")
# --- Catch interrupts and runtime exceptions ---
except KeyboardInterrupt:
print(f"\n{Fore.YELLOW}Program interrupted by user, exiting...{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Error during program execution: {e}{Style.RESET_ALL}")
traceback.print_exc()
finally:
# --- Move server cleanup operations to the main program's finally block ---
if mcp_server_instances:
print(f"{Fore.YELLOW}Cleaning up MCP server resources...{Style.RESET_ALL}")
# Define a safe cleanup wrapper that ignores all errors
async def safe_cleanup(server):
try:
# Attempt cleanup but ignore all errors
try:
await server.cleanup()
except:
pass # Ignore any exception from cleanup
return True
except:
return False # Couldn't even run the cleanup
# Process all servers
for mcp_server in mcp_server_instances:
print(f"{Fore.YELLOW}Attempting to clean up server: {mcp_server.name}...{Style.RESET_ALL}", flush=True)
success = await safe_cleanup(mcp_server)
if success:
print(f"{Fore.GREEN}Cleanup completed for {mcp_server.name}.{Style.RESET_ALL}", flush=True)
else:
print(f"{Fore.RED}Failed to initiate cleanup for {mcp_server.name}.{Style.RESET_ALL}", flush=True)
print(f"{Fore.YELLOW}MCP server resource cleanup complete.{Style.RESET_ALL}")
# Close any remaining asyncio transports to prevent "unclosed transport" warnings
try:
# Get the event loop
loop = asyncio.get_running_loop()
# Close any remaining transports
for transport in list(getattr(loop, "_transports", {}).values()):
if hasattr(transport, "close"):
try:
transport.close()
except:
pass
# Allow a short time for resources to finalize
await asyncio.sleep(0.1)
except:
pass # Ignore any errors in the final cleanup
print(f"{Fore.GREEN}Program ended.{Style.RESET_ALL}")
# Program entry point
if __name__ == "__main__":
asyncio.run(main())