diff --git a/README.md b/README.md index 705574f..605c2b0 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ sudo journalctl -u drip-server -f **Security** - TLS 1.3 encryption for all connections - Token-based authentication +- IP whitelist/blacklist access control - No legacy protocol support **Flexibility** @@ -248,6 +249,21 @@ drip http 8080 -a 172.17.0.3 drip tcp 5432 -a db-container ``` +**IP Access Control** +```bash +# Only allow access from specific networks (CIDR) +drip http 3000 --allow-ip 192.168.0.0/16,10.0.0.0/8 + +# Only allow specific IP addresses +drip http 3000 --allow-ip 192.168.1.100,192.168.1.101 + +# Block specific IP addresses +drip http 3000 --deny-ip 1.2.3.4,5.6.7.8 + +# Combine whitelist and blacklist +drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100 +``` + ## Command Reference ```bash @@ -258,6 +274,8 @@ drip http [flags] -d, --daemon Run in background -s, --server Server address -t, --token Auth token + --allow-ip Allow only these IPs or CIDR ranges + --deny-ip Deny these IPs or CIDR ranges # HTTPS tunnel (same flags as http) drip https [flags] diff --git a/README_CN.md b/README_CN.md index b4751b2..78109ee 100644 --- a/README_CN.md +++ b/README_CN.md @@ -163,6 +163,7 @@ server { location / { proxy_pass https://127.0.0.1:8443; + proxy_ssl_protocols TLSv1.3; proxy_ssl_verify off; proxy_http_version 1.1; proxy_set_header Host $host; @@ -191,6 +192,7 @@ sudo journalctl -u drip-server -f **安全性** - 所有连接使用 TLS 1.3 加密 - 基于 Token 的身份验证 +- IP 白名单/黑名单访问控制 - 不支持任何遗留协议 **灵活性** @@ -247,6 +249,21 @@ drip http 8080 -a 172.17.0.3 drip tcp 5432 -a db-container ``` +**IP 访问控制** +```bash +# 只允许特定网段访问(CIDR) +drip http 3000 --allow-ip 192.168.0.0/16,10.0.0.0/8 + +# 只允许特定 IP 访问 +drip http 3000 --allow-ip 192.168.1.100,192.168.1.101 + +# 拒绝特定 IP +drip http 3000 --deny-ip 1.2.3.4,5.6.7.8 + +# 组合白名单和黑名单 +drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100 +``` + ## 命令参考 ```bash @@ -257,6 +274,8 @@ drip http <端口> [参数] -d, --daemon 后台运行 -s, --server 服务器地址 -t, --token 认证 token + --allow-ip 只允许这些 IP 或 CIDR 访问 + --deny-ip 拒绝这些 IP 或 CIDR 访问 # HTTPS 隧道(参数同 http) drip https <端口> [参数] diff --git a/internal/client/cli/http.go b/internal/client/cli/http.go index d5f368c..2f1986f 100644 --- a/internal/client/cli/http.go +++ b/internal/client/cli/http.go @@ -15,6 +15,8 @@ var ( daemonMode bool daemonMarker bool localAddress string + allowIPs []string + denyIPs []string ) var httpCmd = &cobra.Command{ @@ -25,6 +27,9 @@ var httpCmd = &cobra.Command{ Example: drip http 3000 Tunnel localhost:3000 drip http 8080 --subdomain myapp Use custom subdomain + drip http 3000 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x + drip http 3000 --allow-ip 10.0.0.1 Allow single IP + drip http 3000 --deny-ip 1.2.3.4 Block specific IP Configuration: First time: Run 'drip config init' to save server and token @@ -39,6 +44,8 @@ func init() { httpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)") httpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)") httpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)") + httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") + httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(httpCmd) @@ -67,6 +74,8 @@ func runHTTP(_ *cobra.Command, args []string) error { LocalPort: port, Subdomain: subdomain, Insecure: insecure, + AllowIPs: allowIPs, + DenyIPs: denyIPs, } var daemon *DaemonInfo diff --git a/internal/client/cli/https.go b/internal/client/cli/https.go index 4607ee7..1fcbe6c 100644 --- a/internal/client/cli/https.go +++ b/internal/client/cli/https.go @@ -18,6 +18,9 @@ var httpsCmd = &cobra.Command{ Example: drip https 443 Tunnel localhost:443 drip https 8443 --subdomain myapp Use custom subdomain + drip https 443 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x + drip https 443 --allow-ip 10.0.0.1 Allow single IP + drip https 443 --deny-ip 1.2.3.4 Block specific IP Configuration: First time: Run 'drip config init' to save server and token @@ -32,6 +35,8 @@ func init() { httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)") httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)") httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)") + httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") + httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpsCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(httpsCmd) @@ -60,6 +65,8 @@ func runHTTPS(_ *cobra.Command, args []string) error { LocalPort: port, Subdomain: subdomain, Insecure: insecure, + AllowIPs: allowIPs, + DenyIPs: denyIPs, } var daemon *DaemonInfo diff --git a/internal/client/cli/server.go b/internal/client/cli/server.go index 51d5fc8..c9435a5 100644 --- a/internal/client/cli/server.go +++ b/internal/client/cli/server.go @@ -21,17 +21,17 @@ import ( ) var ( - serverPort int - serverPublicPort int - serverDomain string - serverAuthToken string + serverPort int + serverPublicPort int + serverDomain string + serverAuthToken string serverMetricsToken string - serverDebug bool - serverTCPPortMin int - serverTCPPortMax int - serverTLSCert string - serverTLSKey string - serverPprofPort int + serverDebug bool + serverTCPPortMin int + serverTCPPortMax int + serverTLSCert string + serverTLSKey string + serverPprofPort int ) var serverCmd = &cobra.Command{ @@ -113,6 +113,10 @@ func runServer(_ *cobra.Command, _ []string) error { Debug: serverDebug, } + if err := serverConfig.Validate(); err != nil { + logger.Fatal("Invalid server configuration", zap.Error(err)) + } + tlsConfig, err := serverConfig.LoadTLSConfig() if err != nil { logger.Fatal("Failed to load TLS configuration", zap.Error(err)) diff --git a/internal/client/cli/tcp.go b/internal/client/cli/tcp.go index 19b8872..cced649 100644 --- a/internal/client/cli/tcp.go +++ b/internal/client/cli/tcp.go @@ -20,6 +20,9 @@ Example: drip tcp 3306 Tunnel MySQL drip tcp 22 Tunnel SSH drip tcp 6379 --subdomain myredis Tunnel Redis with custom subdomain + drip tcp 5432 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x + drip tcp 22 --allow-ip 10.0.0.1 Allow single IP + drip tcp 22 --deny-ip 1.2.3.4 Block specific IP Supported Services: - Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017) @@ -39,6 +42,8 @@ func init() { tcpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)") tcpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)") tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)") + tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") + tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") tcpCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(tcpCmd) @@ -67,6 +72,8 @@ func runTCP(_ *cobra.Command, args []string) error { LocalPort: port, Subdomain: subdomain, Insecure: insecure, + AllowIPs: allowIPs, + DenyIPs: denyIPs, } var daemon *DaemonInfo diff --git a/internal/client/tcp/connector.go b/internal/client/tcp/connector.go index 890b237..8843c6c 100644 --- a/internal/client/tcp/connector.go +++ b/internal/client/tcp/connector.go @@ -24,6 +24,9 @@ type ConnectorConfig struct { PoolSize int PoolMin int PoolMax int + + AllowIPs []string + DenyIPs []string } type TunnelClient interface { diff --git a/internal/client/tcp/pool_client.go b/internal/client/tcp/pool_client.go index b6964ba..3237340 100644 --- a/internal/client/tcp/pool_client.go +++ b/internal/client/tcp/pool_client.go @@ -63,6 +63,9 @@ type PoolClient struct { lastScale time.Time logger *zap.Logger + + allowIPs []string + denyIPs []string } // NewPoolClient creates a new pool client. @@ -126,6 +129,8 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient { doneCh: make(chan struct{}), dataSessions: make(map[string]*sessionHandle), logger: logger, + allowIPs: cfg.AllowIPs, + denyIPs: cfg.DenyIPs, } if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS { @@ -156,6 +161,13 @@ func (c *PoolClient) Connect() error { }, } + if len(c.allowIPs) > 0 || len(c.denyIPs) > 0 { + req.IPAccess = &protocol.IPAccessControl{ + AllowIPs: c.allowIPs, + DenyIPs: c.denyIPs, + } + } + payload, err := json.Marshal(req) if err != nil { _ = primaryConn.Close() diff --git a/internal/client/tcp/pool_handler.go b/internal/client/tcp/pool_handler.go index e9ecc0a..8056ee8 100644 --- a/internal/client/tcp/pool_handler.go +++ b/internal/client/tcp/pool_handler.go @@ -127,12 +127,12 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { return } - done := make(chan struct{}) + copyDone := make(chan struct{}) go func() { select { case <-ctx.Done(): stream.Close() - case <-done: + case <-copyDone: } }() @@ -150,7 +150,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { break } } - close(done) + close(copyDone) } func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index 40d56c4..8e34c22 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -41,6 +41,24 @@ type Handler struct { metricsToken string } +var privateNetworks []*net.IPNet + +func init() { + privateCIDRs := []string{ + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC 1918 Class A + "172.16.0.0/12", // RFC 1918 Class B + "192.168.0.0/16", // RFC 1918 Class C + "::1/128", // IPv6 loopback + "fc00::/7", // IPv6 unique local + "fe80::/10", // IPv6 link-local + } + for _, cidr := range privateCIDRs { + _, ipNet, _ := net.ParseCIDR(cidr) + privateNetworks = append(privateNetworks, ipNet) + } +} + func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string, metricsToken string) *Handler { return &Handler{ manager: manager, @@ -81,6 +99,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if tconn.HasIPAccessControl() { + clientIP := h.extractClientIP(r) + if !tconn.IsIPAllowed(clientIP) { + http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden) + return + } + } + tType := tconn.GetTunnelType() if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS { http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway) @@ -159,23 +185,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) + // Use pooled buffer for zero-copy optimization + buf := pool.GetBuffer(pool.SizeLarge) + defer pool.PutBuffer(buf) + + // Copy with context cancellation support ctx := r.Context() - done := make(chan struct{}) + copyDone := make(chan struct{}) go func() { select { case <-ctx.Done(): stream.Close() - case <-done: + case <-copyDone: } }() - // Use pooled buffer for zero-copy optimization - buf := pool.GetBuffer(pool.SizeLarge) _, _ = io.CopyBuffer(w, resp.Body, (*buf)[:]) - pool.PutBuffer(buf) - - close(done) - stream.Close() + close(copyDone) } func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) { @@ -184,24 +210,23 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err err error } ch := make(chan result, 1) - done := make(chan struct{}) - defer close(done) go func() { s, err := tconn.OpenStream() - select { - case ch <- result{s, err}: - case <-done: - if s != nil { - s.Close() - } - } + ch <- result{s, err} }() select { case r := <-ch: return r.stream, r.err case <-time.After(openStreamTimeout): + // Goroutine will eventually complete and send to buffered channel + // which will be garbage collected. If stream was opened, it needs cleanup. + go func() { + if r := <-ch; r.stream != nil { + r.stream.Close() + } + }() return nil, fmt.Errorf("open stream timeout") } } @@ -328,6 +353,59 @@ func (h *Handler) extractSubdomain(host string) string { return "" } +// extractClientIP extracts the client IP from the request. +// It only trusts X-Forwarded-For and X-Real-IP headers when the request +// comes from a private/loopback network (typical reverse proxy setup). +func (h *Handler) extractClientIP(r *http.Request) string { + // First, get the direct remote address + remoteIP := h.extractRemoteIP(r.RemoteAddr) + + // Only trust proxy headers if the request comes from a private network + if isPrivateIP(remoteIP) { + // Check X-Forwarded-For header (may contain multiple IPs) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP (original client) + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + } + + // Fall back to remote address + return remoteIP +} + +// extractRemoteIP extracts the IP address from a remote address string (host:port format). +func (h *Handler) extractRemoteIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + return host +} + +// isPrivateIP checks if the given IP is a private/loopback address. +func isPrivateIP(ip string) bool { + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + for _, network := range privateNetworks { + if network.Contains(parsedIP) { + return true + } + } + + return false +} + func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) { html := ` diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index d05cc29..8dc50b3 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -181,6 +181,15 @@ func (c *Connection) Handle() error { c.tunnelConn.SetTunnelType(req.TunnelType) c.tunnelType = req.TunnelType + if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) { + c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs) + c.logger.Info("IP access control configured", + zap.String("subdomain", subdomain), + zap.Strings("allow_ips", req.IPAccess.AllowIPs), + zap.Strings("deny_ips", req.IPAccess.DenyIPs), + ) + } + c.logger.Info("Tunnel registered", zap.String("subdomain", subdomain), zap.String("tunnel_type", string(req.TunnelType)), @@ -226,7 +235,10 @@ func (c *Connection) Handle() error { RecommendedConns: recommendedConns, } - respData, _ := json.Marshal(resp) + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal registration response: %w", err) + } ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData) err = protocol.WriteFrame(c.conn, ackFrame) @@ -400,13 +412,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error { } } -func min(a, b int) int { - if a < b { - return a - } - return b -} - func parseTCPSubdomainPort(subdomain string) (int, bool) { if !strings.HasPrefix(subdomain, "tcp-") { return 0, false @@ -516,11 +521,15 @@ func (c *Connection) sendError(code, message string) { Code: code, Message: message, } - data, _ := json.Marshal(errMsg) + data, err := json.Marshal(errMsg) + if err != nil { + c.logger.Error("Failed to marshal error message", zap.Error(err)) + return + } errFrame := protocol.NewFrame(protocol.FrameTypeError, data) if c.frameWriter == nil { - protocol.WriteFrame(c.conn, errFrame) + _ = protocol.WriteFrame(c.conn, errFrame) } else { c.frameWriter.WriteFrame(errFrame) } @@ -667,7 +676,10 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read Message: "Data connection accepted", } - respData, _ := json.Marshal(resp) + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal data connect response: %w", err) + } ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) if err := protocol.WriteFrame(c.conn, ackFrame); err != nil { @@ -723,7 +735,11 @@ func (c *Connection) sendDataConnectError(code, message string) { Accepted: false, Message: fmt.Sprintf("%s: %s", code, message), } - respData, _ := json.Marshal(resp) + respData, err := json.Marshal(resp) + if err != nil { + c.logger.Error("Failed to marshal data connect error", zap.Error(err)) + return + } frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) - protocol.WriteFrame(c.conn, frame) + _ = protocol.WriteFrame(c.conn, frame) } diff --git a/internal/server/tcp/proxy.go b/internal/server/tcp/proxy.go index fd896ed..18bb68c 100644 --- a/internal/server/tcp/proxy.go +++ b/internal/server/tcp/proxy.go @@ -32,6 +32,8 @@ type Proxy struct { ctx context.Context cancel context.CancelFunc + + checkIPAccess func(ip string) bool } type trafficStats interface { @@ -66,6 +68,11 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func() } } +// SetIPAccessCheck sets the IP access control check function. +func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) { + p.checkIPAccess = check +} + func (p *Proxy) Start() error { addr := fmt.Sprintf("0.0.0.0:%d", p.port) @@ -156,6 +163,17 @@ func (p *Proxy) handleConn(conn net.Conn) { defer p.wg.Done() defer conn.Close() + if p.checkIPAccess != nil { + clientIP := netutil.ExtractIP(conn.RemoteAddr().String()) + if !p.checkIPAccess(clientIP) { + p.logger.Debug("IP access denied", + zap.String("ip", clientIP), + zap.Int("port", p.port), + ) + return + } + } + if p.sem != nil { select { case p.sem <- struct{}{}: diff --git a/internal/server/tcp/tunnel.go b/internal/server/tcp/tunnel.go index 8c3fb60..26c21c9 100644 --- a/internal/server/tcp/tunnel.go +++ b/internal/server/tcp/tunnel.go @@ -44,6 +44,10 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error { } c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger) + if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() { + c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed) + } + if err := c.proxy.Start(); err != nil { return fmt.Errorf("failed to start tcp proxy: %w", err) } diff --git a/internal/server/tunnel/connection.go b/internal/server/tunnel/connection.go index 7cca58d..7866d2c 100644 --- a/internal/server/tunnel/connection.go +++ b/internal/server/tunnel/connection.go @@ -7,6 +7,7 @@ import ( "time" "drip/internal/server/metrics" + "drip/internal/shared/netutil" "drip/internal/shared/protocol" "github.com/gorilla/websocket" "go.uber.org/zap" @@ -29,6 +30,8 @@ type Connection struct { bytesIn atomic.Int64 bytesOut atomic.Int64 activeConnections atomic.Int64 + + ipAccessChecker *netutil.IPAccessChecker } // NewConnection creates a new tunnel connection @@ -182,6 +185,32 @@ func (c *Connection) GetActiveConnections() int64 { return c.activeConnections.Load() } +// SetIPAccessControl sets the IP access control rules for this tunnel. +func (c *Connection) SetIPAccessControl(allowCIDRs, denyIPs []string) { + c.mu.Lock() + defer c.mu.Unlock() + c.ipAccessChecker = netutil.NewIPAccessChecker(allowCIDRs, denyIPs) +} + +// IsIPAllowed checks if the given IP address is allowed to access this tunnel. +func (c *Connection) IsIPAllowed(ip string) bool { + c.mu.RLock() + checker := c.ipAccessChecker + c.mu.RUnlock() + + if checker == nil { + return true // No access control configured + } + return checker.IsAllowed(ip) +} + +// HasIPAccessControl returns true if IP access control is configured. +func (c *Connection) HasIPAccessControl() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.ipAccessChecker != nil && c.ipAccessChecker.HasRules() +} + // StartWritePump starts the write pump for sending messages func (c *Connection) StartWritePump() { if c.Conn == nil { diff --git a/internal/shared/netutil/ipaccess.go b/internal/shared/netutil/ipaccess.go new file mode 100644 index 0000000..3cdcdb0 --- /dev/null +++ b/internal/shared/netutil/ipaccess.go @@ -0,0 +1,132 @@ +package netutil + +import ( + "net" + "strings" +) + +// IPAccessChecker checks if an IP address is allowed based on whitelist/blacklist rules. +type IPAccessChecker struct { + allowNets []*net.IPNet // Allowed CIDR ranges (whitelist) + denyNets []*net.IPNet // Denied CIDR ranges (blacklist) + hasAllow bool // Whether whitelist is configured + hasDeny bool // Whether blacklist is configured +} + +// NewIPAccessChecker creates a new IP access checker from CIDR and IP lists. +// allowCIDRs: list of CIDR ranges to allow (e.g., "192.168.1.0/24", "10.0.0.0/8") +// denyIPs: list of CIDR ranges or IP addresses to deny (e.g., "192.168.0.0/16", "1.2.3.4") +func NewIPAccessChecker(allowCIDRs, denyIPs []string) *IPAccessChecker { + checker := &IPAccessChecker{} + + // Parse allowed CIDRs + for _, cidr := range allowCIDRs { + cidr = strings.TrimSpace(cidr) + if cidr == "" { + continue + } + + // If no "/" in the string, treat it as a single IP (/32 for IPv4, /128 for IPv6) + if !strings.Contains(cidr, "/") { + ip := net.ParseIP(cidr) + if ip != nil { + if ip.To4() != nil { + cidr = cidr + "/32" + } else { + cidr = cidr + "/128" + } + } + } + + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + continue + } + checker.allowNets = append(checker.allowNets, ipNet) + } + checker.hasAllow = len(checker.allowNets) > 0 + + // Parse denied IPs/CIDRs + for _, ipStr := range denyIPs { + ipStr = strings.TrimSpace(ipStr) + if ipStr == "" { + continue + } + + // If no "/" in the string, treat it as a single IP (/32 for IPv4, /128 for IPv6) + if !strings.Contains(ipStr, "/") { + ip := net.ParseIP(ipStr) + if ip != nil { + if ip.To4() != nil { + ipStr = ipStr + "/32" + } else { + ipStr = ipStr + "/128" + } + } + } + + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + continue + } + checker.denyNets = append(checker.denyNets, ipNet) + } + checker.hasDeny = len(checker.denyNets) > 0 + + return checker +} + +// IsAllowed checks if the given IP address is allowed. +// Rules: +// 1. If IP is in deny list, reject +// 2. If whitelist is configured and IP is not in whitelist, reject +// 3. Otherwise, allow +func (c *IPAccessChecker) IsAllowed(ipStr string) bool { + if c == nil || (!c.hasAllow && !c.hasDeny) { + return true // No rules configured, allow all + } + + ip := net.ParseIP(ipStr) + if ip == nil { + return false // Invalid IP, reject + } + + // Check deny list first (blacklist takes priority) + if c.hasDeny { + for _, denyNet := range c.denyNets { + if denyNet.Contains(ip) { + return false + } + } + } + + // Check allow list (whitelist) + if c.hasAllow { + for _, allowNet := range c.allowNets { + if allowNet.Contains(ip) { + return true + } + } + return false // Whitelist configured but IP not in it + } + + return true // No whitelist, and not in blacklist +} + +// HasRules returns true if any access control rules are configured. +func (c *IPAccessChecker) HasRules() bool { + return c != nil && (c.hasAllow || c.hasDeny) +} + +// ExtractIP extracts the IP address from a remote address string (e.g., "192.168.1.1:12345"). +func ExtractIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + // Maybe it's just an IP without port + if ip := net.ParseIP(remoteAddr); ip != nil { + return remoteAddr + } + return "" + } + return host +} diff --git a/internal/shared/netutil/pipe.go b/internal/shared/netutil/pipe.go index 75e951d..04eb7dc 100644 --- a/internal/shared/netutil/pipe.go +++ b/internal/shared/netutil/pipe.go @@ -64,16 +64,6 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser errCh := make(chan error, 2) - if ctx != nil { - go func() { - select { - case <-ctx.Done(): - closeAll() - case <-stopCh: - } - }() - } - go func() { defer wg.Done() err := pipeBuffer(b, a, bufSize, onAToB, stopCh) @@ -92,6 +82,16 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser closeAll() }() + if ctx != nil { + go func() { + select { + case <-ctx.Done(): + closeAll() + case <-stopCh: + } + }() + } + wg.Wait() select { diff --git a/internal/shared/protocol/messages.go b/internal/shared/protocol/messages.go index b6502db..a94d28f 100644 --- a/internal/shared/protocol/messages.go +++ b/internal/shared/protocol/messages.go @@ -8,6 +8,12 @@ type PoolCapabilities struct { Version int `json:"version"` // Protocol version for pool features } +// IPAccessControl defines IP-based access control rules for a tunnel +type IPAccessControl struct { + AllowIPs []string `json:"allow_ips,omitempty"` // Allowed IPs or CIDR ranges (whitelist) + DenyIPs []string `json:"deny_ips,omitempty"` // Denied IPs or CIDR ranges (blacklist) +} + // RegisterRequest is sent by client to register a tunnel type RegisterRequest struct { Token string `json:"token"` // Authentication token @@ -19,6 +25,9 @@ type RegisterRequest struct { ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities + + // Access control (optional) + IPAccess *IPAccessControl `json:"ip_access,omitempty"` // IP-based access control rules } // RegisterResponse is sent by server after successful registration diff --git a/pkg/config/client_config.go b/pkg/config/client_config.go index 477195b..a883a4c 100644 --- a/pkg/config/client_config.go +++ b/pkg/config/client_config.go @@ -2,8 +2,10 @@ package config import ( "fmt" + "net" "os" "path/filepath" + "strings" "gopkg.in/yaml.v3" ) @@ -15,6 +17,31 @@ type ClientConfig struct { TLS bool `yaml:"tls"` // Use TLS (always true for production) } +// Validate checks if the client configuration is valid +func (c *ClientConfig) Validate() error { + if c.Server == "" { + return fmt.Errorf("server address is required") + } + + host, port, err := net.SplitHostPort(c.Server) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + return fmt.Errorf("server address must include port (e.g., example.com:443), got: %s", c.Server) + } + return fmt.Errorf("invalid server address format: %s (expected host:port)", c.Server) + } + + if host == "" { + return fmt.Errorf("server host is required") + } + + if port == "" { + return fmt.Errorf("server port is required") + } + + return nil +} + // DefaultClientConfig returns the default configuration path func DefaultClientConfigPath() string { home, err := os.UserHomeDir() @@ -43,8 +70,8 @@ func LoadClientConfig(path string) (*ClientConfig, error) { return nil, fmt.Errorf("failed to parse config file: %w", err) } - if config.Server == "" { - return nil, fmt.Errorf("server address is required in config") + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) } return &config, nil diff --git a/pkg/config/config.go b/pkg/config/config.go index 3ce88ec..5eec9ff 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "fmt" "os" + "strings" ) // ServerConfig holds the server configuration @@ -30,6 +31,50 @@ type ServerConfig struct { Debug bool } +// Validate checks if the server configuration is valid +func (c *ServerConfig) Validate() error { + // Validate port + if c.Port < 1 || c.Port > 65535 { + return fmt.Errorf("invalid port %d: must be between 1 and 65535", c.Port) + } + + // Validate public port if set + if c.PublicPort != 0 && (c.PublicPort < 1 || c.PublicPort > 65535) { + return fmt.Errorf("invalid public port %d: must be between 1 and 65535", c.PublicPort) + } + + // Validate domain + if c.Domain == "" { + return fmt.Errorf("domain is required") + } + if strings.Contains(c.Domain, ":") { + return fmt.Errorf("domain should not contain port, got: %s", c.Domain) + } + + // Validate TCP port range + if c.TCPPortMin < 1 || c.TCPPortMin > 65535 { + return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin) + } + if c.TCPPortMax < 1 || c.TCPPortMax > 65535 { + return fmt.Errorf("invalid TCPPortMax %d: must be between 1 and 65535", c.TCPPortMax) + } + if c.TCPPortMin >= c.TCPPortMax { + return fmt.Errorf("TCPPortMin (%d) must be less than TCPPortMax (%d)", c.TCPPortMin, c.TCPPortMax) + } + + // Validate TLS settings + if c.TLSEnabled { + if c.TLSCertFile == "" { + return fmt.Errorf("TLS certificate file is required when TLS is enabled") + } + if c.TLSKeyFile == "" { + return fmt.Errorf("TLS key file is required when TLS is enabled") + } + } + + return nil +} + // LoadTLSConfig loads TLS configuration func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) { if !c.TLSEnabled {