Merge pull request #14 from Gouryella/feat/ip-access-control

feat/ip access control
This commit is contained in:
Gouryella
2026-01-12 11:51:49 +08:00
committed by GitHub
19 changed files with 492 additions and 55 deletions

View File

@@ -192,6 +192,7 @@ sudo journalctl -u drip-server -f
**Security**
- TLS 1.3 encryption for all connections
- Token-based authentication
- IP whitelist/blacklist access control
- No legacy protocol support
**Flexibility**
@@ -248,6 +249,21 @@ drip http 8080 -a 172.17.0.3
drip tcp 5432 -a db-container
```
**IP Access Control**
```bash
# Only allow access from specific networks (CIDR)
drip http 3000 --allow-ip 192.168.0.0/16,10.0.0.0/8
# Only allow specific IP addresses
drip http 3000 --allow-ip 192.168.1.100,192.168.1.101
# Block specific IP addresses
drip http 3000 --deny-ip 1.2.3.4,5.6.7.8
# Combine whitelist and blacklist
drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100
```
## Command Reference
```bash
@@ -258,6 +274,8 @@ drip http <port> [flags]
-d, --daemon Run in background
-s, --server Server address
-t, --token Auth token
--allow-ip Allow only these IPs or CIDR ranges
--deny-ip Deny these IPs or CIDR ranges
# HTTPS tunnel (same flags as http)
drip https <port> [flags]

View File

