mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-27 14:50:52 +00:00
362 lines
12 KiB
Go
362 lines
12 KiB
Go
package cli
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"drip/internal/client/tcp"
|
|
"drip/internal/shared/protocol"
|
|
"drip/internal/shared/utils"
|
|
"drip/pkg/config"
|
|
|
|
"github.com/spf13/cobra"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var tcpCmd = &cobra.Command{
|
|
Use: "tcp <port>",
|
|
Short: "Start TCP tunnel",
|
|
Long: `Start a TCP tunnel to expose any TCP service.
|
|
|
|
Example:
|
|
drip tcp 5432 Tunnel PostgreSQL
|
|
drip tcp 3306 Tunnel MySQL
|
|
drip tcp 22 Tunnel SSH
|
|
drip tcp 6379 --subdomain myredis Tunnel Redis with custom subdomain
|
|
|
|
Supported Services:
|
|
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
|
|
- SSH: Port 22
|
|
- Any TCP service
|
|
|
|
Configuration:
|
|
First time: Run 'drip config init' to save server and token
|
|
Subsequent: Just run 'drip tcp <port>'
|
|
|
|
Note: Uses TCP over TLS 1.3 for secure communication`,
|
|
Args: cobra.ExactArgs(1),
|
|
RunE: runTCP,
|
|
}
|
|
|
|
func init() {
|
|
tcpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
|
tcpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
|
tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
|
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
|
tcpCmd.Flags().MarkHidden("daemon-child")
|
|
rootCmd.AddCommand(tcpCmd)
|
|
}
|
|
|
|
func runTCP(cmd *cobra.Command, args []string) error {
|
|
// Parse port
|
|
port, err := strconv.Atoi(args[0])
|
|
if err != nil || port < 1 || port > 65535 {
|
|
return fmt.Errorf("invalid port number: %s", args[0])
|
|
}
|
|
|
|
// Handle daemon mode
|
|
if daemonMode && !daemonMarker {
|
|
// Start as daemon
|
|
daemonArgs := append([]string{"tcp"}, args...)
|
|
daemonArgs = append(daemonArgs, "--daemon-child")
|
|
if subdomain != "" {
|
|
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
|
}
|
|
if localAddress != "127.0.0.1" {
|
|
daemonArgs = append(daemonArgs, "--address", localAddress)
|
|
}
|
|
if serverURL != "" {
|
|
daemonArgs = append(daemonArgs, "--server", serverURL)
|
|
}
|
|
if authToken != "" {
|
|
daemonArgs = append(daemonArgs, "--token", authToken)
|
|
}
|
|
if insecure {
|
|
daemonArgs = append(daemonArgs, "--insecure")
|
|
}
|
|
if verbose {
|
|
daemonArgs = append(daemonArgs, "--verbose")
|
|
}
|
|
return StartDaemon("tcp", port, daemonArgs)
|
|
}
|
|
|
|
// Initialize logger
|
|
if err := utils.InitLogger(verbose); err != nil {
|
|
return fmt.Errorf("failed to initialize logger: %w", err)
|
|
}
|
|
defer utils.Sync()
|
|
|
|
logger := utils.GetLogger()
|
|
|
|
// Load configuration or use command line flags
|
|
var serverAddr, token string
|
|
|
|
if serverURL == "" {
|
|
// Try to load from config file
|
|
cfg, err := config.LoadClientConfig("")
|
|
if err != nil {
|
|
return fmt.Errorf(`configuration not found.
|
|
|
|
Please run 'drip config init' first, or use flags:
|
|
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
|
|
}
|
|
serverAddr = cfg.Server
|
|
token = cfg.Token
|
|
} else {
|
|
// Use command line flags
|
|
serverAddr = serverURL
|
|
token = authToken
|
|
}
|
|
|
|
// Validate server address
|
|
if serverAddr == "" {
|
|
return fmt.Errorf("server address is required")
|
|
}
|
|
|
|
// Create connector config
|
|
connConfig := &tcp.ConnectorConfig{
|
|
ServerAddr: serverAddr,
|
|
Token: token,
|
|
TunnelType: protocol.TunnelTypeTCP,
|
|
LocalHost: localAddress,
|
|
LocalPort: port,
|
|
Subdomain: subdomain,
|
|
Insecure: insecure,
|
|
}
|
|
|
|
// Setup signal handler for graceful shutdown
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
// Connection loop with reconnect support
|
|
reconnectAttempts := 0
|
|
serviceName := getServiceName(port)
|
|
for {
|
|
// Create connector
|
|
connector := tcp.NewConnector(connConfig, logger)
|
|
|
|
// Connect to server
|
|
if reconnectAttempts == 0 {
|
|
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
|
} else {
|
|
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
|
}
|
|
|
|
if err := connector.Connect(); err != nil {
|
|
// Check if this is a non-retryable error
|
|
if isNonRetryableErrorTCP(err) {
|
|
return fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
|
|
reconnectAttempts++
|
|
if reconnectAttempts >= maxReconnectAttempts {
|
|
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
|
}
|
|
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
|
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
|
|
|
// Wait before retry, but allow interrupt
|
|
select {
|
|
case <-quit:
|
|
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
|
return nil
|
|
case <-time.After(reconnectInterval):
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Reset reconnect attempts on successful connection
|
|
reconnectAttempts = 0
|
|
|
|
// Save daemon info if running as daemon child
|
|
if daemonMarker {
|
|
daemonInfo := &DaemonInfo{
|
|
PID: os.Getpid(),
|
|
Type: "tcp",
|
|
Port: port,
|
|
Subdomain: subdomain,
|
|
Server: serverAddr,
|
|
URL: connector.GetURL(),
|
|
StartTime: time.Now(),
|
|
Executable: os.Args[0],
|
|
}
|
|
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
|
// Log but don't fail
|
|
logger.Warn("Failed to save daemon info", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// Print tunnel information
|
|
fmt.Println()
|
|
fmt.Println("\033[1;35m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
|
fmt.Println("\033[1;35m║\033[0m \033[1;37m🔌 TCP Tunnel Connected Successfully!\033[0m \033[1;35m║\033[0m")
|
|
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;35m║\033[0m\n")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[1;36m%-60s\033[0m \033[1;35m║\033[0m\n", connector.GetURL())
|
|
fmt.Println("\033[1;35m║\033[0m \033[1;35m║\033[0m")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[90mService:\033[0m \033[1;35m%-50s\033[0m \033[1;35m║\033[0m\n", serviceName)
|
|
displayAddr := localAddress
|
|
if displayAddr == "127.0.0.1" {
|
|
displayAddr = "localhost"
|
|
}
|
|
fmt.Printf("\033[1;35m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;35m║\033[0m\n", displayAddr, port, "public", "")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;35m║\033[0m\n", "")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;35m║\033[0m\n", "")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;35m║\033[0m\n", "")
|
|
fmt.Printf("\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;35m║\033[0m\n", "")
|
|
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
|
fmt.Println("\033[1;35m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;35m║\033[0m")
|
|
fmt.Println("\033[1;35m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
|
fmt.Println()
|
|
|
|
// Setup latency display
|
|
latencyCh := make(chan time.Duration, 1)
|
|
connector.SetLatencyCallback(func(latency time.Duration) {
|
|
select {
|
|
case latencyCh <- latency:
|
|
default:
|
|
}
|
|
})
|
|
|
|
// Start stats display updater (updates every second)
|
|
stopDisplay := make(chan struct{})
|
|
disconnected := make(chan struct{})
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
var lastLatency time.Duration
|
|
for {
|
|
select {
|
|
case latency := <-latencyCh:
|
|
lastLatency = latency
|
|
case <-ticker.C:
|
|
// Update speed calculation
|
|
stats := connector.GetStats()
|
|
if stats != nil {
|
|
stats.UpdateSpeed()
|
|
snapshot := stats.GetSnapshot()
|
|
|
|
// Move cursor up 8 lines to update display
|
|
fmt.Print("\033[8A")
|
|
|
|
// Update latency line
|
|
fmt.Printf("\r\033[1;35m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;35m║\033[0m\n", formatLatencyTCP(lastLatency), "")
|
|
|
|
// Update traffic line
|
|
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
|
fmt.Printf("\r\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;35m║\033[0m\n", trafficStr)
|
|
|
|
// Update speed line
|
|
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
|
fmt.Printf("\r\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;35m║\033[0m\n", speedStr)
|
|
|
|
// Update requests line
|
|
fmt.Printf("\r\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;35m║\033[0m\n", snapshot.TotalRequests)
|
|
|
|
// Move back down 4 lines
|
|
fmt.Print("\033[4B")
|
|
}
|
|
case <-stopDisplay:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Monitor connection in background
|
|
go func() {
|
|
connector.Wait()
|
|
close(disconnected)
|
|
}()
|
|
|
|
// Wait for signal or disconnection
|
|
select {
|
|
case <-quit:
|
|
close(stopDisplay)
|
|
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
|
connector.Close()
|
|
if daemonMarker {
|
|
RemoveDaemonInfo("tcp", port)
|
|
}
|
|
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
|
return nil
|
|
case <-disconnected:
|
|
close(stopDisplay)
|
|
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
|
reconnectAttempts++
|
|
if reconnectAttempts >= maxReconnectAttempts {
|
|
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
|
}
|
|
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
|
|
|
// Wait before reconnect, but allow interrupt
|
|
select {
|
|
case <-quit:
|
|
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
|
return nil
|
|
case <-time.After(reconnectInterval):
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// getServiceName returns a friendly name for common port numbers
|
|
func getServiceName(port int) string {
|
|
services := map[int]string{
|
|
22: "SSH",
|
|
80: "HTTP",
|
|
443: "HTTPS",
|
|
3306: "MySQL",
|
|
5432: "PostgreSQL",
|
|
6379: "Redis",
|
|
27017: "MongoDB",
|
|
3389: "RDP",
|
|
5900: "VNC",
|
|
8080: "HTTP (Alt)",
|
|
8443: "HTTPS (Alt)",
|
|
}
|
|
|
|
if name, ok := services[port]; ok {
|
|
return fmt.Sprintf("%s (port %d)", name, port)
|
|
}
|
|
|
|
return fmt.Sprintf("TCP service on port %d", port)
|
|
}
|
|
|
|
// formatLatencyTCP formats latency with color based on value
|
|
func formatLatencyTCP(d time.Duration) string {
|
|
ms := d.Milliseconds()
|
|
if ms < 50 {
|
|
return fmt.Sprintf("\033[32m%dms\033[0m", ms) // Green: excellent
|
|
} else if ms < 100 {
|
|
return fmt.Sprintf("\033[33m%dms\033[0m", ms) // Yellow: good
|
|
} else if ms < 200 {
|
|
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms) // Orange: moderate
|
|
}
|
|
return fmt.Sprintf("\033[31m%dms\033[0m", ms) // Red: poor
|
|
}
|
|
|
|
// isNonRetryableErrorTCP checks if an error should not be retried
|
|
func isNonRetryableErrorTCP(err error) bool {
|
|
errStr := err.Error()
|
|
// Subdomain conflicts - no point retrying
|
|
if strings.Contains(errStr, "subdomain is already taken") ||
|
|
strings.Contains(errStr, "subdomain is reserved") ||
|
|
strings.Contains(errStr, "invalid subdomain") {
|
|
return true
|
|
}
|
|
// Authentication errors - no point retrying
|
|
if strings.Contains(errStr, "authentication") ||
|
|
strings.Contains(errStr, "Invalid authentication token") {
|
|
return true
|
|
}
|
|
return false
|
|
}
|