feat: add transport protocol option supporting TCP and WebSocket connections

Added --transport parameter to allow users to select transport protocol type:
- auto: automatically choose based on server address (default)
- tcp: direct TLS 1.3 connection
- wss: WebSocket over TLS (CDN-friendly)

Also updated client connector to support WebSocket transport, and added server-side discovery endpoint to query supported transport protocols.
This commit is contained in:
Gouryella
2026-01-14 12:49:08 +08:00
parent 81f156f49c
commit 6139a9c0ed
13 changed files with 797 additions and 73 deletions

View File

@@ -3,6 +3,7 @@ package cli
import (
"fmt"
"strconv"
"strings"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
@@ -18,6 +19,7 @@ var (
allowIPs []string
denyIPs []string
authPass string
transport string
)
var httpCmd = &cobra.Command{
@@ -32,12 +34,16 @@ Example:
drip http 3000 --allow-ip 10.0.0.1 Allow single IP
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
drip http 3000 --auth secret Enable proxy authentication with password
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
Configuration:
First time: Run 'drip config init' to save server and token
Subsequent: Just run 'drip http <port>'
Note: Uses TCP over TLS 1.3 for secure communication`,
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
Args: cobra.ExactArgs(1),
RunE: runHTTP,
}
@@ -49,6 +55,7 @@ func init() {
httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpCmd)
@@ -80,6 +87,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
AllowIPs: allowIPs,
DenyIPs: denyIPs,
AuthPass: authPass,
Transport: parseTransport(transport),
}
var daemon *DaemonInfo
@@ -89,3 +97,15 @@ func runHTTP(_ *cobra.Command, args []string) error {
return runTunnelWithUI(connConfig, daemon)
}
// parseTransport converts a string to TransportType
func parseTransport(s string) tcp.TransportType {
switch strings.ToLower(s) {
case "wss":
return tcp.TransportWebSocket
case "tcp", "tls":
return tcp.TransportTCP
default:
return tcp.TransportAuto
}
}

View File

@@ -22,12 +22,16 @@ Example:
drip https 443 --allow-ip 10.0.0.1 Allow single IP
drip https 443 --deny-ip 1.2.3.4 Block specific IP
drip https 443 --auth secret Enable proxy authentication with password
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
Configuration:
First time: Run 'drip config init' to save server and token
Subsequent: Just run 'drip https <port>'
Note: Uses TCP over TLS 1.3 for secure communication`,
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
Args: cobra.ExactArgs(1),
RunE: runHTTPS,
}
@@ -39,6 +43,7 @@ func init() {
httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
@@ -70,6 +75,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
AllowIPs: allowIPs,
DenyIPs: denyIPs,
AuthPass: authPass,
Transport: parseTransport(transport),
}
var daemon *DaemonInfo

View File

@@ -7,6 +7,7 @@ import (
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"drip/internal/server/proxy"
@@ -21,17 +22,20 @@ import (
)
var (
serverPort int
serverPublicPort int
serverDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverPort int
serverPublicPort int
serverDomain string
serverTunnelDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverTransports string
serverTunnelTypes string
)
var serverCmd = &cobra.Command{
@@ -47,7 +51,8 @@ func init() {
// Command line flags with environment variable defaults
serverCmd.Flags().IntVarP(&serverPort, "port", "p", getEnvInt("DRIP_PORT", 8443), "Server port (env: DRIP_PORT)")
serverCmd.Flags().IntVar(&serverPublicPort, "public-port", getEnvInt("DRIP_PUBLIC_PORT", 0), "Public port to display in URLs (env: DRIP_PUBLIC_PORT)")
serverCmd.Flags().StringVarP(&serverDomain, "domain", "d", getEnvString("DRIP_DOMAIN", constants.DefaultDomain), "Server domain (env: DRIP_DOMAIN)")
serverCmd.Flags().StringVarP(&serverDomain, "domain", "d", getEnvString("DRIP_DOMAIN", constants.DefaultDomain), "Server domain for client connections (env: DRIP_DOMAIN)")
serverCmd.Flags().StringVar(&serverTunnelDomain, "tunnel-domain", getEnvString("DRIP_TUNNEL_DOMAIN", ""), "Domain for tunnel URLs, defaults to --domain (env: DRIP_TUNNEL_DOMAIN)")
serverCmd.Flags().StringVarP(&serverAuthToken, "token", "t", getEnvString("DRIP_TOKEN", ""), "Authentication token (env: DRIP_TOKEN)")
serverCmd.Flags().StringVar(&serverMetricsToken, "metrics-token", getEnvString("DRIP_METRICS_TOKEN", ""), "Metrics and stats token (env: DRIP_METRICS_TOKEN)")
serverCmd.Flags().BoolVar(&serverDebug, "debug", false, "Enable debug logging")
@@ -60,6 +65,10 @@ func init() {
// Performance profiling
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
// Transport and tunnel type restrictions
serverCmd.Flags().StringVar(&serverTransports, "transports", getEnvString("DRIP_TRANSPORTS", "tcp,wss"), "Allowed transports: tcp,wss (env: DRIP_TRANSPORTS)")
serverCmd.Flags().StringVar(&serverTunnelTypes, "tunnel-types", getEnvString("DRIP_TUNNEL_TYPES", "http,https,tcp"), "Allowed tunnel types: http,https,tcp (env: DRIP_TUNNEL_TYPES)")
}
func runServer(_ *cobra.Command, _ []string) error {
@@ -100,17 +109,26 @@ func runServer(_ *cobra.Command, _ []string) error {
displayPort = serverPort
}
// Use tunnel domain if set, otherwise fall back to domain
tunnelDomain := serverTunnelDomain
if tunnelDomain == "" {
tunnelDomain = serverDomain
}
serverConfig := &config.ServerConfig{
Port: serverPort,
PublicPort: displayPort,
Domain: serverDomain,
TCPPortMin: serverTCPPortMin,
TCPPortMax: serverTCPPortMax,
TLSEnabled: true,
TLSCertFile: serverTLSCert,
TLSKeyFile: serverTLSKey,
AuthToken: serverAuthToken,
Debug: serverDebug,
Port: serverPort,
PublicPort: displayPort,
Domain: serverDomain,
TunnelDomain: tunnelDomain,
TCPPortMin: serverTCPPortMin,
TCPPortMax: serverTCPPortMax,
TLSEnabled: true,
TLSCertFile: serverTLSCert,
TLSKeyFile: serverTLSKey,
AuthToken: serverAuthToken,
Debug: serverDebug,
AllowedTransports: parseCommaSeparated(serverTransports),
AllowedTunnelTypes: parseCommaSeparated(serverTunnelTypes),
}
if err := serverConfig.Validate(); err != nil {
@@ -136,9 +154,13 @@ func runServer(_ *cobra.Command, _ []string) error {
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
httpHandler := proxy.NewHandler(tunnelManager, logger, serverDomain, serverAuthToken, serverMetricsToken)
httpHandler := proxy.NewHandler(tunnelManager, logger, tunnelDomain, serverAuthToken, serverMetricsToken)
httpHandler.SetAllowedTransports(serverConfig.AllowedTransports)
httpHandler.SetAllowedTunnelTypes(serverConfig.AllowedTunnelTypes)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, tunnelDomain, displayPort, httpHandler)
listener.SetAllowedTransports(serverConfig.AllowedTransports)
listener.SetAllowedTunnelTypes(serverConfig.AllowedTunnelTypes)
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))
@@ -147,7 +169,10 @@ func runServer(_ *cobra.Command, _ []string) error {
logger.Info("Drip Server started",
zap.String("address", listenAddr),
zap.String("domain", serverDomain),
zap.String("tunnel_domain", tunnelDomain),
zap.String("protocol", "TCP over TLS 1.3"),
zap.Strings("transports", serverConfig.AllowedTransports),
zap.Strings("tunnel_types", serverConfig.AllowedTunnelTypes),
)
quit := make(chan os.Signal, 1)
@@ -182,3 +207,19 @@ func getEnvString(key string, defaultVal string) string {
}
return defaultVal
}
// parseCommaSeparated splits a comma-separated string into a slice
func parseCommaSeparated(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, strings.ToLower(p))
}
}
return result
}

View File

@@ -23,6 +23,7 @@ Example:
drip tcp 5432 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly)
Supported Services:
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
@@ -33,7 +34,13 @@ Configuration:
First time: Run 'drip config init' to save server and token
Subsequent: Just run 'drip tcp <port>'
Note: Uses TCP over TLS 1.3 for secure communication`,
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)
Note: TCP tunnels require dynamic port allocation on the server.
When using CDN (--transport wss), the server must still expose the allocated port directly.`,
Args: cobra.ExactArgs(1),
RunE: runTCP,
}
@@ -44,6 +51,7 @@ func init() {
tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
tcpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(tcpCmd)
@@ -74,6 +82,7 @@ func runTCP(_ *cobra.Command, args []string) error {
Insecure: insecure,
AllowIPs: allowIPs,
DenyIPs: denyIPs,
Transport: parseTransport(transport),
}
var daemon *DaemonInfo

View File

@@ -41,8 +41,13 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
if err := connector.Connect(); err != nil {
if isConfigurationError(err) {
fmt.Println(ui.Warning(fmt.Sprintf("Configuration error: %v", err)))
os.Exit(1)
}
if isNonRetryableError(err) {
return fmt.Errorf("failed to connect: %w", err)
fmt.Println(ui.RenderConnectionFailed(err))
os.Exit(1)
}
reconnectAttempts++
@@ -228,3 +233,10 @@ func isNonRetryableError(err error) bool {
strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token")
}
// isConfigurationError returns true for errors caused by user configuration
// that won't be fixed by retrying (e.g., wrong transport type)
func isConfigurationError(err error) bool {
errStr := err.Error()
return strings.Contains(errStr, "server only supports")
}

View File

@@ -10,6 +10,18 @@ import (
"go.uber.org/zap"
)
// TransportType defines the transport protocol for tunnel connections
type TransportType string
const (
// TransportAuto automatically selects transport based on server address
TransportAuto TransportType = "auto"
// TransportTCP uses direct TLS 1.3 connection
TransportTCP TransportType = "tcp"
// TransportWebSocket uses WebSocket over TLS (CDN-friendly)
TransportWebSocket TransportType = "wss"
)
type LatencyCallback func(latency time.Duration)
type ConnectorConfig struct {
@@ -30,6 +42,9 @@ type ConnectorConfig struct {
// Proxy authentication
AuthPass string
// Transport protocol selection
Transport TransportType
}
type TunnelClient interface {

View File

@@ -6,11 +6,14 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
@@ -19,6 +22,7 @@ import (
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/internal/shared/wsutil"
"drip/pkg/config"
)
@@ -68,16 +72,41 @@ type PoolClient struct {
denyIPs []string
authPass string
// Transport protocol selection
transport TransportType
insecure bool
}
// NewPoolClient creates a new pool client.
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
// Parse server address to get host for TLS config
serverAddr := cfg.ServerAddr
host := serverAddr
// Handle wss:// prefix
if strings.HasPrefix(serverAddr, "wss://") {
if u, err := url.Parse(serverAddr); err == nil {
host = u.Host
// Normalize server address for internal use
if u.Port() == "" {
host = u.Host + ":443"
}
serverAddr = host
}
}
// Extract hostname without port for TLS
hostOnly, _, _ := net.SplitHostPort(host)
if hostOnly == "" {
hostOnly = host
}
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
tlsConfig = config.GetClientTLSConfig(hostOnly)
}
localHost := cfg.LocalHost
@@ -111,10 +140,16 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
}
initialSessions = min(max(initialSessions, minSessions), maxSessions)
// Determine transport type
transport := cfg.Transport
if transport == "" {
transport = TransportAuto
}
ctx, cancel := context.WithCancel(context.Background())
c := &PoolClient{
serverAddr: cfg.ServerAddr,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: tunnelType,
@@ -134,6 +169,8 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
allowIPs: cfg.AllowIPs,
denyIPs: cfg.DenyIPs,
authPass: cfg.AuthPass,
transport: transport,
insecure: cfg.Insecure,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
@@ -146,7 +183,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dialTLS()
primaryConn, err := c.dial()
if err != nil {
return err
}
@@ -298,6 +335,138 @@ func (c *PoolClient) dialTLS() (net.Conn, error) {
return conn, nil
}
// serverCapabilities holds the discovered server capabilities
type serverCapabilities struct {
Transports []string `json:"transports"`
Preferred string `json:"preferred"`
}
// dial selects the appropriate transport and establishes a connection
func (c *PoolClient) dial() (net.Conn, error) {
switch c.transport {
case TransportWebSocket:
return c.dialWebSocket()
case TransportTCP:
// User explicitly requested TCP, verify server supports it
caps := c.discoverServerCapabilities()
if caps != nil && len(caps.Transports) > 0 {
tcpSupported := false
for _, t := range caps.Transports {
if t == "tcp" {
tcpSupported = true
break
}
}
if !tcpSupported {
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
}
}
return c.dialTLS()
default: // TransportAuto
// Check if server address indicates WebSocket
if strings.HasPrefix(c.serverAddr, "wss://") {
return c.dialWebSocket()
}
// Query server for preferred transport
caps := c.discoverServerCapabilities()
if caps != nil && caps.Preferred == "wss" {
return c.dialWebSocket()
}
// Default to TCP
return c.dialTLS()
}
}
// discoverServerCapabilities queries the server for its capabilities
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
host = c.serverAddr
port = "443"
}
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: c.tlsConfig,
},
}
resp, err := client.Get(discoverURL)
if err != nil {
c.logger.Debug("Failed to discover server capabilities",
zap.Error(err),
)
return nil
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil
}
var caps serverCapabilities
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
return nil
}
c.logger.Debug("Discovered server capabilities",
zap.Strings("transports", caps.Transports),
zap.String("preferred", caps.Preferred),
)
return &caps
}
// dialWebSocket establishes a WebSocket connection to the server over TLS
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
// Build WebSocket URL
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
// No port specified, use default
host = c.serverAddr
port = "443"
}
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
c.logger.Debug("Connecting via WebSocket over TLS",
zap.String("url", wsURL),
)
dialer := websocket.Dialer{
TLSClientConfig: c.tlsConfig,
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
}
// Add authorization header if token is set
header := http.Header{}
if c.token != "" {
header.Set("Authorization", "Bearer "+c.token)
}
ws, resp, err := dialer.Dial(wsURL, header)
if err != nil {
if resp != nil {
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
}
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
}
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
c.logger.Debug("WebSocket connection established",
zap.String("remote_addr", ws.RemoteAddr().String()),
)
return conn, nil
}
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()