@@ -163,6 +163,7 @@ server {
location / {
proxy_pass https://127.0.0.1:8443;
proxy_ssl_protocols TLSv1.3;
proxy_ssl_verify off;
proxy_http_version 1.1;
proxy_set_header Host $host;
@@ -191,6 +192,7 @@ sudo journalctl -u drip-server -f
**安全性**
- 所有连接使用 TLS 1.3 加密
- 基于 Token 的身份验证
- IP 白名单/黑名单访问控制
- 不支持任何遗留协议
**灵活性**
@@ -247,6 +249,21 @@ drip http 8080 -a 172.17.0.3
drip tcp 5432 -a db-container
```
**IP 访问控制**
```bash
# 只允许特定网段访问CIDR
drip http 3000 --allow-ip 192.168.0.0/16,10.0.0.0/8
# 只允许特定 IP 访问
drip http 3000 --allow-ip 192.168.1.100,192.168.1.101
# 拒绝特定 IP
drip http 3000 --deny-ip 1.2.3.4,5.6.7.8
# 组合白名单和黑名单
drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100
```
## 命令参考
```bash
@@ -257,6 +274,8 @@ drip http <端口> [参数]
-d, --daemon 后台运行
-s, --server 服务器地址
-t, --token 认证 token
--allow-ip 只允许这些 IP 或 CIDR 访问
--deny-ip 拒绝这些 IP 或 CIDR 访问
# HTTPS 隧道(参数同 http
drip https <端口> [参数]

View File

@@ -15,6 +15,8 @@ var (
daemonMode bool
daemonMarker bool
localAddress string
allowIPs []string
denyIPs []string
)
var httpCmd = &cobra.Command{
@@ -25,6 +27,9 @@ var httpCmd = &cobra.Command{
Example:
drip http 3000 Tunnel localhost:3000
drip http 8080 --subdomain myapp Use custom subdomain
drip http 3000 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x
drip http 3000 --allow-ip 10.0.0.1 Allow single IP
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
Configuration:
First time: Run 'drip config init' to save server and token
@@ -39,6 +44,8 @@ func init() {
httpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpCmd)
@@ -67,6 +74,8 @@ func runHTTP(_ *cobra.Command, args []string) error {
LocalPort: port,
Subdomain: subdomain,
Insecure: insecure,
AllowIPs: allowIPs,
DenyIPs: denyIPs,
}
var daemon *DaemonInfo

View File

@@ -18,6 +18,9 @@ var httpsCmd = &cobra.Command{
Example:
drip https 443 Tunnel localhost:443
drip https 8443 --subdomain myapp Use custom subdomain
drip https 443 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x
drip https 443 --allow-ip 10.0.0.1 Allow single IP
drip https 443 --deny-ip 1.2.3.4 Block specific IP
Configuration:
First time: Run 'drip config init' to save server and token
@@ -32,6 +35,8 @@ func init() {
httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
@@ -60,6 +65,8 @@ func runHTTPS(_ *cobra.Command, args []string) error {
LocalPort: port,
Subdomain: subdomain,
Insecure: insecure,
AllowIPs: allowIPs,
DenyIPs: denyIPs,
}
var daemon *DaemonInfo

View File

@@ -21,17 +21,17 @@ import (
)
var (
serverPort int
serverPublicPort int
serverDomain string
serverAuthToken string
serverPort int
serverPublicPort int
serverDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
)
var serverCmd = &cobra.Command{
@@ -113,6 +113,10 @@ func runServer(_ *cobra.Command, _ []string) error {
Debug: serverDebug,
}
if err := serverConfig.Validate(); err != nil {
logger.Fatal("Invalid server configuration", zap.Error(err))
}
tlsConfig, err := serverConfig.LoadTLSConfig()
if err != nil {
logger.Fatal("Failed to load TLS configuration", zap.Error(err))

View File

@@ -20,6 +20,9 @@ Example:
drip tcp 3306 Tunnel MySQL
drip tcp 22 Tunnel SSH
drip tcp 6379 --subdomain myredis Tunnel Redis with custom subdomain
drip tcp 5432 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
Supported Services:
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
@@ -39,6 +42,8 @@ func init() {
tcpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
tcpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
tcpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(tcpCmd)
@@ -67,6 +72,8 @@ func runTCP(_ *cobra.Command, args []string) error {
LocalPort: port,
Subdomain: subdomain,
Insecure: insecure,
AllowIPs: allowIPs,
DenyIPs: denyIPs,
}
var daemon *DaemonInfo

View File

@@ -24,6 +24,9 @@ type ConnectorConfig struct {
PoolSize int
PoolMin int
PoolMax int
AllowIPs []string
DenyIPs []string
}
type TunnelClient interface {

View File

@@ -63,6 +63,9 @@ type PoolClient struct {
lastScale time.Time
logger *zap.Logger
allowIPs []string
denyIPs []string
}
// NewPoolClient creates a new pool client.
@@ -126,6 +129,8 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
doneCh: make(chan struct{}),
dataSessions: make(map[string]*sessionHandle),
logger: logger,
allowIPs: cfg.AllowIPs,
denyIPs: cfg.DenyIPs,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
@@ -156,6 +161,13 @@ func (c *PoolClient) Connect() error {
},
}
if len(c.allowIPs) > 0 || len(c.denyIPs) > 0 {
req.IPAccess = &protocol.IPAccessControl{
AllowIPs: c.allowIPs,
DenyIPs: c.denyIPs,
}
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()

View File

@@ -127,12 +127,12 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
return
}
done := make(chan struct{})
copyDone := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
case <-copyDone:
}
}()
@@ -150,7 +150,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
break
}
}
close(done)
close(copyDone)
}
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {

View File

@@ -41,6 +41,24 @@ type Handler struct {
metricsToken string
}
var privateNetworks []*net.IPNet
func init() {
privateCIDRs := []string{
"127.0.0.0/8", // IPv4 loopback
"10.0.0.0/8", // RFC 1918 Class A
"172.16.0.0/12", // RFC 1918 Class B
"192.168.0.0/16", // RFC 1918 Class C
"::1/128", // IPv6 loopback
"fc00::/7", // IPv6 unique local
"fe80::/10", // IPv6 link-local
}
for _, cidr := range privateCIDRs {
_, ipNet, _ := net.ParseCIDR(cidr)
privateNetworks = append(privateNetworks, ipNet)
}
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string, metricsToken string) *Handler {
return &Handler{
manager: manager,
@@ -81,6 +99,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if tconn.HasIPAccessControl() {
clientIP := h.extractClientIP(r)
if !tconn.IsIPAllowed(clientIP) {
http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden)
return
}
}
tType := tconn.GetTunnelType()
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
@@ -159,23 +185,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
defer pool.PutBuffer(buf)
// Copy with context cancellation support
ctx := r.Context()
done := make(chan struct{})
copyDone := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
case <-copyDone:
}
}()
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
pool.PutBuffer(buf)
close(done)
stream.Close()
close(copyDone)
}
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
@@ -184,24 +210,23 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err
err error
}
ch := make(chan result, 1)
done := make(chan struct{})
defer close(done)
go func() {
s, err := tconn.OpenStream()
select {
case ch <- result{s, err}:
case <-done:
if s != nil {
s.Close()
}
}
ch <- result{s, err}
}()
select {
case r := <-ch:
return r.stream, r.err
case <-time.After(openStreamTimeout):
// Goroutine will eventually complete and send to buffered channel
// which will be garbage collected. If stream was opened, it needs cleanup.
go func() {
if r := <-ch; r.stream != nil {
r.stream.Close()
}
}()
return nil, fmt.Errorf("open stream timeout")
}
}
@@ -328,6 +353,59 @@ func (h *Handler) extractSubdomain(host string) string {
return ""
}
// extractClientIP extracts the client IP from the request.
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
// comes from a private/loopback network (typical reverse proxy setup).
func (h *Handler) extractClientIP(r *http.Request) string {
// First, get the direct remote address
remoteIP := h.extractRemoteIP(r.RemoteAddr)
// Only trust proxy headers if the request comes from a private network
if isPrivateIP(remoteIP) {
// Check X-Forwarded-For header (may contain multiple IPs)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
}
// Fall back to remote address
return remoteIP
}
// extractRemoteIP extracts the IP address from a remote address string (host:port format).
func (h *Handler) extractRemoteIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return host
}
// isPrivateIP checks if the given IP is a private/loopback address.
func isPrivateIP(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
for _, network := range privateNetworks {
if network.Contains(parsedIP) {
return true
}
}
return false
}
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
html := `<!DOCTYPE html>
<html lang="en">

View File

@@ -181,6 +181,15 @@ func (c *Connection) Handle() error {
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
c.logger.Info("IP access control configured",
zap.String("subdomain", subdomain),
zap.Strings("allow_ips", req.IPAccess.AllowIPs),
zap.Strings("deny_ips", req.IPAccess.DenyIPs),
)
}
c.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
@@ -226,7 +235,10 @@ func (c *Connection) Handle() error {
RecommendedConns: recommendedConns,
}
respData, _ := json.Marshal(resp)
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal registration response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
err = protocol.WriteFrame(c.conn, ackFrame)
@@ -400,13 +412,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func parseTCPSubdomainPort(subdomain string) (int, bool) {
if !strings.HasPrefix(subdomain, "tcp-") {
return 0, false
@@ -516,11 +521,15 @@ func (c *Connection) sendError(code, message string) {
Code: code,
Message: message,
}
data, _ := json.Marshal(errMsg)
data, err := json.Marshal(errMsg)
if err != nil {
c.logger.Error("Failed to marshal error message", zap.Error(err))
return
}
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
if c.frameWriter == nil {
protocol.WriteFrame(c.conn, errFrame)
_ = protocol.WriteFrame(c.conn, errFrame)
} else {
c.frameWriter.WriteFrame(errFrame)
}
@@ -667,7 +676,10 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
Message: "Data connection accepted",
}
respData, _ := json.Marshal(resp)
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal data connect response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
@@ -723,7 +735,11 @@ func (c *Connection) sendDataConnectError(code, message string) {
Accepted: false,
Message: fmt.Sprintf("%s: %s", code, message),
}
respData, _ := json.Marshal(resp)
respData, err := json.Marshal(resp)
if err != nil {
c.logger.Error("Failed to marshal data connect error", zap.Error(err))
return
}
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
protocol.WriteFrame(c.conn, frame)
_ = protocol.WriteFrame(c.conn, frame)
}

View File

@@ -32,6 +32,8 @@ type Proxy struct {
ctx context.Context
cancel context.CancelFunc
checkIPAccess func(ip string) bool
}
type trafficStats interface {
@@ -66,6 +68,11 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
}
}
// SetIPAccessCheck sets the IP access control check function.
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
p.checkIPAccess = check
}
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
@@ -156,6 +163,17 @@ func (p *Proxy) handleConn(conn net.Conn) {
defer p.wg.Done()
defer conn.Close()
if p.checkIPAccess != nil {
clientIP := netutil.ExtractIP(conn.RemoteAddr().String())
if !p.checkIPAccess(clientIP) {
p.logger.Debug("IP access denied",
zap.String("ip", clientIP),
zap.Int("port", p.port),
)
return
}
}
if p.sem != nil {
select {
case p.sem <- struct{}{}:

View File

@@ -44,6 +44,10 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
}
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
}
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start tcp proxy: %w", err)
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"drip/internal/server/metrics"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"github.com/gorilla/websocket"
"go.uber.org/zap"
@@ -29,6 +30,8 @@ type Connection struct {
bytesIn atomic.Int64
bytesOut atomic.Int64
activeConnections atomic.Int64
ipAccessChecker *netutil.IPAccessChecker
}
// NewConnection creates a new tunnel connection
@@ -182,6 +185,32 @@ func (c *Connection) GetActiveConnections() int64 {
return c.activeConnections.Load()
}
// SetIPAccessControl sets the IP access control rules for this tunnel.
func (c *Connection) SetIPAccessControl(allowCIDRs, denyIPs []string) {
c.mu.Lock()
defer c.mu.Unlock()
c.ipAccessChecker = netutil.NewIPAccessChecker(allowCIDRs, denyIPs)
}
// IsIPAllowed checks if the given IP address is allowed to access this tunnel.
func (c *Connection) IsIPAllowed(ip string) bool {
c.mu.RLock()
checker := c.ipAccessChecker
c.mu.RUnlock()
if checker == nil {
return true // No access control configured
}
return checker.IsAllowed(ip)
}
// HasIPAccessControl returns true if IP access control is configured.
func (c *Connection) HasIPAccessControl() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ipAccessChecker != nil && c.ipAccessChecker.HasRules()
}
// StartWritePump starts the write pump for sending messages
func (c *Connection) StartWritePump() {
if c.Conn == nil {

View File

@@ -0,0 +1,132 @@
package netutil
import (
"net"
"strings"
)
// IPAccessChecker checks if an IP address is allowed based on whitelist/blacklist rules.
type IPAccessChecker struct {
allowNets []*net.IPNet // Allowed CIDR ranges (whitelist)
denyNets []*net.IPNet // Denied CIDR ranges (blacklist)
hasAllow bool // Whether whitelist is configured
hasDeny bool // Whether blacklist is configured
}
// NewIPAccessChecker creates a new IP access checker from CIDR and IP lists.
// allowCIDRs: list of CIDR ranges to allow (e.g., "192.168.1.0/24", "10.0.0.0/8")
// denyIPs: list of CIDR ranges or IP addresses to deny (e.g., "192.168.0.0/16", "1.2.3.4")
func NewIPAccessChecker(allowCIDRs, denyIPs []string) *IPAccessChecker {
checker := &IPAccessChecker{}
// Parse allowed CIDRs
for _, cidr := range allowCIDRs {
cidr = strings.TrimSpace(cidr)
if cidr == "" {
continue
}
// If no "/" in the string, treat it as a single IP (/32 for IPv4, /128 for IPv6)
if !strings.Contains(cidr, "/") {
ip := net.ParseIP(cidr)
if ip != nil {
if ip.To4() != nil {
cidr = cidr + "/32"
} else {
cidr = cidr + "/128"
}
}
}
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
checker.allowNets = append(checker.allowNets, ipNet)
}
checker.hasAllow = len(checker.allowNets) > 0
// Parse denied IPs/CIDRs
for _, ipStr := range denyIPs {
ipStr = strings.TrimSpace(ipStr)
if ipStr == "" {
continue
}
// If no "/" in the string, treat it as a single IP (/32 for IPv4, /128 for IPv6)
if !strings.Contains(ipStr, "/") {
ip := net.ParseIP(ipStr)
if ip != nil {
if ip.To4() != nil {
ipStr = ipStr + "/32"
} else {
ipStr = ipStr + "/128"
}
}
}
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
continue
}
checker.denyNets = append(checker.denyNets, ipNet)
}
checker.hasDeny = len(checker.denyNets) > 0
return checker
}
// IsAllowed checks if the given IP address is allowed.
// Rules:
// 1. If IP is in deny list, reject
// 2. If whitelist is configured and IP is not in whitelist, reject
// 3. Otherwise, allow
func (c *IPAccessChecker) IsAllowed(ipStr string) bool {
if c == nil || (!c.hasAllow && !c.hasDeny) {
return true // No rules configured, allow all
}
ip := net.ParseIP(ipStr)
if ip == nil {
return false // Invalid IP, reject
}
// Check deny list first (blacklist takes priority)
if c.hasDeny {
for _, denyNet := range c.denyNets {
if denyNet.Contains(ip) {
return false
}
}
}
// Check allow list (whitelist)
if c.hasAllow {
for _, allowNet := range c.allowNets {
if allowNet.Contains(ip) {
return true
}
}
return false // Whitelist configured but IP not in it
}
return true // No whitelist, and not in blacklist
}
// HasRules returns true if any access control rules are configured.
func (c *IPAccessChecker) HasRules() bool {
return c != nil && (c.hasAllow || c.hasDeny)
}
// ExtractIP extracts the IP address from a remote address string (e.g., "192.168.1.1:12345").
func ExtractIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
// Maybe it's just an IP without port
if ip := net.ParseIP(remoteAddr); ip != nil {
return remoteAddr
}
return ""
}
return host
}

View File

@@ -64,16 +64,6 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser
errCh := make(chan error, 2)
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
go func() {
defer wg.Done()
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
@@ -92,6 +82,16 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser
closeAll()
}()
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
wg.Wait()
select {

View File

@@ -8,6 +8,12 @@ type PoolCapabilities struct {
Version int `json:"version"` // Protocol version for pool features
}
// IPAccessControl defines IP-based access control rules for a tunnel
type IPAccessControl struct {
AllowIPs []string `json:"allow_ips,omitempty"` // Allowed IPs or CIDR ranges (whitelist)
DenyIPs []string `json:"deny_ips,omitempty"` // Denied IPs or CIDR ranges (blacklist)
}
// RegisterRequest is sent by client to register a tunnel
type RegisterRequest struct {
Token string `json:"token"` // Authentication token
@@ -19,6 +25,9 @@ type RegisterRequest struct {
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
// Access control (optional)
IPAccess *IPAccessControl `json:"ip_access,omitempty"` // IP-based access control rules
}
// RegisterResponse is sent by server after successful registration

View File

@@ -2,8 +2,10 @@ package config
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
@@ -15,6 +17,31 @@ type ClientConfig struct {
TLS bool `yaml:"tls"` // Use TLS (always true for production)
}
// Validate checks if the client configuration is valid
func (c *ClientConfig) Validate() error {
if c.Server == "" {
return fmt.Errorf("server address is required")
}
host, port, err := net.SplitHostPort(c.Server)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
return fmt.Errorf("server address must include port (e.g., example.com:443), got: %s", c.Server)
}
return fmt.Errorf("invalid server address format: %s (expected host:port)", c.Server)
}
if host == "" {
return fmt.Errorf("server host is required")
}
if port == "" {
return fmt.Errorf("server port is required")
}
return nil
}
// DefaultClientConfig returns the default configuration path
func DefaultClientConfigPath() string {
home, err := os.UserHomeDir()
@@ -43,8 +70,8 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
if config.Server == "" {
return nil, fmt.Errorf("server address is required in config")
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &config, nil

View File

@@ -4,6 +4,7 @@ import (
"crypto/tls"
"fmt"
"os"
"strings"
)
// ServerConfig holds the server configuration
@@ -30,6 +31,50 @@ type ServerConfig struct {
Debug bool
}
// Validate checks if the server configuration is valid
func (c *ServerConfig) Validate() error {
// Validate port
if c.Port < 1 || c.Port > 65535 {
return fmt.Errorf("invalid port %d: must be between 1 and 65535", c.Port)
}
// Validate public port if set
if c.PublicPort != 0 && (c.PublicPort < 1 || c.PublicPort > 65535) {
return fmt.Errorf("invalid public port %d: must be between 1 and 65535", c.PublicPort)
}
// Validate domain
if c.Domain == "" {
return fmt.Errorf("domain is required")
}
if strings.Contains(c.Domain, ":") {
return fmt.Errorf("domain should not contain port, got: %s", c.Domain)
}
// Validate TCP port range
if c.TCPPortMin < 1 || c.TCPPortMin > 65535 {
return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin)
}
if c.TCPPortMax < 1 || c.TCPPortMax > 65535 {
return fmt.Errorf("invalid TCPPortMax %d: must be between 1 and 65535", c.TCPPortMax)
}
if c.TCPPortMin >= c.TCPPortMax {
return fmt.Errorf("TCPPortMin (%d) must be less than TCPPortMax (%d)", c.TCPPortMin, c.TCPPortMax)
}
// Validate TLS settings
if c.TLSEnabled {
if c.TLSCertFile == "" {
return fmt.Errorf("TLS certificate file is required when TLS is enabled")
}
if c.TLSKeyFile == "" {
return fmt.Errorf("TLS key file is required when TLS is enabled")
}
}
return nil
}
// LoadTLSConfig loads TLS configuration
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
if !c.TLSEnabled {