mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
Merge pull request #16 from Gouryella/feat/add-wss-transport
feat: add WebSocket transport protocol and GoReleaser integration
This commit is contained in:
64
.github/workflows/release.yml
vendored
64
.github/workflows/release.yml
vendored
@@ -3,72 +3,30 @@ name: Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- 'v*.*.*'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build and Release
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.25'
|
||||
|
||||
- name: Get version
|
||||
id: version
|
||||
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build for all platforms
|
||||
run: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
COMMIT=${{ github.sha }}
|
||||
COMMIT_SHORT=${COMMIT:0:10}
|
||||
BUILD_TIME=$(date -u '+%Y-%m-%d_%H:%M:%S')
|
||||
LDFLAGS="-s -w -X main.Version=${VERSION} -X main.GitCommit=${COMMIT_SHORT} -X main.BuildTime=${BUILD_TIME}"
|
||||
|
||||
# Linux amd64
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="${LDFLAGS}" -o drip-${VERSION}-linux-amd64 ./cmd/drip
|
||||
|
||||
# Linux arm64
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="${LDFLAGS}" -o drip-${VERSION}-linux-arm64 ./cmd/drip
|
||||
|
||||
# macOS amd64
|
||||
GOOS=darwin GOARCH=amd64 go build -ldflags="${LDFLAGS}" -o drip-${VERSION}-darwin-amd64 ./cmd/drip
|
||||
|
||||
# macOS arm64
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="${LDFLAGS}" -o drip-${VERSION}-darwin-arm64 ./cmd/drip
|
||||
|
||||
# Windows amd64
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="${LDFLAGS}" -o drip-${VERSION}-windows-amd64.exe ./cmd/drip
|
||||
|
||||
# Windows arm64
|
||||
GOOS=windows GOARCH=arm64 go build -ldflags="${LDFLAGS}" -o drip-${VERSION}-windows-arm64.exe ./cmd/drip
|
||||
|
||||
- name: Generate checksums
|
||||
run: |
|
||||
sha256sum drip-${{ steps.version.outputs.VERSION }}-* > checksums.txt
|
||||
|
||||
- name: Create Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
files: |
|
||||
drip-${{ steps.version.outputs.VERSION }}-linux-amd64
|
||||
drip-${{ steps.version.outputs.VERSION }}-linux-arm64
|
||||
drip-${{ steps.version.outputs.VERSION }}-darwin-amd64
|
||||
drip-${{ steps.version.outputs.VERSION }}-darwin-arm64
|
||||
drip-${{ steps.version.outputs.VERSION }}-windows-amd64.exe
|
||||
drip-${{ steps.version.outputs.VERSION }}-windows-arm64.exe
|
||||
checksums.txt
|
||||
draft: false
|
||||
prerelease: false
|
||||
generate_release_notes: true
|
||||
distribution: goreleaser
|
||||
version: latest
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
14
README.md
14
README.md
@@ -200,6 +200,7 @@ sudo journalctl -u drip-server -f
|
||||
- Forward to localhost or any LAN address
|
||||
- Custom subdomains or auto-generated
|
||||
- Daemon mode for persistent tunnels
|
||||
- Multiple transport protocols (TCP, WebSocket)
|
||||
|
||||
**Performance**
|
||||
- Binary protocol with msgpack encoding
|
||||
@@ -264,6 +265,18 @@ drip http 3000 --deny-ip 1.2.3.4,5.6.7.8
|
||||
drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100
|
||||
```
|
||||
|
||||
**Transport Protocols**
|
||||
```bash
|
||||
# Auto-select transport based on server (default)
|
||||
drip http 3000 --transport auto
|
||||
|
||||
# Use direct TLS 1.3 connection
|
||||
drip http 3000 --transport tcp
|
||||
|
||||
# Use WebSocket over TLS (CDN-friendly, works through Cloudflare)
|
||||
drip http 3000 --transport wss
|
||||
```
|
||||
|
||||
## Command Reference
|
||||
|
||||
```bash
|
||||
@@ -276,6 +289,7 @@ drip http <port> [flags]
|
||||
-t, --token Auth token
|
||||
--allow-ip Allow only these IPs or CIDR ranges
|
||||
--deny-ip Deny these IPs or CIDR ranges
|
||||
--transport Transport protocol: auto, tcp, wss (default: auto)
|
||||
|
||||
# HTTPS tunnel (same flags as http)
|
||||
drip https <port> [flags]
|
||||
|
||||
14
README_CN.md
14
README_CN.md
@@ -200,6 +200,7 @@ sudo journalctl -u drip-server -f
|
||||
- 可以转发到 localhost 或任何局域网地址
|
||||
- 自定义子域名或自动生成
|
||||
- 守护模式保持隧道持久运行
|
||||
- 多种传输协议(TCP、WebSocket)
|
||||
|
||||
**性能**
|
||||
- 二进制协议 + msgpack 编码
|
||||
@@ -264,6 +265,18 @@ drip http 3000 --deny-ip 1.2.3.4,5.6.7.8
|
||||
drip tcp 5432 --allow-ip 192.168.1.0/24 --deny-ip 192.168.1.100
|
||||
```
|
||||
|
||||
**传输协议**
|
||||
```bash
|
||||
# 根据服务器自动选择传输协议(默认)
|
||||
drip http 3000 --transport auto
|
||||
|
||||
# 使用直接 TLS 1.3 连接
|
||||
drip http 3000 --transport tcp
|
||||
|
||||
# 使用 WebSocket over TLS(CDN 友好,可穿透 Cloudflare)
|
||||
drip http 3000 --transport wss
|
||||
```
|
||||
|
||||
## 命令参考
|
||||
|
||||
```bash
|
||||
@@ -276,6 +289,7 @@ drip http <端口> [参数]
|
||||
-t, --token 认证 token
|
||||
--allow-ip 只允许这些 IP 或 CIDR 访问
|
||||
--deny-ip 拒绝这些 IP 或 CIDR 访问
|
||||
--transport 传输协议:auto, tcp, wss(默认:auto)
|
||||
|
||||
# HTTPS 隧道(参数同 http)
|
||||
drip https <端口> [参数]
|
||||
|
||||
@@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
@@ -18,6 +19,7 @@ var (
|
||||
allowIPs []string
|
||||
denyIPs []string
|
||||
authPass string
|
||||
transport string
|
||||
)
|
||||
|
||||
var httpCmd = &cobra.Command{
|
||||
@@ -32,12 +34,16 @@ Example:
|
||||
drip http 3000 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip http 3000 --auth secret Enable proxy authentication with password
|
||||
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
Subsequent: Just run 'drip http <port>'
|
||||
|
||||
Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTP,
|
||||
}
|
||||
@@ -49,6 +55,7 @@ func init() {
|
||||
httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpCmd)
|
||||
@@ -80,6 +87,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
AuthPass: authPass,
|
||||
Transport: parseTransport(transport),
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
@@ -89,3 +97,15 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
}
|
||||
|
||||
// parseTransport converts a string to TransportType
|
||||
func parseTransport(s string) tcp.TransportType {
|
||||
switch strings.ToLower(s) {
|
||||
case "wss":
|
||||
return tcp.TransportWebSocket
|
||||
case "tcp", "tls":
|
||||
return tcp.TransportTCP
|
||||
default:
|
||||
return tcp.TransportAuto
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,12 +22,16 @@ Example:
|
||||
drip https 443 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip https 443 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip https 443 --auth secret Enable proxy authentication with password
|
||||
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
Subsequent: Just run 'drip https <port>'
|
||||
|
||||
Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTPS,
|
||||
}
|
||||
@@ -39,6 +43,7 @@ func init() {
|
||||
httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpsCmd)
|
||||
@@ -70,6 +75,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
AuthPass: authPass,
|
||||
Transport: parseTransport(transport),
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"drip/internal/server/proxy"
|
||||
@@ -21,17 +22,20 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverTunnelDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverTransports string
|
||||
serverTunnelTypes string
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
@@ -47,7 +51,8 @@ func init() {
|
||||
// Command line flags with environment variable defaults
|
||||
serverCmd.Flags().IntVarP(&serverPort, "port", "p", getEnvInt("DRIP_PORT", 8443), "Server port (env: DRIP_PORT)")
|
||||
serverCmd.Flags().IntVar(&serverPublicPort, "public-port", getEnvInt("DRIP_PUBLIC_PORT", 0), "Public port to display in URLs (env: DRIP_PUBLIC_PORT)")
|
||||
serverCmd.Flags().StringVarP(&serverDomain, "domain", "d", getEnvString("DRIP_DOMAIN", constants.DefaultDomain), "Server domain (env: DRIP_DOMAIN)")
|
||||
serverCmd.Flags().StringVarP(&serverDomain, "domain", "d", getEnvString("DRIP_DOMAIN", constants.DefaultDomain), "Server domain for client connections (env: DRIP_DOMAIN)")
|
||||
serverCmd.Flags().StringVar(&serverTunnelDomain, "tunnel-domain", getEnvString("DRIP_TUNNEL_DOMAIN", ""), "Domain for tunnel URLs, defaults to --domain (env: DRIP_TUNNEL_DOMAIN)")
|
||||
serverCmd.Flags().StringVarP(&serverAuthToken, "token", "t", getEnvString("DRIP_TOKEN", ""), "Authentication token (env: DRIP_TOKEN)")
|
||||
serverCmd.Flags().StringVar(&serverMetricsToken, "metrics-token", getEnvString("DRIP_METRICS_TOKEN", ""), "Metrics and stats token (env: DRIP_METRICS_TOKEN)")
|
||||
serverCmd.Flags().BoolVar(&serverDebug, "debug", false, "Enable debug logging")
|
||||
@@ -60,6 +65,10 @@ func init() {
|
||||
|
||||
// Performance profiling
|
||||
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
|
||||
|
||||
// Transport and tunnel type restrictions
|
||||
serverCmd.Flags().StringVar(&serverTransports, "transports", getEnvString("DRIP_TRANSPORTS", "tcp,wss"), "Allowed transports: tcp,wss (env: DRIP_TRANSPORTS)")
|
||||
serverCmd.Flags().StringVar(&serverTunnelTypes, "tunnel-types", getEnvString("DRIP_TUNNEL_TYPES", "http,https,tcp"), "Allowed tunnel types: http,https,tcp (env: DRIP_TUNNEL_TYPES)")
|
||||
}
|
||||
|
||||
func runServer(_ *cobra.Command, _ []string) error {
|
||||
@@ -100,17 +109,26 @@ func runServer(_ *cobra.Command, _ []string) error {
|
||||
displayPort = serverPort
|
||||
}
|
||||
|
||||
// Use tunnel domain if set, otherwise fall back to domain
|
||||
tunnelDomain := serverTunnelDomain
|
||||
if tunnelDomain == "" {
|
||||
tunnelDomain = serverDomain
|
||||
}
|
||||
|
||||
serverConfig := &config.ServerConfig{
|
||||
Port: serverPort,
|
||||
PublicPort: displayPort,
|
||||
Domain: serverDomain,
|
||||
TCPPortMin: serverTCPPortMin,
|
||||
TCPPortMax: serverTCPPortMax,
|
||||
TLSEnabled: true,
|
||||
TLSCertFile: serverTLSCert,
|
||||
TLSKeyFile: serverTLSKey,
|
||||
AuthToken: serverAuthToken,
|
||||
Debug: serverDebug,
|
||||
Port: serverPort,
|
||||
PublicPort: displayPort,
|
||||
Domain: serverDomain,
|
||||
TunnelDomain: tunnelDomain,
|
||||
TCPPortMin: serverTCPPortMin,
|
||||
TCPPortMax: serverTCPPortMax,
|
||||
TLSEnabled: true,
|
||||
TLSCertFile: serverTLSCert,
|
||||
TLSKeyFile: serverTLSKey,
|
||||
AuthToken: serverAuthToken,
|
||||
Debug: serverDebug,
|
||||
AllowedTransports: parseCommaSeparated(serverTransports),
|
||||
AllowedTunnelTypes: parseCommaSeparated(serverTunnelTypes),
|
||||
}
|
||||
|
||||
if err := serverConfig.Validate(); err != nil {
|
||||
@@ -136,9 +154,13 @@ func runServer(_ *cobra.Command, _ []string) error {
|
||||
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
|
||||
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, serverDomain, serverAuthToken, serverMetricsToken)
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, tunnelDomain, serverAuthToken, serverMetricsToken)
|
||||
httpHandler.SetAllowedTransports(serverConfig.AllowedTransports)
|
||||
httpHandler.SetAllowedTunnelTypes(serverConfig.AllowedTunnelTypes)
|
||||
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler)
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, tunnelDomain, displayPort, httpHandler)
|
||||
listener.SetAllowedTransports(serverConfig.AllowedTransports)
|
||||
listener.SetAllowedTunnelTypes(serverConfig.AllowedTunnelTypes)
|
||||
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
@@ -147,7 +169,10 @@ func runServer(_ *cobra.Command, _ []string) error {
|
||||
logger.Info("Drip Server started",
|
||||
zap.String("address", listenAddr),
|
||||
zap.String("domain", serverDomain),
|
||||
zap.String("tunnel_domain", tunnelDomain),
|
||||
zap.String("protocol", "TCP over TLS 1.3"),
|
||||
zap.Strings("transports", serverConfig.AllowedTransports),
|
||||
zap.Strings("tunnel_types", serverConfig.AllowedTunnelTypes),
|
||||
)
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
@@ -182,3 +207,19 @@ func getEnvString(key string, defaultVal string) string {
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// parseCommaSeparated splits a comma-separated string into a slice
|
||||
func parseCommaSeparated(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, strings.ToLower(p))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ Example:
|
||||
drip tcp 5432 --allow-ip 192.168.0.0/16 Only allow IPs from 192.168.x.x
|
||||
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
|
||||
Supported Services:
|
||||
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
|
||||
@@ -33,7 +34,13 @@ Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
Subsequent: Just run 'drip tcp <port>'
|
||||
|
||||
Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)
|
||||
|
||||
Note: TCP tunnels require dynamic port allocation on the server.
|
||||
When using CDN (--transport wss), the server must still expose the allocated port directly.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runTCP,
|
||||
}
|
||||
@@ -44,6 +51,7 @@ func init() {
|
||||
tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
tcpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
@@ -74,6 +82,7 @@ func runTCP(_ *cobra.Command, args []string) error {
|
||||
Insecure: insecure,
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
Transport: parseTransport(transport),
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
|
||||
@@ -41,8 +41,13 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
if isConfigurationError(err) {
|
||||
fmt.Println(ui.Warning(fmt.Sprintf("Configuration error: %v", err)))
|
||||
os.Exit(1)
|
||||
}
|
||||
if isNonRetryableError(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
fmt.Println(ui.RenderConnectionFailed(err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
@@ -228,3 +233,10 @@ func isNonRetryableError(err error) bool {
|
||||
strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token")
|
||||
}
|
||||
|
||||
// isConfigurationError returns true for errors caused by user configuration
|
||||
// that won't be fixed by retrying (e.g., wrong transport type)
|
||||
func isConfigurationError(err error) bool {
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "server only supports")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,18 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TransportType defines the transport protocol for tunnel connections
|
||||
type TransportType string
|
||||
|
||||
const (
|
||||
// TransportAuto automatically selects transport based on server address
|
||||
TransportAuto TransportType = "auto"
|
||||
// TransportTCP uses direct TLS 1.3 connection
|
||||
TransportTCP TransportType = "tcp"
|
||||
// TransportWebSocket uses WebSocket over TLS (CDN-friendly)
|
||||
TransportWebSocket TransportType = "wss"
|
||||
)
|
||||
|
||||
type LatencyCallback func(latency time.Duration)
|
||||
|
||||
type ConnectorConfig struct {
|
||||
@@ -30,6 +42,9 @@ type ConnectorConfig struct {
|
||||
|
||||
// Proxy authentication
|
||||
AuthPass string
|
||||
|
||||
// Transport protocol selection
|
||||
Transport TransportType
|
||||
}
|
||||
|
||||
type TunnelClient interface {
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
@@ -19,6 +22,7 @@ import (
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/stats"
|
||||
"drip/internal/shared/wsutil"
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
@@ -68,16 +72,41 @@ type PoolClient struct {
|
||||
denyIPs []string
|
||||
|
||||
authPass string
|
||||
|
||||
// Transport protocol selection
|
||||
transport TransportType
|
||||
insecure bool
|
||||
}
|
||||
|
||||
// NewPoolClient creates a new pool client.
|
||||
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
// Parse server address to get host for TLS config
|
||||
serverAddr := cfg.ServerAddr
|
||||
host := serverAddr
|
||||
|
||||
// Handle wss:// prefix
|
||||
if strings.HasPrefix(serverAddr, "wss://") {
|
||||
if u, err := url.Parse(serverAddr); err == nil {
|
||||
host = u.Host
|
||||
// Normalize server address for internal use
|
||||
if u.Port() == "" {
|
||||
host = u.Host + ":443"
|
||||
}
|
||||
serverAddr = host
|
||||
}
|
||||
}
|
||||
|
||||
// Extract hostname without port for TLS
|
||||
hostOnly, _, _ := net.SplitHostPort(host)
|
||||
if hostOnly == "" {
|
||||
hostOnly = host
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if cfg.Insecure {
|
||||
tlsConfig = config.GetClientTLSConfigInsecure()
|
||||
} else {
|
||||
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
||||
tlsConfig = config.GetClientTLSConfig(host)
|
||||
tlsConfig = config.GetClientTLSConfig(hostOnly)
|
||||
}
|
||||
|
||||
localHost := cfg.LocalHost
|
||||
@@ -111,10 +140,16 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
}
|
||||
initialSessions = min(max(initialSessions, minSessions), maxSessions)
|
||||
|
||||
// Determine transport type
|
||||
transport := cfg.Transport
|
||||
if transport == "" {
|
||||
transport = TransportAuto
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
c := &PoolClient{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
serverAddr: serverAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: tunnelType,
|
||||
@@ -134,6 +169,8 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
allowIPs: cfg.AllowIPs,
|
||||
denyIPs: cfg.DenyIPs,
|
||||
authPass: cfg.AuthPass,
|
||||
transport: transport,
|
||||
insecure: cfg.Insecure,
|
||||
}
|
||||
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
@@ -146,7 +183,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
|
||||
// Connect establishes the primary connection and starts background workers.
|
||||
func (c *PoolClient) Connect() error {
|
||||
primaryConn, err := c.dialTLS()
|
||||
primaryConn, err := c.dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -298,6 +335,138 @@ func (c *PoolClient) dialTLS() (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// serverCapabilities holds the discovered server capabilities
|
||||
type serverCapabilities struct {
|
||||
Transports []string `json:"transports"`
|
||||
Preferred string `json:"preferred"`
|
||||
}
|
||||
|
||||
// dial selects the appropriate transport and establishes a connection
|
||||
func (c *PoolClient) dial() (net.Conn, error) {
|
||||
switch c.transport {
|
||||
case TransportWebSocket:
|
||||
return c.dialWebSocket()
|
||||
case TransportTCP:
|
||||
// User explicitly requested TCP, verify server supports it
|
||||
caps := c.discoverServerCapabilities()
|
||||
if caps != nil && len(caps.Transports) > 0 {
|
||||
tcpSupported := false
|
||||
for _, t := range caps.Transports {
|
||||
if t == "tcp" {
|
||||
tcpSupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tcpSupported {
|
||||
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
|
||||
}
|
||||
}
|
||||
return c.dialTLS()
|
||||
default: // TransportAuto
|
||||
// Check if server address indicates WebSocket
|
||||
if strings.HasPrefix(c.serverAddr, "wss://") {
|
||||
return c.dialWebSocket()
|
||||
}
|
||||
// Query server for preferred transport
|
||||
caps := c.discoverServerCapabilities()
|
||||
if caps != nil && caps.Preferred == "wss" {
|
||||
return c.dialWebSocket()
|
||||
}
|
||||
// Default to TCP
|
||||
return c.dialTLS()
|
||||
}
|
||||
}
|
||||
|
||||
// discoverServerCapabilities queries the server for its capabilities
|
||||
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
|
||||
host, port, err := net.SplitHostPort(c.serverAddr)
|
||||
if err != nil {
|
||||
host = c.serverAddr
|
||||
port = "443"
|
||||
}
|
||||
|
||||
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: c.tlsConfig,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(discoverURL)
|
||||
if err != nil {
|
||||
c.logger.Debug("Failed to discover server capabilities",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
var caps serverCapabilities
|
||||
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debug("Discovered server capabilities",
|
||||
zap.Strings("transports", caps.Transports),
|
||||
zap.String("preferred", caps.Preferred),
|
||||
)
|
||||
|
||||
return &caps
|
||||
}
|
||||
|
||||
// dialWebSocket establishes a WebSocket connection to the server over TLS
|
||||
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
|
||||
// Build WebSocket URL
|
||||
host, port, err := net.SplitHostPort(c.serverAddr)
|
||||
if err != nil {
|
||||
// No port specified, use default
|
||||
host = c.serverAddr
|
||||
port = "443"
|
||||
}
|
||||
|
||||
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
|
||||
|
||||
c.logger.Debug("Connecting via WebSocket over TLS",
|
||||
zap.String("url", wsURL),
|
||||
)
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
TLSClientConfig: c.tlsConfig,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: 256 * 1024,
|
||||
WriteBufferSize: 256 * 1024,
|
||||
}
|
||||
|
||||
// Add authorization header if token is set
|
||||
header := http.Header{}
|
||||
if c.token != "" {
|
||||
header.Set("Authorization", "Bearer "+c.token)
|
||||
}
|
||||
|
||||
ws, resp, err := dialer.Dial(wsURL, header)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
|
||||
}
|
||||
|
||||
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
|
||||
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
|
||||
|
||||
c.logger.Debug("WebSocket connection established",
|
||||
zap.String("remote_addr", ws.RemoteAddr().String()),
|
||||
)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
|
||||
defer c.wg.Done()
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ func (c *PoolClient) addDataSession() error {
|
||||
return fmt.Errorf("server does not support data connections")
|
||||
}
|
||||
|
||||
conn, err := c.dialTLS()
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/wsutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
@@ -94,6 +96,20 @@ type Handler struct {
|
||||
domain string
|
||||
authToken string
|
||||
metricsToken string
|
||||
publicPort int
|
||||
|
||||
// WebSocket tunnel support
|
||||
wsUpgrader websocket.Upgrader
|
||||
wsConnHandler WSConnectionHandler
|
||||
|
||||
// Server capabilities
|
||||
allowedTransports []string
|
||||
allowedTunnelTypes []string
|
||||
}
|
||||
|
||||
// WSConnectionHandler handles WebSocket tunnel connections
|
||||
type WSConnectionHandler interface {
|
||||
HandleWSConnection(conn net.Conn, remoteAddr string)
|
||||
}
|
||||
|
||||
var privateNetworks []*net.IPNet
|
||||
@@ -121,10 +137,86 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
metricsToken: metricsToken,
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 256 * 1024,
|
||||
WriteBufferSize: 256 * 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for tunnel connections
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetWSConnectionHandler sets the handler for WebSocket tunnel connections
|
||||
func (h *Handler) SetWSConnectionHandler(handler WSConnectionHandler) {
|
||||
h.wsConnHandler = handler
|
||||
}
|
||||
|
||||
// SetPublicPort sets the public port for URL generation
|
||||
func (h *Handler) SetPublicPort(port int) {
|
||||
h.publicPort = port
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transport protocols
|
||||
func (h *Handler) SetAllowedTransports(transports []string) {
|
||||
h.allowedTransports = transports
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types
|
||||
func (h *Handler) SetAllowedTunnelTypes(types []string) {
|
||||
h.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (h *Handler) IsTransportAllowed(transport string) bool {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(h.allowedTunnelTypes) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredTransport returns the preferred transport for auto-detection
|
||||
func (h *Handler) GetPreferredTransport() string {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return "tcp"
|
||||
}
|
||||
if len(h.allowedTransports) == 1 {
|
||||
return h.allowedTransports[0]
|
||||
}
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Discovery endpoint for client auto-detection
|
||||
if r.URL.Path == "/_drip/discover" {
|
||||
h.serveDiscovery(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket tunnel endpoint - must be checked before other routes
|
||||
if r.URL.Path == "/_drip/ws" {
|
||||
h.handleTunnelWebSocket(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
@@ -849,3 +941,69 @@ func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdoma
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(htmlContent))
|
||||
}
|
||||
|
||||
// handleTunnelWebSocket handles WebSocket connections for tunnel clients
|
||||
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if WSS transport is allowed
|
||||
if !h.IsTransportAllowed("wss") {
|
||||
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if h.wsConnHandler == nil {
|
||||
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WebSocket for tunnel use
|
||||
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
|
||||
|
||||
// Extract real client IP (support CDN headers)
|
||||
remoteAddr := h.extractClientIP(r)
|
||||
|
||||
h.logger.Info("WebSocket tunnel connection established",
|
||||
zap.String("remote_addr", remoteAddr),
|
||||
)
|
||||
|
||||
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
|
||||
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
|
||||
|
||||
// Handle the connection using the registered handler
|
||||
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
|
||||
}
|
||||
|
||||
// serveDiscovery returns server capabilities for client auto-detection
|
||||
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
transports := h.allowedTransports
|
||||
if len(transports) == 0 {
|
||||
transports = []string{"tcp", "wss"}
|
||||
}
|
||||
|
||||
tunnelTypes := h.allowedTunnelTypes
|
||||
if len(tunnelTypes) == 0 {
|
||||
tunnelTypes = []string{"http", "https", "tcp"}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"transports": transports,
|
||||
"tunnel_types": tunnelTypes,
|
||||
"preferred": h.GetPreferredTransport(),
|
||||
"version": "1",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type Connection struct {
|
||||
subdomain string
|
||||
port int
|
||||
domain string
|
||||
tunnelDomain string
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
@@ -57,10 +58,13 @@ type Connection struct {
|
||||
groupManager *ConnectionGroupManager
|
||||
httpListener *connQueueListener
|
||||
handedOff bool
|
||||
|
||||
// Server capabilities
|
||||
allowedTunnelTypes []string
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
conn: conn,
|
||||
@@ -69,6 +73,7 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
logger: logger,
|
||||
portAlloc: portAlloc,
|
||||
domain: domain,
|
||||
tunnelDomain: tunnelDomain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
@@ -130,6 +135,12 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
// Check if tunnel type is allowed
|
||||
if !c.isTunnelTypeAllowed(string(req.TunnelType)) {
|
||||
c.sendError("tunnel_type_not_allowed", fmt.Sprintf("Tunnel type '%s' is not allowed on this server", req.TunnelType))
|
||||
return fmt.Errorf("tunnel type not allowed: %s", req.TunnelType)
|
||||
}
|
||||
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed")
|
||||
@@ -207,12 +218,12 @@ func (c *Connection) Handle() error {
|
||||
var tunnelURL string
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
if c.publicPort == 443 {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.tunnelDomain)
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.tunnelDomain, c.publicPort)
|
||||
}
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.tunnelDomain, c.port)
|
||||
}
|
||||
|
||||
var tunnelID string
|
||||
@@ -750,3 +761,21 @@ func (c *Connection) sendDataConnectError(code, message string) {
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
_ = protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
|
||||
func (c *Connection) SetAllowedTunnelTypes(types []string) {
|
||||
c.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// isTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(c.allowedTunnelTypes) == 0 {
|
||||
return true // Allow all by default
|
||||
}
|
||||
for _, t := range c.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"drip/internal/server/metrics"
|
||||
"drip/internal/server/proxy"
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/recovery"
|
||||
@@ -27,6 +28,7 @@ type Listener struct {
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
tunnelDomain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
listener net.Listener
|
||||
@@ -40,9 +42,13 @@ type Listener struct {
|
||||
groupManager *ConnectionGroupManager
|
||||
httpServer *http.Server
|
||||
httpListener *connQueueListener
|
||||
|
||||
// Server capabilities
|
||||
allowedTransports []string
|
||||
allowedTunnelTypes []string
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
numCPU := pool.NumCPU()
|
||||
workers := numCPU * 5
|
||||
queueSize := workers * 20
|
||||
@@ -60,7 +66,7 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
// Initialize worker pool metrics
|
||||
metrics.WorkerPoolSize.Set(float64(workers))
|
||||
|
||||
return &Listener{
|
||||
l := &Listener{
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
@@ -68,6 +74,7 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
tunnelDomain: tunnelDomain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
@@ -77,6 +84,14 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
panicMetrics: panicMetrics,
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
}
|
||||
|
||||
// Set up WebSocket connection handler if httpHandler supports it
|
||||
if h, ok := httpHandler.(*proxy.Handler); ok {
|
||||
h.SetWSConnectionHandler(l)
|
||||
h.SetPublicPort(publicPort)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
@@ -182,6 +197,14 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
l.connMu.Unlock()
|
||||
})
|
||||
|
||||
// Check if TCP transport is allowed
|
||||
if !l.IsTransportAllowed("tcp") {
|
||||
l.logger.Warn("TCP transport not allowed, rejecting connection",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
tlsConn, ok := netConn.(*tls.Conn)
|
||||
if !ok {
|
||||
l.logger.Error("Connection is not TLS")
|
||||
@@ -234,7 +257,8 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -334,3 +358,92 @@ func (l *Listener) GetActiveConnections() int {
|
||||
defer l.connMu.RUnlock()
|
||||
return len(l.connections)
|
||||
}
|
||||
|
||||
// HandleWSConnection implements proxy.WSConnectionHandler for WebSocket tunnel connections
|
||||
func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
|
||||
connID := remoteAddr
|
||||
if connID == "" {
|
||||
connID = conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
l.logger.Info("Handling WebSocket tunnel connection",
|
||||
zap.String("remote_addr", connID),
|
||||
)
|
||||
|
||||
// Create connection handler (no TLS verification needed - already done by HTTP server)
|
||||
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = tcpConn
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !tcpConn.IsHandedOff() {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := tcpConn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "websocket: close") {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "tunnel type not allowed") {
|
||||
l.logger.Warn("WebSocket tunnel protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
l.logger.Error("WebSocket tunnel connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transport protocols
|
||||
func (l *Listener) SetAllowedTransports(transports []string) {
|
||||
l.allowedTransports = transports
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types
|
||||
func (l *Listener) SetAllowedTunnelTypes(types []string) {
|
||||
l.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (l *Listener) IsTransportAllowed(transport string) bool {
|
||||
if len(l.allowedTransports) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range l.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
169
internal/shared/wsutil/conn.go
Normal file
169
internal/shared/wsutil/conn.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package wsutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Conn wraps a gorilla/websocket.Conn to implement net.Conn.
|
||||
// It uses binary messages for data transfer, making it compatible
|
||||
// with yamux and the existing frame protocol.
|
||||
type Conn struct {
|
||||
ws *websocket.Conn
|
||||
reader io.Reader
|
||||
readMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
pingStop chan struct{}
|
||||
pingOnce sync.Once
|
||||
}
|
||||
|
||||
// NewConn wraps a WebSocket connection as a net.Conn.
|
||||
func NewConn(ws *websocket.Conn) *Conn {
|
||||
c := &Conn{
|
||||
ws: ws,
|
||||
localAddr: ws.LocalAddr(),
|
||||
remoteAddr: ws.RemoteAddr(),
|
||||
pingStop: make(chan struct{}),
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// NewConnWithPing wraps a WebSocket connection and starts a ping loop
|
||||
// to keep the connection alive through CDN/proxies.
|
||||
func NewConnWithPing(ws *websocket.Conn, pingInterval time.Duration) *Conn {
|
||||
c := NewConn(ws)
|
||||
c.startPingLoop(pingInterval)
|
||||
return c
|
||||
}
|
||||
|
||||
// Read reads data from the WebSocket connection.
|
||||
// It handles WebSocket message boundaries transparently, presenting
|
||||
// a continuous byte stream to the caller.
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
c.readMu.Lock()
|
||||
defer c.readMu.Unlock()
|
||||
|
||||
for {
|
||||
if c.reader == nil {
|
||||
messageType, reader, err := c.ws.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Only accept binary messages for tunnel data
|
||||
if messageType != websocket.BinaryMessage {
|
||||
// Skip non-binary messages (text, ping/pong handled by gorilla)
|
||||
continue
|
||||
}
|
||||
c.reader = reader
|
||||
}
|
||||
|
||||
n, err := c.reader.Read(p)
|
||||
if err == io.EOF {
|
||||
// Current message exhausted, get next message
|
||||
c.reader = nil
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes data to the WebSocket connection as a binary message.
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
err := c.ws.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close closes the WebSocket connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.pingOnce.Do(func() {
|
||||
close(c.pingStop)
|
||||
})
|
||||
|
||||
// Send close message before closing
|
||||
c.writeMu.Lock()
|
||||
_ = c.ws.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.writeMu.Unlock()
|
||||
|
||||
return c.ws.Close()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines.
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
if err := c.ws.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.ws.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.ws.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline.
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.ws.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// startPingLoop starts a goroutine that sends periodic ping messages
|
||||
// to keep the connection alive through CDN/proxies like Cloudflare.
|
||||
func (c *Conn) startPingLoop(interval time.Duration) {
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Second
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.pingStop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.writeMu.Lock()
|
||||
err := c.ws.WriteControl(
|
||||
websocket.PingMessage,
|
||||
[]byte{},
|
||||
time.Now().Add(10*time.Second),
|
||||
)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// UnderlyingConn returns the underlying WebSocket connection.
|
||||
// Use with caution as direct access bypasses the mutex protection.
|
||||
func (c *Conn) UnderlyingConn() *websocket.Conn {
|
||||
return c.ws
|
||||
}
|
||||
@@ -9,26 +9,31 @@ import (
|
||||
|
||||
// ServerConfig holds the server configuration
|
||||
type ServerConfig struct {
|
||||
// Server settings
|
||||
Port int
|
||||
PublicPort int // Port to display in URLs (for reverse proxy scenarios)
|
||||
Domain string
|
||||
Port int
|
||||
PublicPort int // Port to display in URLs (for reverse proxy scenarios)
|
||||
Domain string // Domain for client connections (e.g., connect.example.com)
|
||||
TunnelDomain string // Domain for tunnel URLs (e.g., example.com for *.example.com)
|
||||
|
||||
// TCP tunnel dynamic port allocation
|
||||
TCPPortMin int
|
||||
TCPPortMax int
|
||||
|
||||
// TLS/SSL settings
|
||||
// TLS settings
|
||||
TLSEnabled bool
|
||||
TLSCertFile string
|
||||
TLSKeyFile string
|
||||
AutoTLS bool // Automatic Let's Encrypt
|
||||
|
||||
// Security
|
||||
AuthToken string
|
||||
|
||||
// Logging
|
||||
Debug bool
|
||||
|
||||
// Allowed transports: "tcp", "wss", or "tcp,wss" (default: "tcp,wss")
|
||||
AllowedTransports []string
|
||||
|
||||
// Allowed tunnel types: "http", "https", "tcp" (default: all)
|
||||
AllowedTunnelTypes []string
|
||||
}
|
||||
|
||||
// Validate checks if the server configuration is valid
|
||||
@@ -51,6 +56,11 @@ func (c *ServerConfig) Validate() error {
|
||||
return fmt.Errorf("domain should not contain port, got: %s", c.Domain)
|
||||
}
|
||||
|
||||
// Validate tunnel domain if set
|
||||
if c.TunnelDomain != "" && strings.Contains(c.TunnelDomain, ":") {
|
||||
return fmt.Errorf("tunnel domain should not contain port, got: %s", c.TunnelDomain)
|
||||
}
|
||||
|
||||
// Validate TCP port range
|
||||
if c.TCPPortMin < 1 || c.TCPPortMin > 65535 {
|
||||
return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin)
|
||||
@@ -118,10 +128,10 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0), // Enable session resumption (0 = default size)
|
||||
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
||||
PreferServerCipherSuites: true,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
@@ -135,10 +145,10 @@ func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
func GetClientTLSConfigInsecure() *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0), // Enable session resumption (0 = default size)
|
||||
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
||||
PreferServerCipherSuites: true,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
@@ -146,22 +156,3 @@ func GetClientTLSConfigInsecure() *tls.Config {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerURL returns the server URL based on configuration
|
||||
func (c *ServerConfig) GetServerURL() string {
|
||||
protocol := "http"
|
||||
if c.TLSEnabled {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
if c.Port == 80 || (c.TLSEnabled && c.Port == 443) {
|
||||
return fmt.Sprintf("%s://%s", protocol, c.Domain)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s:%d", protocol, c.Domain, c.Port)
|
||||
}
|
||||
|
||||
// GetTCPAddress returns the TCP address for tunnel connections
|
||||
func (c *ServerConfig) GetTCPAddress() string {
|
||||
return fmt.Sprintf("%s:%d", c.Domain, c.Port)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user