View File

@@ -188,7 +188,7 @@ func (c *PoolClient) addDataSession() error {
return fmt.Errorf("server does not support data connections")
}
conn, err := c.dialTLS()
conn, err := c.dial()
if err != nil {
return err
}

View File

@@ -17,12 +17,14 @@ import (
"time"
json "github.com/goccy/go-json"
"github.com/gorilla/websocket"
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/wsutil"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
@@ -94,6 +96,20 @@ type Handler struct {
domain string
authToken string
metricsToken string
publicPort int
// WebSocket tunnel support
wsUpgrader websocket.Upgrader
wsConnHandler WSConnectionHandler
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
}
// WSConnectionHandler handles WebSocket tunnel connections
type WSConnectionHandler interface {
HandleWSConnection(conn net.Conn, remoteAddr string)
}
var privateNetworks []*net.IPNet
@@ -121,10 +137,86 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
domain: domain,
authToken: authToken,
metricsToken: metricsToken,
wsUpgrader: websocket.Upgrader{
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for tunnel connections
},
},
}
}
// SetWSConnectionHandler sets the handler for WebSocket tunnel connections
func (h *Handler) SetWSConnectionHandler(handler WSConnectionHandler) {
h.wsConnHandler = handler
}
// SetPublicPort sets the public port for URL generation
func (h *Handler) SetPublicPort(port int) {
h.publicPort = port
}
// SetAllowedTransports sets the allowed transport protocols
func (h *Handler) SetAllowedTransports(transports []string) {
h.allowedTransports = transports
}
// SetAllowedTunnelTypes sets the allowed tunnel types
func (h *Handler) SetAllowedTunnelTypes(types []string) {
h.allowedTunnelTypes = types
}
// IsTransportAllowed checks if a transport is allowed
func (h *Handler) IsTransportAllowed(transport string) bool {
if len(h.allowedTransports) == 0 {
return true
}
for _, t := range h.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
}
// IsTunnelTypeAllowed checks if a tunnel type is allowed
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
if len(h.allowedTunnelTypes) == 0 {
return true
}
for _, t := range h.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
return true
}
}
return false
}
// GetPreferredTransport returns the preferred transport for auto-detection
func (h *Handler) GetPreferredTransport() string {
if len(h.allowedTransports) == 0 {
return "tcp"
}
if len(h.allowedTransports) == 1 {
return h.allowedTransports[0]
}
return "tcp"
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Discovery endpoint for client auto-detection
if r.URL.Path == "/_drip/discover" {
h.serveDiscovery(w, r)
return
}
// WebSocket tunnel endpoint - must be checked before other routes
if r.URL.Path == "/_drip/ws" {
h.handleTunnelWebSocket(w, r)
return
}
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
@@ -849,3 +941,69 @@ func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdoma
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(htmlContent))
}
// handleTunnelWebSocket handles WebSocket connections for tunnel clients
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
// Check if WSS transport is allowed
if !h.IsTransportAllowed("wss") {
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
return
}
if h.wsConnHandler == nil {
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
return
}
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
// Configure WebSocket for tunnel use
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
// Extract real client IP (support CDN headers)
remoteAddr := h.extractClientIP(r)
h.logger.Info("WebSocket tunnel connection established",
zap.String("remote_addr", remoteAddr),
)
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
// Handle the connection using the registered handler
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
}
// serveDiscovery returns server capabilities for client auto-detection
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
transports := h.allowedTransports
if len(transports) == 0 {
transports = []string{"tcp", "wss"}
}
tunnelTypes := h.allowedTunnelTypes
if len(tunnelTypes) == 0 {
tunnelTypes = []string{"http", "https", "tcp"}
}
response := map[string]interface{}{
"transports": transports,
"tunnel_types": tunnelTypes,
"preferred": h.GetPreferredTransport(),
"version": "1",
}
data, err := json.Marshal(response)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-cache")
w.Write(data)
}

