feat: add transport protocol option supporting TCP and WebSocket connections

Added --transport parameter to allow users to select transport protocol type:
- auto: automatically choose based on server address (default)
- tcp: direct TLS 1.3 connection
- wss: WebSocket over TLS (CDN-friendly)

Also updated client connector to support WebSocket transport, and added server-side discovery endpoint to query supported transport protocols.
This commit is contained in:
Gouryella
2026-01-14 12:49:08 +08:00
parent 81f156f49c
commit 6139a9c0ed
13 changed files with 797 additions and 73 deletions

View File

@@ -17,12 +17,14 @@ import (
"time"
json "github.com/goccy/go-json"
"github.com/gorilla/websocket"
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/wsutil"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
@@ -94,6 +96,20 @@ type Handler struct {
domain string
authToken string
metricsToken string
publicPort int
// WebSocket tunnel support
wsUpgrader websocket.Upgrader
wsConnHandler WSConnectionHandler
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
}
// WSConnectionHandler handles WebSocket tunnel connections
type WSConnectionHandler interface {
HandleWSConnection(conn net.Conn, remoteAddr string)
}
var privateNetworks []*net.IPNet
@@ -121,10 +137,86 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
domain: domain,
authToken: authToken,
metricsToken: metricsToken,
wsUpgrader: websocket.Upgrader{
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for tunnel connections
},
},
}
}
// SetWSConnectionHandler sets the handler for WebSocket tunnel connections
func (h *Handler) SetWSConnectionHandler(handler WSConnectionHandler) {
h.wsConnHandler = handler
}
// SetPublicPort sets the public port for URL generation
func (h *Handler) SetPublicPort(port int) {
h.publicPort = port
}
// SetAllowedTransports sets the allowed transport protocols
func (h *Handler) SetAllowedTransports(transports []string) {
h.allowedTransports = transports
}
// SetAllowedTunnelTypes sets the allowed tunnel types
func (h *Handler) SetAllowedTunnelTypes(types []string) {
h.allowedTunnelTypes = types
}
// IsTransportAllowed checks if a transport is allowed
func (h *Handler) IsTransportAllowed(transport string) bool {
if len(h.allowedTransports) == 0 {
return true
}
for _, t := range h.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
}
// IsTunnelTypeAllowed checks if a tunnel type is allowed
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
if len(h.allowedTunnelTypes) == 0 {
return true
}
for _, t := range h.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
return true
}
}
return false
}
// GetPreferredTransport returns the preferred transport for auto-detection
func (h *Handler) GetPreferredTransport() string {
if len(h.allowedTransports) == 0 {
return "tcp"
}
if len(h.allowedTransports) == 1 {
return h.allowedTransports[0]
}
return "tcp"
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Discovery endpoint for client auto-detection
if r.URL.Path == "/_drip/discover" {
h.serveDiscovery(w, r)
return
}
// WebSocket tunnel endpoint - must be checked before other routes
if r.URL.Path == "/_drip/ws" {
h.handleTunnelWebSocket(w, r)
return
}
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
@@ -849,3 +941,69 @@ func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdoma
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(htmlContent))
}
// handleTunnelWebSocket handles WebSocket connections for tunnel clients
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
// Check if WSS transport is allowed
if !h.IsTransportAllowed("wss") {
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
return
}
if h.wsConnHandler == nil {
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
return
}
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
// Configure WebSocket for tunnel use
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
// Extract real client IP (support CDN headers)
remoteAddr := h.extractClientIP(r)
h.logger.Info("WebSocket tunnel connection established",
zap.String("remote_addr", remoteAddr),
)
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
// Handle the connection using the registered handler
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
}
// serveDiscovery returns server capabilities for client auto-detection
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
transports := h.allowedTransports
if len(transports) == 0 {
transports = []string{"tcp", "wss"}
}
tunnelTypes := h.allowedTunnelTypes
if len(tunnelTypes) == 0 {
tunnelTypes = []string{"http", "https", "tcp"}
}
response := map[string]interface{}{
"transports": transports,
"tunnel_types": tunnelTypes,
"preferred": h.GetPreferredTransport(),
"version": "1",
}
data, err := json.Marshal(response)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-cache")
w.Write(data)
}