mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-26 14:21:17 +00:00
feat: add transport protocol option supporting TCP and WebSocket connections
Added --transport parameter to allow users to select transport protocol type: - auto: automatically choose based on server address (default) - tcp: direct TLS 1.3 connection - wss: WebSocket over TLS (CDN-friendly) Also updated client connector to support WebSocket transport, and added server-side discovery endpoint to query supported transport protocols.
This commit is contained in:
@@ -17,12 +17,14 @@ import (
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/wsutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
@@ -94,6 +96,20 @@ type Handler struct {
|
||||
domain string
|
||||
authToken string
|
||||
metricsToken string
|
||||
publicPort int
|
||||
|
||||
// WebSocket tunnel support
|
||||
wsUpgrader websocket.Upgrader
|
||||
wsConnHandler WSConnectionHandler
|
||||
|
||||
// Server capabilities
|
||||
allowedTransports []string
|
||||
allowedTunnelTypes []string
|
||||
}
|
||||
|
||||
// WSConnectionHandler handles WebSocket tunnel connections
|
||||
type WSConnectionHandler interface {
|
||||
HandleWSConnection(conn net.Conn, remoteAddr string)
|
||||
}
|
||||
|
||||
var privateNetworks []*net.IPNet
|
||||
@@ -121,10 +137,86 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
metricsToken: metricsToken,
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 256 * 1024,
|
||||
WriteBufferSize: 256 * 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for tunnel connections
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetWSConnectionHandler sets the handler for WebSocket tunnel connections
|
||||
func (h *Handler) SetWSConnectionHandler(handler WSConnectionHandler) {
|
||||
h.wsConnHandler = handler
|
||||
}
|
||||
|
||||
// SetPublicPort sets the public port for URL generation
|
||||
func (h *Handler) SetPublicPort(port int) {
|
||||
h.publicPort = port
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transport protocols
|
||||
func (h *Handler) SetAllowedTransports(transports []string) {
|
||||
h.allowedTransports = transports
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types
|
||||
func (h *Handler) SetAllowedTunnelTypes(types []string) {
|
||||
h.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (h *Handler) IsTransportAllowed(transport string) bool {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(h.allowedTunnelTypes) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredTransport returns the preferred transport for auto-detection
|
||||
func (h *Handler) GetPreferredTransport() string {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return "tcp"
|
||||
}
|
||||
if len(h.allowedTransports) == 1 {
|
||||
return h.allowedTransports[0]
|
||||
}
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Discovery endpoint for client auto-detection
|
||||
if r.URL.Path == "/_drip/discover" {
|
||||
h.serveDiscovery(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket tunnel endpoint - must be checked before other routes
|
||||
if r.URL.Path == "/_drip/ws" {
|
||||
h.handleTunnelWebSocket(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
@@ -849,3 +941,69 @@ func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdoma
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(htmlContent))
|
||||
}
|
||||
|
||||
// handleTunnelWebSocket handles WebSocket connections for tunnel clients
|
||||
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if WSS transport is allowed
|
||||
if !h.IsTransportAllowed("wss") {
|
||||
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if h.wsConnHandler == nil {
|
||||
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WebSocket for tunnel use
|
||||
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
|
||||
|
||||
// Extract real client IP (support CDN headers)
|
||||
remoteAddr := h.extractClientIP(r)
|
||||
|
||||
h.logger.Info("WebSocket tunnel connection established",
|
||||
zap.String("remote_addr", remoteAddr),
|
||||
)
|
||||
|
||||
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
|
||||
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
|
||||
|
||||
// Handle the connection using the registered handler
|
||||
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
|
||||
}
|
||||
|
||||
// serveDiscovery returns server capabilities for client auto-detection
|
||||
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
transports := h.allowedTransports
|
||||
if len(transports) == 0 {
|
||||
transports = []string{"tcp", "wss"}
|
||||
}
|
||||
|
||||
tunnelTypes := h.allowedTunnelTypes
|
||||
if len(tunnelTypes) == 0 {
|
||||
tunnelTypes = []string{"http", "https", "tcp"}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"transports": transports,
|
||||
"tunnel_types": tunnelTypes,
|
||||
"preferred": h.GetPreferredTransport(),
|
||||
"version": "1",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user