View File

@@ -39,6 +39,7 @@ type Connection struct {
subdomain string
port int
domain string
tunnelDomain string
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
@@ -57,10 +58,13 @@ type Connection struct {
groupManager *ConnectionGroupManager
httpListener *connQueueListener
handedOff bool
// Server capabilities
allowedTunnelTypes []string
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
conn: conn,
@@ -69,6 +73,7 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
logger: logger,
portAlloc: portAlloc,
domain: domain,
tunnelDomain: tunnelDomain,
publicPort: publicPort,
httpHandler: httpHandler,
stopCh: make(chan struct{}),
@@ -130,6 +135,12 @@ func (c *Connection) Handle() error {
c.tunnelType = req.TunnelType
// Check if tunnel type is allowed
if !c.isTunnelTypeAllowed(string(req.TunnelType)) {
c.sendError("tunnel_type_not_allowed", fmt.Sprintf("Tunnel type '%s' is not allowed on this server", req.TunnelType))
return fmt.Errorf("tunnel type not allowed: %s", req.TunnelType)
}
if c.authToken != "" && req.Token != c.authToken {
c.sendError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed")
@@ -207,12 +218,12 @@ func (c *Connection) Handle() error {
var tunnelURL string
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
if c.publicPort == 443 {
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.tunnelDomain)
} else {
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.tunnelDomain, c.publicPort)
}
} else {
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.tunnelDomain, c.port)
}
var tunnelID string
@@ -750,3 +761,21 @@ func (c *Connection) sendDataConnectError(code, message string) {
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
_ = protocol.WriteFrame(c.conn, frame)
}
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
func (c *Connection) SetAllowedTunnelTypes(types []string) {
c.allowedTunnelTypes = types
}
// isTunnelTypeAllowed checks if a tunnel type is allowed
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
if len(c.allowedTunnelTypes) == 0 {
return true // Allow all by default
}
for _, t := range c.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
return true
}
}
return false
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"drip/internal/server/metrics"
"drip/internal/server/proxy"
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/recovery"
@@ -27,6 +28,7 @@ type Listener struct {
portAlloc *PortAllocator
logger *zap.Logger
domain string
tunnelDomain string
publicPort int
httpHandler http.Handler
listener net.Listener
@@ -40,9 +42,13 @@ type Listener struct {
groupManager *ConnectionGroupManager
httpServer *http.Server
httpListener *connQueueListener
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
numCPU := pool.NumCPU()
workers := numCPU * 5
queueSize := workers * 20
@@ -60,7 +66,7 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
// Initialize worker pool metrics
metrics.WorkerPoolSize.Set(float64(workers))
return &Listener{
l := &Listener{
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
@@ -68,6 +74,7 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
portAlloc: portAlloc,
logger: logger,
domain: domain,
tunnelDomain: tunnelDomain,
publicPort: publicPort,
httpHandler: httpHandler,
stopCh: make(chan struct{}),
@@ -77,6 +84,14 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
panicMetrics: panicMetrics,
groupManager: NewConnectionGroupManager(logger),
}
// Set up WebSocket connection handler if httpHandler supports it
if h, ok := httpHandler.(*proxy.Handler); ok {
h.SetWSConnectionHandler(l)
h.SetPublicPort(publicPort)
}
return l
}
func (l *Listener) Start() error {
@@ -234,7 +249,8 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -334,3 +350,92 @@ func (l *Listener) GetActiveConnections() int {
defer l.connMu.RUnlock()
return len(l.connections)
}
// HandleWSConnection implements proxy.WSConnectionHandler for WebSocket tunnel connections
func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
l.wg.Add(1)
defer l.wg.Done()
connID := remoteAddr
if connID == "" {
connID = conn.RemoteAddr().String()
}
l.logger.Info("Handling WebSocket tunnel connection",
zap.String("remote_addr", connID),
)
// Create connection handler (no TLS verification needed - already done by HTTP server)
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
l.connMu.Lock()
l.connections[connID] = tcpConn
l.connMu.Unlock()
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !tcpConn.IsHandedOff() {
conn.Close()
}
}()
if err := tcpConn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "websocket: close") {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "tunnel type not allowed") {
l.logger.Warn("WebSocket tunnel protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
} else {
l.logger.Error("WebSocket tunnel connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}
}
}
// SetAllowedTransports sets the allowed transport protocols
func (l *Listener) SetAllowedTransports(transports []string) {
l.allowedTransports = transports
}
// SetAllowedTunnelTypes sets the allowed tunnel types
func (l *Listener) SetAllowedTunnelTypes(types []string) {
l.allowedTunnelTypes = types
}
// IsTransportAllowed checks if a transport is allowed
func (l *Listener) IsTransportAllowed(transport string) bool {
if len(l.allowedTransports) == 0 {
return true
}
for _, t := range l.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
}

View File

@@ -0,0 +1,169 @@
package wsutil
import (
"io"
"net"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Conn wraps a gorilla/websocket.Conn to implement net.Conn.
// It uses binary messages for data transfer, making it compatible
// with yamux and the existing frame protocol.
type Conn struct {
ws *websocket.Conn
reader io.Reader
readMu sync.Mutex
writeMu sync.Mutex
localAddr net.Addr
remoteAddr net.Addr
pingStop chan struct{}
pingOnce sync.Once
}
// NewConn wraps a WebSocket connection as a net.Conn.
func NewConn(ws *websocket.Conn) *Conn {
c := &Conn{
ws: ws,
localAddr: ws.LocalAddr(),
remoteAddr: ws.RemoteAddr(),
pingStop: make(chan struct{}),
}
return c
}
// NewConnWithPing wraps a WebSocket connection and starts a ping loop
// to keep the connection alive through CDN/proxies.
func NewConnWithPing(ws *websocket.Conn, pingInterval time.Duration) *Conn {
c := NewConn(ws)
c.startPingLoop(pingInterval)
return c
}
// Read reads data from the WebSocket connection.
// It handles WebSocket message boundaries transparently, presenting
// a continuous byte stream to the caller.
func (c *Conn) Read(p []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
for {
if c.reader == nil {
messageType, reader, err := c.ws.NextReader()
if err != nil {
return 0, err
}
// Only accept binary messages for tunnel data
if messageType != websocket.BinaryMessage {
// Skip non-binary messages (text, ping/pong handled by gorilla)
continue
}
c.reader = reader
}
n, err := c.reader.Read(p)
if err == io.EOF {
// Current message exhausted, get next message
c.reader = nil
if n > 0 {
return n, nil
}
continue
}
return n, err
}
}
// Write writes data to the WebSocket connection as a binary message.
func (c *Conn) Write(p []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
err := c.ws.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return 0, err
}
return len(p), nil
}
// Close closes the WebSocket connection.
func (c *Conn) Close() error {
c.pingOnce.Do(func() {
close(c.pingStop)
})
// Send close message before closing
c.writeMu.Lock()
_ = c.ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMu.Unlock()
return c.ws.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.localAddr
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
// SetDeadline sets the read and write deadlines.
func (c *Conn) SetDeadline(t time.Time) error {
if err := c.ws.SetReadDeadline(t); err != nil {
return err
}
return c.ws.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.ws.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.ws.SetWriteDeadline(t)
}
// startPingLoop starts a goroutine that sends periodic ping messages
// to keep the connection alive through CDN/proxies like Cloudflare.
func (c *Conn) startPingLoop(interval time.Duration) {
if interval <= 0 {
interval = 30 * time.Second
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-c.pingStop:
return
case <-ticker.C:
c.writeMu.Lock()
err := c.ws.WriteControl(
websocket.PingMessage,
[]byte{},
time.Now().Add(10*time.Second),
)
c.writeMu.Unlock()
if err != nil {
return
}
}
}
}()
}
// UnderlyingConn returns the underlying WebSocket connection.
// Use with caution as direct access bypasses the mutex protection.
func (c *Conn) UnderlyingConn() *websocket.Conn {
return c.ws
}

View File

@@ -9,26 +9,31 @@ import (
// ServerConfig holds the server configuration
type ServerConfig struct {
// Server settings
Port int
PublicPort int // Port to display in URLs (for reverse proxy scenarios)
Domain string
Port int
PublicPort int // Port to display in URLs (for reverse proxy scenarios)
Domain string // Domain for client connections (e.g., connect.example.com)
TunnelDomain string // Domain for tunnel URLs (e.g., example.com for *.example.com)
// TCP tunnel dynamic port allocation
TCPPortMin int
TCPPortMax int
// TLS/SSL settings
// TLS settings
TLSEnabled bool
TLSCertFile string
TLSKeyFile string
AutoTLS bool // Automatic Let's Encrypt
// Security
AuthToken string
// Logging
Debug bool
// Allowed transports: "tcp", "wss", or "tcp,wss" (default: "tcp,wss")
AllowedTransports []string
// Allowed tunnel types: "http", "https", "tcp" (default: all)
AllowedTunnelTypes []string
}
// Validate checks if the server configuration is valid
@@ -51,6 +56,11 @@ func (c *ServerConfig) Validate() error {
return fmt.Errorf("domain should not contain port, got: %s", c.Domain)
}
// Validate tunnel domain if set
if c.TunnelDomain != "" && strings.Contains(c.TunnelDomain, ":") {
return fmt.Errorf("tunnel domain should not contain port, got: %s", c.TunnelDomain)
}
// Validate TCP port range
if c.TCPPortMin < 1 || c.TCPPortMin > 65535 {
return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin)
@@ -118,10 +128,10 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
func GetClientTLSConfig(serverName string) *tls.Config {
return &tls.Config{
ServerName: serverName,
MinVersion: tls.VersionTLS13, // Only TLS 1.3
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
ClientSessionCache: tls.NewLRUClientSessionCache(0), // Enable session resumption (0 = default size)
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
ClientSessionCache: tls.NewLRUClientSessionCache(0),
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
@@ -135,10 +145,10 @@ func GetClientTLSConfig(serverName string) *tls.Config {
func GetClientTLSConfigInsecure() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13, // Only TLS 1.3
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
ClientSessionCache: tls.NewLRUClientSessionCache(0), // Enable session resumption (0 = default size)
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
ClientSessionCache: tls.NewLRUClientSessionCache(0),
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
@@ -146,22 +156,3 @@ func GetClientTLSConfigInsecure() *tls.Config {
},
}
}
// GetServerURL returns the server URL based on configuration
func (c *ServerConfig) GetServerURL() string {
protocol := "http"
if c.TLSEnabled {
protocol = "https"
}
if c.Port == 80 || (c.TLSEnabled && c.Port == 443) {
return fmt.Sprintf("%s://%s", protocol, c.Domain)
}
return fmt.Sprintf("%s://%s:%d", protocol, c.Domain, c.Port)
}
// GetTCPAddress returns the TCP address for tunnel connections
func (c *ServerConfig) GetTCPAddress() string {
return fmt.Sprintf("%s:%d", c.Domain, c.Port)
}