From 85a0f44e44ed066a66be77c7ec3755df9e104cf6 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Sun, 11 Jan 2026 14:22:41 +0800 Subject: [PATCH 1/3] feat: Add IP access control functionality - Implement IP whitelist/blacklist access control mechanism - Add --allow-ip and --deny-ip command-line arguments to configure IP access rules - Support CIDR format for IP range configuration - Enable IP access control in HTTP, HTTPS, and TCP tunnels - Add IP access check logic to server-side proxy handling - Update documentation to explain how to use IP access control --- README.md | 18 ++++ README_CN.md | 19 +++++ internal/client/cli/http.go | 9 ++ internal/client/cli/https.go | 7 ++ internal/client/cli/tcp.go | 7 ++ internal/client/tcp/connector.go | 3 + internal/client/tcp/pool_client.go | 12 +++ internal/server/proxy/handler.go | 34 ++++++++ internal/server/tcp/connection.go | 9 ++ internal/server/tcp/proxy.go | 18 ++++ internal/server/tcp/tunnel.go | 4 + internal/server/tunnel/connection.go | 29 +++++++ internal/shared/netutil/ipaccess.go | 119 +++++++++++++++++++++++++++ internal/shared/protocol/messages.go | 9 ++ 14 files changed, 297 insertions(+) create mode 100644 internal/shared/netutil/ipaccess.go diff --git a/README.md b/README.md index 705574f..605c2b0 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ sudo journalctl -u drip-server -f **Security** - TLS 1.3 encryption for all connections - Token-based authentication +- IP whitelist/blacklist access control - No legacy protocol support **Flexibility** @@ -248,6 +249,21 @@ drip http 8080 -a 172.17.0.3 drip tcp 5432 -a db-container ``` +**IP Access Control** +```bash +# Only allow access from specific networks (CIDR) +drip http 3000 --allow-ip 192.168.0.0/16,10.0.0.0/8 + +# Only allow specific IP addresses +drip http 3000 --allow-ip 192.168.1.100,192.168.1.101 + +# Block specific IP addresses +drip http 3000 --deny-ip 1.2.3.4,5.6.7.8 + +# Combine whitelist and blacklist +drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100 +``` + ## Command Reference ```bash @@ -258,6 +274,8 @@ drip http [flags] -d, --daemon Run in background -s, --server Server address -t, --token Auth token + --allow-ip Allow only these IPs or CIDR ranges + --deny-ip Deny these IPs or CIDR ranges # HTTPS tunnel (same flags as http) drip https [flags] diff --git a/README_CN.md b/README_CN.md index b4751b2..78109ee 100644 --- a/README_CN.md +++ b/README_CN.md @@ -163,6 +163,7 @@ server { location / { proxy_pass https://127.0.0.1:8443; + proxy_ssl_protocols TLSv1.3; proxy_ssl_verify off; proxy_http_version 1.1; proxy_set_header Host $host; @@ -191,6 +192,7 @@ sudo journalctl -u drip-server -f **安全性** - 所有连接使用 TLS 1.3 加密 - 基于 Token 的身份验证 +- IP 白名单/黑名单访问控制 - 不支持任何遗留协议 **灵活性** @@ -247,6 +249,21 @@ drip http 8080 -a 172.17.0.3 drip tcp 5432 -a db-container ``` +**IP 访问控制** +```bash +# 只允许特定网段访问(CIDR) +drip http 3000 --allow-ip 192.168.0.0/16,10.0.0.0/8 + +# 只允许特定 IP 访问 +drip http 3000 --allow-ip 192.168.1.100,192.168.1.101 + +# 拒绝特定 IP +drip http 3000 --deny-ip 1.2.3.4,5.6.7.8 + +# 组合白名单和黑名单 +drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100 +``` + ## 命令参考 ```bash @@ -257,6 +274,8 @@ drip http <端口> [参数] -d, --daemon 后台运行 -s, --server 服务器地址 -t, --token 认证 token + --allow-ip 只允许这些 IP 或 CIDR 访问 + --deny-ip 拒绝这些 IP 或 CIDR 访问 # HTTPS 隧道(参数同 http) drip https <端口> [参数] diff --git a/internal/client/cli/http.go b/internal/client/cli/http.go index d5f368c..2f1986f 100644 --- a/internal/client/cli/http.go +++ b/internal/client/cli/http.go @@ -15,6 +15,8 @@ var ( daemonMode bool daemonMarker bool localAddress string + allowIPs []string + denyIPs []string ) var httpCmd = &cobra.Command{ @@ -25,6 +27,9 @@ var httpCmd = &cobra.Command{ Example: drip http 3000 Tunnel localhost:3000 drip http 8080 --subdomain myapp Use custom subdomain + drip http 3000 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x + drip http 3000 --allow-ip 10.0.0.1 Allow single IP + drip http 3000 --deny-ip 1.2.3.4 Block specific IP Configuration: First time: Run 'drip config init' to save server and token @@ -39,6 +44,8 @@ func init() { httpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)") httpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)") httpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)") + httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") + httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(httpCmd) @@ -67,6 +74,8 @@ func runHTTP(_ *cobra.Command, args []string) error { LocalPort: port, Subdomain: subdomain, Insecure: insecure, + AllowIPs: allowIPs, + DenyIPs: denyIPs, } var daemon *DaemonInfo diff --git a/internal/client/cli/https.go b/internal/client/cli/https.go index 4607ee7..1fcbe6c 100644 --- a/internal/client/cli/https.go +++ b/internal/client/cli/https.go @@ -18,6 +18,9 @@ var httpsCmd = &cobra.Command{ Example: drip https 443 Tunnel localhost:443 drip https 8443 --subdomain myapp Use custom subdomain + drip https 443 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x + drip https 443 --allow-ip 10.0.0.1 Allow single IP + drip https 443 --deny-ip 1.2.3.4 Block specific IP Configuration: First time: Run 'drip config init' to save server and token @@ -32,6 +35,8 @@ func init() { httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)") httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)") httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)") + httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") + httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpsCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(httpsCmd) @@ -60,6 +65,8 @@ func runHTTPS(_ *cobra.Command, args []string) error { LocalPort: port, Subdomain: subdomain, Insecure: insecure, + AllowIPs: allowIPs, + DenyIPs: denyIPs, } var daemon *DaemonInfo diff --git a/internal/client/cli/tcp.go b/internal/client/cli/tcp.go index 19b8872..cced649 100644 --- a/internal/client/cli/tcp.go +++ b/internal/client/cli/tcp.go @@ -20,6 +20,9 @@ Example: drip tcp 3306 Tunnel MySQL drip tcp 22 Tunnel SSH drip tcp 6379 --subdomain myredis Tunnel Redis with custom subdomain + drip tcp 5432 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x + drip tcp 22 --allow-ip 10.0.0.1 Allow single IP + drip tcp 22 --deny-ip 1.2.3.4 Block specific IP Supported Services: - Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017) @@ -39,6 +42,8 @@ func init() { tcpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)") tcpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)") tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)") + tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") + tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") tcpCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(tcpCmd) @@ -67,6 +72,8 @@ func runTCP(_ *cobra.Command, args []string) error { LocalPort: port, Subdomain: subdomain, Insecure: insecure, + AllowIPs: allowIPs, + DenyIPs: denyIPs, } var daemon *DaemonInfo diff --git a/internal/client/tcp/connector.go b/internal/client/tcp/connector.go index 890b237..8843c6c 100644 --- a/internal/client/tcp/connector.go +++ b/internal/client/tcp/connector.go @@ -24,6 +24,9 @@ type ConnectorConfig struct { PoolSize int PoolMin int PoolMax int + + AllowIPs []string + DenyIPs []string } type TunnelClient interface { diff --git a/internal/client/tcp/pool_client.go b/internal/client/tcp/pool_client.go index b6964ba..3237340 100644 --- a/internal/client/tcp/pool_client.go +++ b/internal/client/tcp/pool_client.go @@ -63,6 +63,9 @@ type PoolClient struct { lastScale time.Time logger *zap.Logger + + allowIPs []string + denyIPs []string } // NewPoolClient creates a new pool client. @@ -126,6 +129,8 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient { doneCh: make(chan struct{}), dataSessions: make(map[string]*sessionHandle), logger: logger, + allowIPs: cfg.AllowIPs, + denyIPs: cfg.DenyIPs, } if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS { @@ -156,6 +161,13 @@ func (c *PoolClient) Connect() error { }, } + if len(c.allowIPs) > 0 || len(c.denyIPs) > 0 { + req.IPAccess = &protocol.IPAccessControl{ + AllowIPs: c.allowIPs, + DenyIPs: c.denyIPs, + } + } + payload, err := json.Marshal(req) if err != nil { _ = primaryConn.Close() diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index 40d56c4..c4ce149 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -81,6 +81,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if tconn.HasIPAccessControl() { + clientIP := h.extractClientIP(r) + if !tconn.IsIPAllowed(clientIP) { + http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden) + return + } + } + tType := tconn.GetTunnelType() if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS { http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway) @@ -328,6 +336,32 @@ func (h *Handler) extractSubdomain(host string) string { return "" } +// extractClientIP extracts the client IP from the request. +// It checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), +// then falls back to the remote address. +func (h *Handler) extractClientIP(r *http.Request) string { + // Check X-Forwarded-For header (may contain multiple IPs) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP (original client) + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + + // Fall back to remote address + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) { html := ` diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index d05cc29..b2e521a 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -181,6 +181,15 @@ func (c *Connection) Handle() error { c.tunnelConn.SetTunnelType(req.TunnelType) c.tunnelType = req.TunnelType + if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) { + c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs) + c.logger.Info("IP access control configured", + zap.String("subdomain", subdomain), + zap.Strings("allow_ips", req.IPAccess.AllowIPs), + zap.Strings("deny_ips", req.IPAccess.DenyIPs), + ) + } + c.logger.Info("Tunnel registered", zap.String("subdomain", subdomain), zap.String("tunnel_type", string(req.TunnelType)), diff --git a/internal/server/tcp/proxy.go b/internal/server/tcp/proxy.go index fd896ed..18bb68c 100644 --- a/internal/server/tcp/proxy.go +++ b/internal/server/tcp/proxy.go @@ -32,6 +32,8 @@ type Proxy struct { ctx context.Context cancel context.CancelFunc + + checkIPAccess func(ip string) bool } type trafficStats interface { @@ -66,6 +68,11 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func() } } +// SetIPAccessCheck sets the IP access control check function. +func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) { + p.checkIPAccess = check +} + func (p *Proxy) Start() error { addr := fmt.Sprintf("0.0.0.0:%d", p.port) @@ -156,6 +163,17 @@ func (p *Proxy) handleConn(conn net.Conn) { defer p.wg.Done() defer conn.Close() + if p.checkIPAccess != nil { + clientIP := netutil.ExtractIP(conn.RemoteAddr().String()) + if !p.checkIPAccess(clientIP) { + p.logger.Debug("IP access denied", + zap.String("ip", clientIP), + zap.Int("port", p.port), + ) + return + } + } + if p.sem != nil { select { case p.sem <- struct{}{}: diff --git a/internal/server/tcp/tunnel.go b/internal/server/tcp/tunnel.go index 8c3fb60..26c21c9 100644 --- a/internal/server/tcp/tunnel.go +++ b/internal/server/tcp/tunnel.go @@ -44,6 +44,10 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error { } c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger) + if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() { + c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed) + } + if err := c.proxy.Start(); err != nil { return fmt.Errorf("failed to start tcp proxy: %w", err) } diff --git a/internal/server/tunnel/connection.go b/internal/server/tunnel/connection.go index 7cca58d..7866d2c 100644 --- a/internal/server/tunnel/connection.go +++ b/internal/server/tunnel/connection.go @@ -7,6 +7,7 @@ import ( "time" "drip/internal/server/metrics" + "drip/internal/shared/netutil" "drip/internal/shared/protocol" "github.com/gorilla/websocket" "go.uber.org/zap" @@ -29,6 +30,8 @@ type Connection struct { bytesIn atomic.Int64 bytesOut atomic.Int64 activeConnections atomic.Int64 + + ipAccessChecker *netutil.IPAccessChecker } // NewConnection creates a new tunnel connection @@ -182,6 +185,32 @@ func (c *Connection) GetActiveConnections() int64 { return c.activeConnections.Load() } +// SetIPAccessControl sets the IP access control rules for this tunnel. +func (c *Connection) SetIPAccessControl(allowCIDRs, denyIPs []string) { + c.mu.Lock() + defer c.mu.Unlock() + c.ipAccessChecker = netutil.NewIPAccessChecker(allowCIDRs, denyIPs) +} + +// IsIPAllowed checks if the given IP address is allowed to access this tunnel. +func (c *Connection) IsIPAllowed(ip string) bool { + c.mu.RLock() + checker := c.ipAccessChecker + c.mu.RUnlock() + + if checker == nil { + return true // No access control configured + } + return checker.IsAllowed(ip) +} + +// HasIPAccessControl returns true if IP access control is configured. +func (c *Connection) HasIPAccessControl() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.ipAccessChecker != nil && c.ipAccessChecker.HasRules() +} + // StartWritePump starts the write pump for sending messages func (c *Connection) StartWritePump() { if c.Conn == nil { diff --git a/internal/shared/netutil/ipaccess.go b/internal/shared/netutil/ipaccess.go new file mode 100644 index 0000000..561b7d2 --- /dev/null +++ b/internal/shared/netutil/ipaccess.go @@ -0,0 +1,119 @@ +package netutil + +import ( + "net" + "strings" +) + +// IPAccessChecker checks if an IP address is allowed based on whitelist/blacklist rules. +type IPAccessChecker struct { + allowNets []*net.IPNet // Allowed CIDR ranges (whitelist) + denyIPs []net.IP // Denied IP addresses (blacklist) + hasAllow bool // Whether whitelist is configured + hasDeny bool // Whether blacklist is configured +} + +// NewIPAccessChecker creates a new IP access checker from CIDR and IP lists. +// allowCIDRs: list of CIDR ranges to allow (e.g., "192.168.1.0/24", "10.0.0.0/8") +// denyIPs: list of IP addresses to deny (e.g., "1.2.3.4", "5.6.7.8") +func NewIPAccessChecker(allowCIDRs, denyIPs []string) *IPAccessChecker { + checker := &IPAccessChecker{} + + // Parse allowed CIDRs + for _, cidr := range allowCIDRs { + cidr = strings.TrimSpace(cidr) + if cidr == "" { + continue + } + + // If no "/" in the string, treat it as a single IP (/32 for IPv4, /128 for IPv6) + if !strings.Contains(cidr, "/") { + ip := net.ParseIP(cidr) + if ip != nil { + if ip.To4() != nil { + cidr = cidr + "/32" + } else { + cidr = cidr + "/128" + } + } + } + + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + continue + } + checker.allowNets = append(checker.allowNets, ipNet) + } + checker.hasAllow = len(checker.allowNets) > 0 + + // Parse denied IPs + for _, ipStr := range denyIPs { + ipStr = strings.TrimSpace(ipStr) + if ipStr == "" { + continue + } + + ip := net.ParseIP(ipStr) + if ip != nil { + checker.denyIPs = append(checker.denyIPs, ip) + } + } + checker.hasDeny = len(checker.denyIPs) > 0 + + return checker +} + +// IsAllowed checks if the given IP address is allowed. +// Rules: +// 1. If IP is in deny list, reject +// 2. If whitelist is configured and IP is not in whitelist, reject +// 3. Otherwise, allow +func (c *IPAccessChecker) IsAllowed(ipStr string) bool { + if c == nil || (!c.hasAllow && !c.hasDeny) { + return true // No rules configured, allow all + } + + ip := net.ParseIP(ipStr) + if ip == nil { + return false // Invalid IP, reject + } + + // Check deny list first (blacklist takes priority) + if c.hasDeny { + for _, denyIP := range c.denyIPs { + if ip.Equal(denyIP) { + return false + } + } + } + + // Check allow list (whitelist) + if c.hasAllow { + for _, allowNet := range c.allowNets { + if allowNet.Contains(ip) { + return true + } + } + return false // Whitelist configured but IP not in it + } + + return true // No whitelist, and not in blacklist +} + +// HasRules returns true if any access control rules are configured. +func (c *IPAccessChecker) HasRules() bool { + return c != nil && (c.hasAllow || c.hasDeny) +} + +// ExtractIP extracts the IP address from a remote address string (e.g., "192.168.1.1:12345"). +func ExtractIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + // Maybe it's just an IP without port + if ip := net.ParseIP(remoteAddr); ip != nil { + return remoteAddr + } + return "" + } + return host +} diff --git a/internal/shared/protocol/messages.go b/internal/shared/protocol/messages.go index b6502db..a94d28f 100644 --- a/internal/shared/protocol/messages.go +++ b/internal/shared/protocol/messages.go @@ -8,6 +8,12 @@ type PoolCapabilities struct { Version int `json:"version"` // Protocol version for pool features } +// IPAccessControl defines IP-based access control rules for a tunnel +type IPAccessControl struct { + AllowIPs []string `json:"allow_ips,omitempty"` // Allowed IPs or CIDR ranges (whitelist) + DenyIPs []string `json:"deny_ips,omitempty"` // Denied IPs or CIDR ranges (blacklist) +} + // RegisterRequest is sent by client to register a tunnel type RegisterRequest struct { Token string `json:"token"` // Authentication token @@ -19,6 +25,9 @@ type RegisterRequest struct { ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities + + // Access control (optional) + IPAccess *IPAccessControl `json:"ip_access,omitempty"` // IP-based access control rules } // RegisterResponse is sent by server after successful registration From d7b92a8b95259f867162c94f5cb23170be785303 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Mon, 12 Jan 2026 10:55:27 +0800 Subject: [PATCH 2/3] feat(server): Add server configuration validation and optimize connection handling - Add Validate method to ServerConfig to validate port ranges, domain format, TCP port ranges, and other configuration items - Add configuration validation logic in server.go to ensure valid configuration before server startup - Improve channel naming in TCP connections for better code readability - Enhance data copying mechanism with context cancellation support to avoid resource leaks - Add private network definitions for secure validation of trusted proxy headers fix(proxy): Strengthen client IP extraction security and fix error handling - Trust X-Forwarded-For and X-Real-IP headers only when requests originate from private/loopback networks - Define RFC 1918 and other private network ranges for proxy header validation - Add JSON serialization error handling in TCP connections to prevent data loss - Fix context handling logic in pipe callbacks - Optimize error handling mechanism for data connection responses refactor(config): Improve client configuration validation and error handling - Add Validate method to ClientConfig to verify server address format and port validity - Change configuration validation from simple checks to full validation function calls - Provide more detailed error messages to help users correctly configure server address formats --- internal/client/cli/server.go | 24 ++++--- internal/client/tcp/pool_handler.go | 6 +- internal/server/proxy/handler.go | 108 +++++++++++++++++++--------- internal/server/tcp/connection.go | 33 +++++---- internal/shared/netutil/pipe.go | 20 +++--- pkg/config/client_config.go | 31 +++++++- pkg/config/config.go | 45 ++++++++++++ 7 files changed, 197 insertions(+), 70 deletions(-) diff --git a/internal/client/cli/server.go b/internal/client/cli/server.go index 51d5fc8..c9435a5 100644 --- a/internal/client/cli/server.go +++ b/internal/client/cli/server.go @@ -21,17 +21,17 @@ import ( ) var ( - serverPort int - serverPublicPort int - serverDomain string - serverAuthToken string + serverPort int + serverPublicPort int + serverDomain string + serverAuthToken string serverMetricsToken string - serverDebug bool - serverTCPPortMin int - serverTCPPortMax int - serverTLSCert string - serverTLSKey string - serverPprofPort int + serverDebug bool + serverTCPPortMin int + serverTCPPortMax int + serverTLSCert string + serverTLSKey string + serverPprofPort int ) var serverCmd = &cobra.Command{ @@ -113,6 +113,10 @@ func runServer(_ *cobra.Command, _ []string) error { Debug: serverDebug, } + if err := serverConfig.Validate(); err != nil { + logger.Fatal("Invalid server configuration", zap.Error(err)) + } + tlsConfig, err := serverConfig.LoadTLSConfig() if err != nil { logger.Fatal("Failed to load TLS configuration", zap.Error(err)) diff --git a/internal/client/tcp/pool_handler.go b/internal/client/tcp/pool_handler.go index e9ecc0a..8056ee8 100644 --- a/internal/client/tcp/pool_handler.go +++ b/internal/client/tcp/pool_handler.go @@ -127,12 +127,12 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { return } - done := make(chan struct{}) + copyDone := make(chan struct{}) go func() { select { case <-ctx.Done(): stream.Close() - case <-done: + case <-copyDone: } }() @@ -150,7 +150,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { break } } - close(done) + close(copyDone) } func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index c4ce149..8e34c22 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -41,6 +41,24 @@ type Handler struct { metricsToken string } +var privateNetworks []*net.IPNet + +func init() { + privateCIDRs := []string{ + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC 1918 Class A + "172.16.0.0/12", // RFC 1918 Class B + "192.168.0.0/16", // RFC 1918 Class C + "::1/128", // IPv6 loopback + "fc00::/7", // IPv6 unique local + "fe80::/10", // IPv6 link-local + } + for _, cidr := range privateCIDRs { + _, ipNet, _ := net.ParseCIDR(cidr) + privateNetworks = append(privateNetworks, ipNet) + } +} + func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string, metricsToken string) *Handler { return &Handler{ manager: manager, @@ -167,23 +185,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) + // Use pooled buffer for zero-copy optimization + buf := pool.GetBuffer(pool.SizeLarge) + defer pool.PutBuffer(buf) + + // Copy with context cancellation support ctx := r.Context() - done := make(chan struct{}) + copyDone := make(chan struct{}) go func() { select { case <-ctx.Done(): stream.Close() - case <-done: + case <-copyDone: } }() - // Use pooled buffer for zero-copy optimization - buf := pool.GetBuffer(pool.SizeLarge) _, _ = io.CopyBuffer(w, resp.Body, (*buf)[:]) - pool.PutBuffer(buf) - - close(done) - stream.Close() + close(copyDone) } func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) { @@ -192,24 +210,23 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err err error } ch := make(chan result, 1) - done := make(chan struct{}) - defer close(done) go func() { s, err := tconn.OpenStream() - select { - case ch <- result{s, err}: - case <-done: - if s != nil { - s.Close() - } - } + ch <- result{s, err} }() select { case r := <-ch: return r.stream, r.err case <-time.After(openStreamTimeout): + // Goroutine will eventually complete and send to buffered channel + // which will be garbage collected. If stream was opened, it needs cleanup. + go func() { + if r := <-ch; r.stream != nil { + r.stream.Close() + } + }() return nil, fmt.Errorf("open stream timeout") } } @@ -337,31 +354,58 @@ func (h *Handler) extractSubdomain(host string) string { } // extractClientIP extracts the client IP from the request. -// It checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), -// then falls back to the remote address. +// It only trusts X-Forwarded-For and X-Real-IP headers when the request +// comes from a private/loopback network (typical reverse proxy setup). func (h *Handler) extractClientIP(r *http.Request) string { - // Check X-Forwarded-For header (may contain multiple IPs) - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - // Take the first IP (original client) - if idx := strings.Index(xff, ","); idx != -1 { - return strings.TrimSpace(xff[:idx]) - } - return strings.TrimSpace(xff) - } + // First, get the direct remote address + remoteIP := h.extractRemoteIP(r.RemoteAddr) - // Check X-Real-IP header - if xri := r.Header.Get("X-Real-IP"); xri != "" { - return strings.TrimSpace(xri) + // Only trust proxy headers if the request comes from a private network + if isPrivateIP(remoteIP) { + // Check X-Forwarded-For header (may contain multiple IPs) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP (original client) + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } } // Fall back to remote address - host, _, err := net.SplitHostPort(r.RemoteAddr) + return remoteIP +} + +// extractRemoteIP extracts the IP address from a remote address string (host:port format). +func (h *Handler) extractRemoteIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) if err != nil { - return r.RemoteAddr + return remoteAddr } return host } +// isPrivateIP checks if the given IP is a private/loopback address. +func isPrivateIP(ip string) bool { + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + for _, network := range privateNetworks { + if network.Contains(parsedIP) { + return true + } + } + + return false +} + func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) { html := ` diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index b2e521a..8dc50b3 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -235,7 +235,10 @@ func (c *Connection) Handle() error { RecommendedConns: recommendedConns, } - respData, _ := json.Marshal(resp) + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal registration response: %w", err) + } ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData) err = protocol.WriteFrame(c.conn, ackFrame) @@ -409,13 +412,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error { } } -func min(a, b int) int { - if a < b { - return a - } - return b -} - func parseTCPSubdomainPort(subdomain string) (int, bool) { if !strings.HasPrefix(subdomain, "tcp-") { return 0, false @@ -525,11 +521,15 @@ func (c *Connection) sendError(code, message string) { Code: code, Message: message, } - data, _ := json.Marshal(errMsg) + data, err := json.Marshal(errMsg) + if err != nil { + c.logger.Error("Failed to marshal error message", zap.Error(err)) + return + } errFrame := protocol.NewFrame(protocol.FrameTypeError, data) if c.frameWriter == nil { - protocol.WriteFrame(c.conn, errFrame) + _ = protocol.WriteFrame(c.conn, errFrame) } else { c.frameWriter.WriteFrame(errFrame) } @@ -676,7 +676,10 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read Message: "Data connection accepted", } - respData, _ := json.Marshal(resp) + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal data connect response: %w", err) + } ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) if err := protocol.WriteFrame(c.conn, ackFrame); err != nil { @@ -732,7 +735,11 @@ func (c *Connection) sendDataConnectError(code, message string) { Accepted: false, Message: fmt.Sprintf("%s: %s", code, message), } - respData, _ := json.Marshal(resp) + respData, err := json.Marshal(resp) + if err != nil { + c.logger.Error("Failed to marshal data connect error", zap.Error(err)) + return + } frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) - protocol.WriteFrame(c.conn, frame) + _ = protocol.WriteFrame(c.conn, frame) } diff --git a/internal/shared/netutil/pipe.go b/internal/shared/netutil/pipe.go index 75e951d..04eb7dc 100644 --- a/internal/shared/netutil/pipe.go +++ b/internal/shared/netutil/pipe.go @@ -64,16 +64,6 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser errCh := make(chan error, 2) - if ctx != nil { - go func() { - select { - case <-ctx.Done(): - closeAll() - case <-stopCh: - } - }() - } - go func() { defer wg.Done() err := pipeBuffer(b, a, bufSize, onAToB, stopCh) @@ -92,6 +82,16 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser closeAll() }() + if ctx != nil { + go func() { + select { + case <-ctx.Done(): + closeAll() + case <-stopCh: + } + }() + } + wg.Wait() select { diff --git a/pkg/config/client_config.go b/pkg/config/client_config.go index 477195b..a883a4c 100644 --- a/pkg/config/client_config.go +++ b/pkg/config/client_config.go @@ -2,8 +2,10 @@ package config import ( "fmt" + "net" "os" "path/filepath" + "strings" "gopkg.in/yaml.v3" ) @@ -15,6 +17,31 @@ type ClientConfig struct { TLS bool `yaml:"tls"` // Use TLS (always true for production) } +// Validate checks if the client configuration is valid +func (c *ClientConfig) Validate() error { + if c.Server == "" { + return fmt.Errorf("server address is required") + } + + host, port, err := net.SplitHostPort(c.Server) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + return fmt.Errorf("server address must include port (e.g., example.com:443), got: %s", c.Server) + } + return fmt.Errorf("invalid server address format: %s (expected host:port)", c.Server) + } + + if host == "" { + return fmt.Errorf("server host is required") + } + + if port == "" { + return fmt.Errorf("server port is required") + } + + return nil +} + // DefaultClientConfig returns the default configuration path func DefaultClientConfigPath() string { home, err := os.UserHomeDir() @@ -43,8 +70,8 @@ func LoadClientConfig(path string) (*ClientConfig, error) { return nil, fmt.Errorf("failed to parse config file: %w", err) } - if config.Server == "" { - return nil, fmt.Errorf("server address is required in config") + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) } return &config, nil diff --git a/pkg/config/config.go b/pkg/config/config.go index 3ce88ec..5eec9ff 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "fmt" "os" + "strings" ) // ServerConfig holds the server configuration @@ -30,6 +31,50 @@ type ServerConfig struct { Debug bool } +// Validate checks if the server configuration is valid +func (c *ServerConfig) Validate() error { + // Validate port + if c.Port < 1 || c.Port > 65535 { + return fmt.Errorf("invalid port %d: must be between 1 and 65535", c.Port) + } + + // Validate public port if set + if c.PublicPort != 0 && (c.PublicPort < 1 || c.PublicPort > 65535) { + return fmt.Errorf("invalid public port %d: must be between 1 and 65535", c.PublicPort) + } + + // Validate domain + if c.Domain == "" { + return fmt.Errorf("domain is required") + } + if strings.Contains(c.Domain, ":") { + return fmt.Errorf("domain should not contain port, got: %s", c.Domain) + } + + // Validate TCP port range + if c.TCPPortMin < 1 || c.TCPPortMin > 65535 { + return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin) + } + if c.TCPPortMax < 1 || c.TCPPortMax > 65535 { + return fmt.Errorf("invalid TCPPortMax %d: must be between 1 and 65535", c.TCPPortMax) + } + if c.TCPPortMin >= c.TCPPortMax { + return fmt.Errorf("TCPPortMin (%d) must be less than TCPPortMax (%d)", c.TCPPortMin, c.TCPPortMax) + } + + // Validate TLS settings + if c.TLSEnabled { + if c.TLSCertFile == "" { + return fmt.Errorf("TLS certificate file is required when TLS is enabled") + } + if c.TLSKeyFile == "" { + return fmt.Errorf("TLS key file is required when TLS is enabled") + } + } + + return nil +} + // LoadTLSConfig loads TLS configuration func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) { if !c.TLSEnabled { From 852dbb2ee696f0a926108a29841ef059244439f3 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Mon, 12 Jan 2026 11:50:34 +0800 Subject: [PATCH 3/3] feat(netutil): extend IP access checker blacklist from single IP to CIDR ranges Rename denyIPs field to denyNets, supporting blacklist configuration with CIDR ranges. Now supports both individual IP addresses and CIDR subnet masks as deny rules, with IPv4 automatically converted to /32 and IPv6 to /128, using the Contains method for more flexible subnet matching. --- internal/shared/netutil/ipaccess.go | 31 ++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/internal/shared/netutil/ipaccess.go b/internal/shared/netutil/ipaccess.go index 561b7d2..3cdcdb0 100644 --- a/internal/shared/netutil/ipaccess.go +++ b/internal/shared/netutil/ipaccess.go @@ -8,14 +8,14 @@ import ( // IPAccessChecker checks if an IP address is allowed based on whitelist/blacklist rules. type IPAccessChecker struct { allowNets []*net.IPNet // Allowed CIDR ranges (whitelist) - denyIPs []net.IP // Denied IP addresses (blacklist) + denyNets []*net.IPNet // Denied CIDR ranges (blacklist) hasAllow bool // Whether whitelist is configured hasDeny bool // Whether blacklist is configured } // NewIPAccessChecker creates a new IP access checker from CIDR and IP lists. // allowCIDRs: list of CIDR ranges to allow (e.g., "192.168.1.0/24", "10.0.0.0/8") -// denyIPs: list of IP addresses to deny (e.g., "1.2.3.4", "5.6.7.8") +// denyIPs: list of CIDR ranges or IP addresses to deny (e.g., "192.168.0.0/16", "1.2.3.4") func NewIPAccessChecker(allowCIDRs, denyIPs []string) *IPAccessChecker { checker := &IPAccessChecker{} @@ -46,19 +46,32 @@ func NewIPAccessChecker(allowCIDRs, denyIPs []string) *IPAccessChecker { } checker.hasAllow = len(checker.allowNets) > 0 - // Parse denied IPs + // Parse denied IPs/CIDRs for _, ipStr := range denyIPs { ipStr = strings.TrimSpace(ipStr) if ipStr == "" { continue } - ip := net.ParseIP(ipStr) - if ip != nil { - checker.denyIPs = append(checker.denyIPs, ip) + // If no "/" in the string, treat it as a single IP (/32 for IPv4, /128 for IPv6) + if !strings.Contains(ipStr, "/") { + ip := net.ParseIP(ipStr) + if ip != nil { + if ip.To4() != nil { + ipStr = ipStr + "/32" + } else { + ipStr = ipStr + "/128" + } + } } + + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + continue + } + checker.denyNets = append(checker.denyNets, ipNet) } - checker.hasDeny = len(checker.denyIPs) > 0 + checker.hasDeny = len(checker.denyNets) > 0 return checker } @@ -80,8 +93,8 @@ func (c *IPAccessChecker) IsAllowed(ipStr string) bool { // Check deny list first (blacklist takes priority) if c.hasDeny { - for _, denyIP := range c.denyIPs { - if ip.Equal(denyIP) { + for _, denyNet := range c.denyNets { + if denyNet.Contains(ip) { return false } }