feat: Add Bearer Token authentication support and optimize code structure

- Add Bearer Token authentication, supporting tunnel access control via the --auth-bearer parameter
- Refactor large modules into smaller, more focused components to improve code maintainability
- Update dependency versions, including golang.org/x/crypto, golang.org/x/net, etc.
- Add SilenceUsage and SilenceErrors configuration for all CLI commands
- Modify connector configuration structure to support the new authentication method
- Update recent change log in README with new feature descriptions

BREAKING CHANGE: Authentication via Bearer Token is now supported, requiring the new --auth-bearer parameter
This commit is contained in:
zhiqing
2026-01-29 14:40:53 +08:00
parent 3256a3486f
commit 307cf8e6cc
50 changed files with 3338 additions and 1611 deletions

View File

@@ -33,6 +33,13 @@
- **Actually free** - Use your own domain, no paid tiers or feature restrictions
- **Open source** - BSD 3-Clause License
## Recent Changes
### 2025-01-29
- **Bearer Token Authentication** - Added bearer token authentication support for tunnel access control
- **Code Optimization** - Refactored large modules into smaller, focused components for better maintainability
## Quick Start
### Install

View File

@@ -33,6 +33,13 @@
- **真的免费** - 用你自己的域名,没有付费档位或功能阉割
- **开源** - BSD 3-Clause 协议
## 最近更新
### 2025-01-29
- **Bearer Token 认证** - 新增 Bearer Token 认证支持,用于隧道访问控制
- **代码优化** - 将大型模块重构为更小、更专注的组件,提升可维护性
## 快速开始
### 安装

18
go.mod
View File

@@ -1,6 +1,6 @@
module drip
go 1.25.4
go 1.25.5
require (
github.com/charmbracelet/lipgloss v1.1.0
@@ -10,9 +10,9 @@ require (
github.com/prometheus/client_golang v1.23.2
github.com/spf13/cobra v1.10.2
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.46.0
golang.org/x/net v0.48.0
golang.org/x/sys v0.39.0
golang.org/x/crypto v0.47.0
golang.org/x/net v0.49.0
golang.org/x/sys v0.40.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -21,12 +21,12 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.3 // indirect
github.com/charmbracelet/x/ansi v0.11.4 // indirect
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.6.1 // indirect
github.com/clipperhouse/displaywidth v0.7.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/clipperhouse/uax29/v2 v2.3.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@@ -34,13 +34,13 @@ require (
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.4 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

32
go.sum
View File

@@ -8,18 +8,18 @@ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI=
github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI=
github.com/charmbracelet/x/ansi v0.11.4 h1:6G65PLu6HjmE858CnTUQY1LXT3ZUWwfvqEROLF8vqHI=
github.com/charmbracelet/x/ansi v0.11.4/go.mod h1:/5AZ+UfWExW3int5H5ugnsG/PWjNcSQcwYsHBlPFQN4=
github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4=
github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.6.1 h1:/zMlAezfDzT2xy6acHBzwIfyu2ic0hgkT83UX5EY2gY=
github.com/clipperhouse/displaywidth v0.6.1/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
github.com/clipperhouse/displaywidth v0.7.0 h1:QNv1GYsnLX9QBrcWUtMlogpTXuM5FVnBwKWp1O5NwmE=
github.com/clipperhouse/displaywidth v0.7.0/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/clipperhouse/uax29/v2 v2.3.1 h1:RjM8gnVbFbgI67SBekIC7ihFpyXwRPYWXn9BZActHbw=
github.com/clipperhouse/uax29/v2 v2.3.1/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -57,8 +57,8 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
@@ -84,17 +84,17 @@ go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -28,9 +28,11 @@ Examples:
drip attach tcp 5432 Attach to TCP tunnel on port 5432
Press Ctrl+C to detach (tunnel will continue running).`,
Aliases: []string{"logs", "tail"},
Args: cobra.MaximumNArgs(2),
RunE: runAttach,
Aliases: []string{"logs", "tail"},
Args: cobra.MaximumNArgs(2),
RunE: runAttach,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {

View File

@@ -13,44 +13,56 @@ import (
)
var configCmd = &cobra.Command{
Use: "config",
Short: "Manage configuration",
Long: "Manage Drip client configuration (server, token, tunnels)",
Use: "config",
Short: "Manage configuration",
Long: "Manage Drip client configuration (server, token, tunnels)",
SilenceUsage: true,
SilenceErrors: true,
}
var configInitCmd = &cobra.Command{
Use: "init",
Short: "Initialize configuration interactively",
Long: "Initialize Drip configuration with interactive prompts",
RunE: runConfigInit,
Use: "init",
Short: "Initialize configuration interactively",
Long: "Initialize Drip configuration with interactive prompts",
RunE: runConfigInit,
SilenceUsage: true,
SilenceErrors: true,
}
var configShowCmd = &cobra.Command{
Use: "show",
Short: "Show current configuration",
Long: "Display the current Drip configuration",
RunE: runConfigShow,
Use: "show",
Short: "Show current configuration",
Long: "Display the current Drip configuration",
RunE: runConfigShow,
SilenceUsage: true,
SilenceErrors: true,
}
var configSetCmd = &cobra.Command{
Use: "set",
Short: "Set configuration values",
Long: "Set specific configuration values (server, token)",
RunE: runConfigSet,
Use: "set",
Short: "Set configuration values",
Long: "Set specific configuration values (server, token)",
RunE: runConfigSet,
SilenceUsage: true,
SilenceErrors: true,
}
var configResetCmd = &cobra.Command{
Use: "reset",
Short: "Reset configuration",
Long: "Delete the configuration file",
RunE: runConfigReset,
Use: "reset",
Short: "Reset configuration",
Long: "Delete the configuration file",
RunE: runConfigReset,
SilenceUsage: true,
SilenceErrors: true,
}
var configValidateCmd = &cobra.Command{
Use: "validate",
Short: "Validate configuration",
Long: "Validate the configuration file",
RunE: runConfigValidate,
Use: "validate",
Short: "Validate configuration",
Long: "Validate the configuration file",
RunE: runConfigValidate,
SilenceUsage: true,
SilenceErrors: true,
}
var (

View File

@@ -19,6 +19,7 @@ var (
allowIPs []string
denyIPs []string
authPass string
authBearer string
transport string
)
@@ -34,6 +35,7 @@ Example:
drip http 3000 --allow-ip 10.0.0.1 Allow single IP
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
drip http 3000 --auth secret Enable proxy authentication with password
drip http 3000 --auth-bearer sk-xxx Enable proxy authentication with bearer token
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
Configuration:
@@ -44,8 +46,10 @@ Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
Args: cobra.ExactArgs(1),
RunE: runHTTP,
Args: cobra.ExactArgs(1),
RunE: runHTTP,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {
@@ -55,6 +59,7 @@ func init() {
httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpCmd.Flags().MarkHidden("daemon-child")
@@ -71,6 +76,10 @@ func runHTTP(_ *cobra.Command, args []string) error {
return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress))
}
if authPass != "" && authBearer != "" {
return fmt.Errorf("cannot use --auth and --auth-bearer together")
}
serverAddr, token, err := resolveServerAddrAndToken("http", port)
if err != nil {
return err
@@ -87,6 +96,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
AllowIPs: allowIPs,
DenyIPs: denyIPs,
AuthPass: authPass,
AuthBearer: authBearer,
Transport: parseTransport(transport),
}

View File

@@ -22,6 +22,7 @@ Example:
drip https 443 --allow-ip 10.0.0.1 Allow single IP
drip https 443 --deny-ip 1.2.3.4 Block specific IP
drip https 443 --auth secret Enable proxy authentication with password
drip https 443 --auth-bearer sk-xxx Enable proxy authentication with bearer token
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
Configuration:
@@ -32,8 +33,10 @@ Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
Args: cobra.ExactArgs(1),
RunE: runHTTPS,
Args: cobra.ExactArgs(1),
RunE: runHTTPS,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {
@@ -43,6 +46,7 @@ func init() {
httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpsCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
@@ -59,6 +63,10 @@ func runHTTPS(_ *cobra.Command, args []string) error {
return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress))
}
if authPass != "" && authBearer != "" {
return fmt.Errorf("cannot use --auth and --auth-bearer together")
}
serverAddr, token, err := resolveServerAddrAndToken("https", port)
if err != nil {
return err
@@ -75,6 +83,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
AllowIPs: allowIPs,
DenyIPs: denyIPs,
AuthPass: authPass,
AuthBearer: authBearer,
Transport: parseTransport(transport),
}

View File

@@ -33,8 +33,10 @@ This command shows:
In interactive mode, you can select a tunnel to:
- Attach: View real-time logs
- Stop: Terminate the tunnel`,
Aliases: []string{"ls", "ps", "status"},
RunE: runList,
Aliases: []string{"ls", "ps", "status"},
RunE: runList,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {

View File

@@ -45,6 +45,8 @@ Features:
✓ Auto-save configuration
✓ Custom subdomains
✓ Authentication via token`,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {

View File

@@ -22,28 +22,30 @@ import (
)
var (
serverPort int
serverPublicPort int
serverDomain string
serverTunnelDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverTransports string
serverTunnelTypes string
serverConfigFile string
serverPort int
serverPublicPort int
serverDomain string
serverTunnelDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverTransports string
serverTunnelTypes string
serverConfigFile string
)
var serverCmd = &cobra.Command{
Use: "server",
Short: "Start Drip server",
Long: `Start the Drip tunnel server to accept client connections`,
RunE: runServer,
Use: "server",
Short: "Start Drip server",
Long: `Start the Drip tunnel server to accept client connections`,
RunE: runServer,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {
@@ -285,11 +287,29 @@ func runServer(cmd *cobra.Command, _ []string) error {
listenAddr := fmt.Sprintf("0.0.0.0:%d", cfg.Port)
httpHandler := proxy.NewHandler(tunnelManager, logger, cfg.Domain, cfg.TunnelDomain, cfg.AuthToken, cfg.MetricsToken)
httpHandler := proxy.NewHandler(proxy.HandlerConfig{
Manager: tunnelManager,
Logger: logger,
ServerDomain: cfg.Domain,
TunnelDomain: cfg.TunnelDomain,
AuthToken: cfg.AuthToken,
MetricsToken: cfg.MetricsToken,
})
httpHandler.SetAllowedTransports(cfg.AllowedTransports)
httpHandler.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
listener := tcp.NewListener(listenAddr, tlsConfig, cfg.AuthToken, tunnelManager, logger, portAllocator, cfg.Domain, cfg.TunnelDomain, cfg.PublicPort, httpHandler)
listener := tcp.NewListener(tcp.ListenerConfig{
Address: listenAddr,
TLSConfig: tlsConfig,
AuthToken: cfg.AuthToken,
Manager: tunnelManager,
Logger: logger,
PortAlloc: portAllocator,
Domain: cfg.Domain,
TunnelDomain: cfg.TunnelDomain,
PublicPort: cfg.PublicPort,
HTTPHandler: httpHandler,
})
listener.SetAllowedTransports(cfg.AllowedTransports)
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)

View File

@@ -54,7 +54,9 @@ Configuration file example (~/.drip/config.yaml):
allow_ips:
- 192.168.0.0/16
- 10.0.0.0/8`,
RunE: runStart,
RunE: runStart,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {
@@ -238,6 +240,7 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
AllowIPs: t.AllowIPs,
DenyIPs: t.DenyIPs,
AuthPass: t.Auth,
AuthBearer: t.AuthBearer,
Transport: transport,
}
}

View File

@@ -19,9 +19,11 @@ Examples:
drip stop all Stop all running tunnels
Use 'drip list' to see running tunnels.`,
Aliases: []string{"kill"},
Args: cobra.MinimumNArgs(1),
RunE: runStop,
Aliases: []string{"kill"},
Args: cobra.MinimumNArgs(1),
RunE: runStop,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {

View File

@@ -41,8 +41,10 @@ Transport options:
Note: TCP tunnels require dynamic port allocation on the server.
When using CDN (--transport wss), the server must still expose the allocated port directly.`,
Args: cobra.ExactArgs(1),
RunE: runTCP,
Args: cobra.ExactArgs(1),
RunE: runTCP,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {

View File

@@ -24,6 +24,12 @@ func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAd
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if authPass != "" {
daemonArgs = append(daemonArgs, "--auth", authPass)
}
if authBearer != "" {
daemonArgs = append(daemonArgs, "--auth-bearer", authBearer)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}

View File

@@ -0,0 +1,197 @@
package tcp
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"time"
json "github.com/goccy/go-json"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"drip/internal/shared/wsutil"
)
// serverCapabilities holds the discovered server capabilities
type serverCapabilities struct {
Transports []string `json:"transports"`
Preferred string `json:"preferred"`
}
// ConnectionDialer handles establishing connections to the server.
type ConnectionDialer struct {
serverAddr string
tlsConfig *tls.Config
token string
transport TransportType
logger *zap.Logger
}
// NewConnectionDialer creates a new connection dialer.
func NewConnectionDialer(
serverAddr string,
tlsConfig *tls.Config,
token string,
transport TransportType,
logger *zap.Logger,
) *ConnectionDialer {
return &ConnectionDialer{
serverAddr: serverAddr,
tlsConfig: tlsConfig,
token: token,
transport: transport,
logger: logger,
}
}
// Dial establishes a connection using the appropriate transport.
func (d *ConnectionDialer) Dial() (net.Conn, error) {
switch d.transport {
case TransportWebSocket:
return d.dialWebSocket()
case TransportTCP:
// User explicitly requested TCP, verify server supports it
caps := d.discoverServerCapabilities()
if caps != nil && len(caps.Transports) > 0 {
tcpSupported := false
for _, t := range caps.Transports {
if t == "tcp" {
tcpSupported = true
break
}
}
if !tcpSupported {
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
}
}
return d.dialTLS()
default: // TransportAuto
// Check if server address indicates WebSocket
if strings.HasPrefix(d.serverAddr, "wss://") {
return d.dialWebSocket()
}
// Query server for preferred transport
caps := d.discoverServerCapabilities()
if caps != nil && caps.Preferred == "wss" {
return d.dialWebSocket()
}
// Default to TCP
return d.dialTLS()
}
}
// dialTLS establishes a TLS connection to the server.
func (d *ConnectionDialer) dialTLS() (net.Conn, error) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", d.serverAddr, d.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
_ = conn.Close()
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
return conn, nil
}
// dialWebSocket establishes a WebSocket connection to the server over TLS.
func (d *ConnectionDialer) dialWebSocket() (net.Conn, error) {
// Build WebSocket URL
host, port, err := net.SplitHostPort(d.serverAddr)
if err != nil {
// No port specified, use default
host = d.serverAddr
port = "443"
}
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
d.logger.Debug("Connecting via WebSocket over TLS",
zap.String("url", wsURL),
)
dialer := websocket.Dialer{
TLSClientConfig: d.tlsConfig,
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
}
// Add authorization header if token is set
header := http.Header{}
if d.token != "" {
header.Set("Authorization", "Bearer "+d.token)
}
ws, resp, err := dialer.Dial(wsURL, header)
if err != nil {
if resp != nil {
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
}
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
}
d.logger.Info("Connected via WebSocket over TLS",
zap.String("url", wsURL),
)
// Wrap WebSocket connection to implement net.Conn with ping loop for CDN keep-alive
return wsutil.NewConnWithPing(ws, 30*time.Second), nil
}
// discoverServerCapabilities queries the server for its capabilities.
func (d *ConnectionDialer) discoverServerCapabilities() *serverCapabilities {
host, port, err := net.SplitHostPort(d.serverAddr)
if err != nil {
host = d.serverAddr
port = "443"
}
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: d.tlsConfig,
},
}
resp, err := client.Get(discoverURL)
if err != nil {
d.logger.Debug("Failed to discover server capabilities",
zap.Error(err),
)
return nil
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil
}
var caps serverCapabilities
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
return nil
}
d.logger.Debug("Discovered server capabilities",
zap.Strings("transports", caps.Transports),
zap.String("preferred", caps.Preferred),
)
return &caps
}

View File

@@ -41,7 +41,8 @@ type ConnectorConfig struct {
DenyIPs []string
// Proxy authentication
AuthPass string
AuthPass string
AuthBearer string
// Transport protocol selection
Transport TransportType

View File

@@ -13,7 +13,6 @@ import (
"sync/atomic"
"time"
"github.com/gorilla/websocket"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
@@ -22,7 +21,6 @@ import (
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/internal/shared/wsutil"
"drip/pkg/config"
)
@@ -71,11 +69,18 @@ type PoolClient struct {
allowIPs []string
denyIPs []string
authPass string
authPass string
authBearer string
// Transport protocol selection
transport TransportType
insecure bool
// Connection dialer
dialer *ConnectionDialer
// Session scaler
scaler *SessionScaler
}
// NewPoolClient creates a new pool client.
@@ -169,8 +174,10 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
allowIPs: cfg.AllowIPs,
denyIPs: cfg.DenyIPs,
authPass: cfg.AuthPass,
authBearer: cfg.AuthBearer,
transport: transport,
insecure: cfg.Insecure,
dialer: NewConnectionDialer(serverAddr, tlsConfig, cfg.Token, transport, logger),
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
@@ -183,7 +190,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dial()
primaryConn, err := c.dialer.Dial()
if err != nil {
return err
}
@@ -208,9 +215,16 @@ func (c *PoolClient) Connect() error {
}
}
if c.authPass != "" {
if c.authBearer != "" {
req.ProxyAuth = &protocol.ProxyAuth{
Enabled: true,
Type: "bearer",
Token: c.authBearer,
}
} else if c.authPass != "" {
req.ProxyAuth = &protocol.ProxyAuth{
Enabled: true,
Type: "password",
Password: c.authPass,
}
}
@@ -299,8 +313,9 @@ func (c *PoolClient) Connect() error {
c.warmupSessions()
c.wg.Add(1)
go c.scalerLoop()
// Initialize and start session scaler
c.scaler = NewSessionScaler(c, c.logger, c.stopCh, &c.wg)
c.scaler.Start()
}
go func() {
@@ -311,162 +326,6 @@ func (c *PoolClient) Connect() error {
return nil
}
func (c *PoolClient) dialTLS() (net.Conn, error) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
_ = conn.Close()
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
return conn, nil
}
// serverCapabilities holds the discovered server capabilities
type serverCapabilities struct {
Transports []string `json:"transports"`
Preferred string `json:"preferred"`
}
// dial selects the appropriate transport and establishes a connection
func (c *PoolClient) dial() (net.Conn, error) {
switch c.transport {
case TransportWebSocket:
return c.dialWebSocket()
case TransportTCP:
// User explicitly requested TCP, verify server supports it
caps := c.discoverServerCapabilities()
if caps != nil && len(caps.Transports) > 0 {
tcpSupported := false
for _, t := range caps.Transports {
if t == "tcp" {
tcpSupported = true
break
}
}
if !tcpSupported {
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
}
}
return c.dialTLS()
default: // TransportAuto
// Check if server address indicates WebSocket
if strings.HasPrefix(c.serverAddr, "wss://") {
return c.dialWebSocket()
}
// Query server for preferred transport
caps := c.discoverServerCapabilities()
if caps != nil && caps.Preferred == "wss" {
return c.dialWebSocket()
}
// Default to TCP
return c.dialTLS()
}
}
// discoverServerCapabilities queries the server for its capabilities
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
host = c.serverAddr
port = "443"
}
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: c.tlsConfig,
},
}
resp, err := client.Get(discoverURL)
if err != nil {
c.logger.Debug("Failed to discover server capabilities",
zap.Error(err),
)
return nil
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil
}
var caps serverCapabilities
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
return nil
}
c.logger.Debug("Discovered server capabilities",
zap.Strings("transports", caps.Transports),
zap.String("preferred", caps.Preferred),
)
return &caps
}
// dialWebSocket establishes a WebSocket connection to the server over TLS
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
// Build WebSocket URL
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
// No port specified, use default
host = c.serverAddr
port = "443"
}
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
c.logger.Debug("Connecting via WebSocket over TLS",
zap.String("url", wsURL),
)
dialer := websocket.Dialer{
TLSClientConfig: c.tlsConfig,
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
}
// Add authorization header if token is set
header := http.Header{}
if c.token != "" {
header.Set("Authorization", "Bearer "+c.token)
}
ws, resp, err := dialer.Dial(wsURL, header)
if err != nil {
if resp != nil {
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
}
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
}
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
c.logger.Debug("WebSocket connection established",
zap.String("remote_addr", ws.RemoteAddr().String()),
)
return conn, nil
}
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
@@ -623,12 +482,12 @@ func (c *PoolClient) Close() error {
return closeErr
}
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
if cb == nil {

View File

@@ -188,7 +188,7 @@ func (c *PoolClient) addDataSession() error {
return fmt.Errorf("server does not support data connections")
}
conn, err := c.dial()
conn, err := c.dialer.Dial()
if err != nil {
return err
}

View File

@@ -0,0 +1,150 @@
package tcp
import (
"sync"
"time"
"go.uber.org/zap"
)
// SessionScaler manages automatic scaling of yamux sessions based on load.
type SessionScaler struct {
client *PoolClient
logger *zap.Logger
stopCh <-chan struct{}
wg *sync.WaitGroup
// Scaling configuration
checkInterval time.Duration
scaleUpCooldown time.Duration
scaleDownCooldown time.Duration
capacityPerSession int64
scaleUpLoad float64
scaleDownLoad float64
burstThreshold float64
maxBurstAdd int
}
// NewSessionScaler creates a new session scaler.
func NewSessionScaler(
client *PoolClient,
logger *zap.Logger,
stopCh <-chan struct{},
wg *sync.WaitGroup,
) *SessionScaler {
return &SessionScaler{
client: client,
logger: logger,
stopCh: stopCh,
wg: wg,
checkInterval: 1 * time.Second,
scaleUpCooldown: 1 * time.Second,
scaleDownCooldown: 60 * time.Second,
capacityPerSession: 256,
scaleUpLoad: 0.6,
scaleDownLoad: 0.2,
burstThreshold: 0.9,
maxBurstAdd: 4,
}
}
// Start starts the scaler loop.
func (s *SessionScaler) Start() {
s.wg.Add(1)
go s.scalerLoop()
}
// scalerLoop monitors load and adjusts session count.
func (s *SessionScaler) scalerLoop() {
defer s.wg.Done()
ticker := time.NewTicker(s.checkInterval)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
}
s.client.mu.Lock()
desired := s.client.desiredTotal
if desired == 0 {
desired = s.client.initialSessions
s.client.desiredTotal = desired
}
lastScale := s.client.lastScale
s.client.mu.Unlock()
current := s.client.sessionCount()
if current == 0 {
continue
}
activeConns := s.client.stats.GetActiveConnections()
capacity := int64(current) * s.capacityPerSession
load := float64(activeConns) / float64(capacity)
now := time.Now()
// Burst scaling: rapid scale-up under extreme load
if load > s.burstThreshold && current < s.client.maxSessions {
toAdd := min(s.maxBurstAdd, s.client.maxSessions-current)
s.logger.Info("Burst scaling up sessions",
zap.Int("current", current),
zap.Int("adding", toAdd),
zap.Float64("load", load),
)
for i := 0; i < toAdd; i++ {
_ = s.client.addDataSession()
}
s.client.mu.Lock()
s.client.desiredTotal = current + toAdd
s.client.lastScale = now
s.client.mu.Unlock()
continue
}
// Scale up: add sessions when load is high
if load > s.scaleUpLoad && current < s.client.maxSessions {
if now.Sub(lastScale) < s.scaleUpCooldown {
continue
}
newDesired := min(desired+1, s.client.maxSessions)
if newDesired > desired {
s.logger.Debug("Scaling up sessions",
zap.Int("current", current),
zap.Int("desired", newDesired),
zap.Float64("load", load),
)
_ = s.client.addDataSession()
s.client.mu.Lock()
s.client.desiredTotal = newDesired
s.client.lastScale = now
s.client.mu.Unlock()
}
continue
}
// Scale down: remove idle sessions when load is low
if load < s.scaleDownLoad && current > s.client.minSessions {
if now.Sub(lastScale) < s.scaleDownCooldown {
continue
}
newDesired := max(desired-1, s.client.minSessions)
if newDesired < desired && current > newDesired {
s.logger.Debug("Scaling down sessions",
zap.Int("current", current),
zap.Int("desired", newDesired),
zap.Float64("load", load),
)
s.client.removeIdleSessions(1)
s.client.mu.Lock()
s.client.desiredTotal = newDesired
s.client.lastScale = now
s.client.mu.Unlock()
}
}
}
}

View File

@@ -0,0 +1,412 @@
package proxy
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"fmt"
"html"
"net/http"
"strings"
"sync"
"time"
"drip/internal/server/tunnel"
"drip/internal/shared/protocol"
)
const authCookieName = "drip_auth"
const authSessionDuration = 24 * time.Hour
const (
authRateLimitWindow = 1 * time.Minute
authRateLimitMax = 10
authRateLimitLockout = 5 * time.Minute
authRateLimitLockoutThreshold = 20
)
type authSession struct {
subdomain string
expiresAt time.Time
}
type authSessionStore struct {
mu sync.RWMutex
sessions map[string]*authSession
}
type authRateLimitEntry struct {
failures int
windowStart time.Time
lockedUntil time.Time
}
type authRateLimiter struct {
mu sync.RWMutex
entries map[string]*authRateLimitEntry
}
var sessionStore = &authSessionStore{
sessions: make(map[string]*authSession),
}
var authLimiter = &authRateLimiter{
entries: make(map[string]*authRateLimitEntry),
}
func init() {
go authLimiter.startCleanupLoop()
go sessionStore.startCleanupLoop()
}
func (rl *authRateLimiter) startCleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.cleanup()
}
}
func (s *authSessionStore) startCleanupLoop() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.cleanup()
}
}
func (s *authSessionStore) cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for token, session := range s.sessions {
if now.After(session.expiresAt) {
delete(s.sessions, token)
}
}
}
func (rl *authRateLimiter) isRateLimited(ip string) bool {
if ip == "" {
return false
}
rl.mu.RLock()
entry, exists := rl.entries[ip]
rl.mu.RUnlock()
if !exists {
return false
}
now := time.Now()
if !entry.lockedUntil.IsZero() && now.Before(entry.lockedUntil) {
return true
}
if now.Sub(entry.windowStart) < authRateLimitWindow && entry.failures >= authRateLimitMax {
return true
}
return false
}
func (rl *authRateLimiter) recordFailure(ip string) {
if ip == "" {
return
}
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
entry, exists := rl.entries[ip]
if !exists {
rl.entries[ip] = &authRateLimitEntry{
failures: 1,
windowStart: now,
}
return
}
if now.Sub(entry.windowStart) >= authRateLimitWindow {
entry.failures = 1
entry.windowStart = now
entry.lockedUntil = time.Time{}
return
}
entry.failures++
if entry.failures >= authRateLimitLockoutThreshold {
entry.lockedUntil = now.Add(authRateLimitLockout)
}
}
func (rl *authRateLimiter) resetFailures(ip string) {
if ip == "" {
return
}
rl.mu.Lock()
delete(rl.entries, ip)
rl.mu.Unlock()
}
func (rl *authRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for ip, entry := range rl.entries {
windowExpired := now.Sub(entry.windowStart) >= authRateLimitWindow
lockoutExpired := entry.lockedUntil.IsZero() || now.After(entry.lockedUntil)
if windowExpired && lockoutExpired {
delete(rl.entries, ip)
}
}
}
func (s *authSessionStore) create(subdomain string) string {
token := generateSessionToken()
s.mu.Lock()
s.sessions[token] = &authSession{
subdomain: subdomain,
expiresAt: time.Now().Add(authSessionDuration),
}
s.mu.Unlock()
return token
}
func (s *authSessionStore) validate(token, subdomain string) bool {
s.mu.RLock()
session, ok := s.sessions[token]
s.mu.RUnlock()
if !ok {
return false
}
if time.Now().After(session.expiresAt) {
s.mu.Lock()
delete(s.sessions, token)
s.mu.Unlock()
return false
}
return session.subdomain == subdomain
}
func generateSessionToken() string {
b := make([]byte, 32)
rand.Read(b)
hash := sha256.Sum256(b)
return hex.EncodeToString(hash[:])
}
func isBearerProxyAuth(auth *protocol.ProxyAuth) bool {
if auth == nil {
return false
}
if auth.Type != "" {
return strings.EqualFold(auth.Type, "bearer")
}
return auth.Token != ""
}
func extractBearerToken(header string) string {
if header == "" {
return ""
}
parts := strings.Fields(header)
if len(parts) < 2 {
return ""
}
if !strings.EqualFold(parts[0], "Bearer") {
return ""
}
return parts[1]
}
func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool {
cookie, err := r.Cookie(authCookieName + "_" + subdomain)
if err != nil {
return false
}
return sessionStore.validate(cookie.Value, subdomain)
}
func (h *Handler) isBearerAuthenticated(r *http.Request, auth *protocol.ProxyAuth) bool {
token := extractBearerToken(r.Header.Get("Authorization"))
if token == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(token), []byte(auth.Token)) == 1
}
func (h *Handler) serveBearerAuthRequired(w http.ResponseWriter, realm string) {
if realm == "" {
realm = "drip"
}
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm))
w.Header().Set("Cache-Control", "no-store")
http.Error(w, "Unauthorized: provide bearer token via Authorization header", http.StatusUnauthorized)
}
func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) {
h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, "")
}
func (h *Handler) handleProxyLoginWithRateLimit(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string, clientIP string) {
if r.Method != http.MethodPost {
h.serveLoginPage(w, r, subdomain, "")
return
}
if clientIP != "" && authLimiter.isRateLimited(clientIP) {
w.Header().Set("Retry-After", "60")
http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests)
return
}
if err := r.ParseForm(); err != nil {
h.serveLoginPage(w, r, subdomain, "Invalid form data")
return
}
password := r.FormValue("password")
if !tconn.ValidateProxyAuth(password) {
if clientIP != "" {
authLimiter.recordFailure(clientIP)
}
h.serveLoginPage(w, r, subdomain, "Invalid password")
return
}
if clientIP != "" {
authLimiter.resetFailures(clientIP)
}
token := sessionStore.create(subdomain)
http.SetCookie(w, &http.Cookie{
Name: authCookieName + "_" + subdomain,
Value: token,
Path: "/",
MaxAge: int(authSessionDuration.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
redirectURL := r.FormValue("redirect")
if redirectURL == "" || redirectURL == "/_drip/login" {
redirectURL = "/"
}
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) {
redirectURL := r.URL.Path
if r.URL.RawQuery != "" {
redirectURL += "?" + r.URL.RawQuery
}
if redirectURL == "/_drip/login" {
redirectURL = "/"
}
errorHTML := ""
if errorMsg != "" {
errorHTML = fmt.Sprintf(`<p class="error">%s</p>`, html.EscapeString(errorMsg))
}
safeRedirectURL := html.EscapeString(redirectURL)
htmlContent := fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>%s - Drip</title>
`+faviconLink+`
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #fff;
color: #24292f;
line-height: 1.6;
}
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
header { margin-bottom: 48px; }
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
h1 span { margin-right: 8px; }
.desc { color: #57606a; font-size: 16px; }
p { margin-bottom: 24px; }
.error { color: #cf222e; margin-bottom: 16px; }
.input-wrap {
position: relative;
background: #f6f8fa;
border: 1px solid #d0d7de;
border-radius: 6px;
margin-bottom: 12px;
display: flex;
}
.input-wrap input {
flex: 1;
margin: 0;
padding: 12px 16px;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
font-size: 14px;
background: transparent;
border: none;
outline: none;
}
.input-wrap button {
background: #24292f;
color: #fff;
border: none;
padding: 8px 16px;
margin: 4px;
border-radius: 4px;
font-size: 14px;
cursor: pointer;
}
.input-wrap button:hover { background: #32383f; }
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
footer a:hover { color: #0969da; }
</style>
</head>
<body>
<div class="container">
<header>
<h1><span>🔒</span>%s</h1>
<p class="desc">This tunnel is password protected</p>
</header>
%s
<form method="POST" action="/_drip/login">
<input type="hidden" name="redirect" value="%s" />
<div class="input-wrap">
<input type="password" name="password" placeholder="Enter password" required autofocus />
<button type="submit">Continue</button>
</div>
</form>
<footer>
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
</footer>
</div>
</body>
</html>`, subdomain, subdomain, errorHTML, safeRedirectURL)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(htmlContent))
}

View File

@@ -2,12 +2,8 @@ package proxy
import (
"bufio"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"crypto/subtle"
"fmt"
"html"
"io"
"net"
"net/http"
@@ -16,18 +12,14 @@ import (
"sync"
"time"
json "github.com/goccy/go-json"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/wsutil"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
)
// bufio.Reader pool to reduce allocations on hot path
@@ -38,56 +30,14 @@ var bufioReaderPool = sync.Pool{
}
const openStreamTimeout = 3 * time.Second
const authCookieName = "drip_auth"
const authSessionDuration = 24 * time.Hour
type authSession struct {
subdomain string
expiresAt time.Time
}
type authSessionStore struct {
mu sync.RWMutex
sessions map[string]*authSession
}
var sessionStore = &authSessionStore{
sessions: make(map[string]*authSession),
}
func (s *authSessionStore) create(subdomain string) string {
token := generateSessionToken()
s.mu.Lock()
s.sessions[token] = &authSession{
subdomain: subdomain,
expiresAt: time.Now().Add(authSessionDuration),
}
s.mu.Unlock()
return token
}
func (s *authSessionStore) validate(token, subdomain string) bool {
s.mu.RLock()
session, ok := s.sessions[token]
s.mu.RUnlock()
if !ok {
return false
}
if time.Now().After(session.expiresAt) {
s.mu.Lock()
delete(s.sessions, token)
s.mu.Unlock()
return false
}
return session.subdomain == subdomain
}
func generateSessionToken() string {
b := make([]byte, 32)
rand.Read(b)
hash := sha256.Sum256(b)
return hex.EncodeToString(hash[:])
type HandlerConfig struct {
Manager *tunnel.Manager
Logger *zap.Logger
ServerDomain string
TunnelDomain string
AuthToken string
MetricsToken string
}
type Handler struct {
@@ -113,32 +63,14 @@ type WSConnectionHandler interface {
HandleWSConnection(conn net.Conn, remoteAddr string)
}
var privateNetworks []*net.IPNet
func init() {
privateCIDRs := []string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"::1/128",
"fc00::/7",
"fe80::/10",
}
for _, cidr := range privateCIDRs {
_, ipNet, _ := net.ParseCIDR(cidr)
privateNetworks = append(privateNetworks, ipNet)
}
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, serverDomain, tunnelDomain string, authToken string, metricsToken string) *Handler {
func NewHandler(cfg HandlerConfig) *Handler {
return &Handler{
manager: manager,
logger: logger,
serverDomain: serverDomain,
tunnelDomain: tunnelDomain,
authToken: authToken,
metricsToken: metricsToken,
manager: cfg.Manager,
logger: cfg.Logger,
serverDomain: cfg.ServerDomain,
tunnelDomain: cfg.TunnelDomain,
authToken: cfg.AuthToken,
metricsToken: cfg.MetricsToken,
wsUpgrader: websocket.Upgrader{
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
@@ -253,22 +185,38 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if tconn.HasIPAccessControl() {
clientIP := h.extractClientIP(r)
clientIP := netutil.ExtractClientIP(r)
if !tconn.IsIPAllowed(clientIP) {
http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden)
return
}
}
// Check proxy authentication
if tconn.HasProxyAuth() {
if r.URL.Path == "/_drip/login" {
h.handleProxyLogin(w, r, tconn, subdomain)
if auth := tconn.GetProxyAuth(); auth != nil && auth.Enabled {
clientIP := netutil.ExtractClientIP(r)
if authLimiter.isRateLimited(clientIP) {
w.Header().Set("Retry-After", "60")
http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests)
return
}
if !h.isProxyAuthenticated(r, subdomain) {
h.serveLoginPage(w, r, subdomain, "")
return
if isBearerProxyAuth(auth) {
if !h.isBearerAuthenticated(r, auth) {
authLimiter.recordFailure(clientIP)
h.serveBearerAuthRequired(w, "drip")
return
}
authLimiter.resetFailures(clientIP)
} else {
if r.URL.Path == "/_drip/login" {
h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, clientIP)
return
}
if !h.isProxyAuthenticated(r, subdomain) {
h.serveLoginPage(w, r, subdomain, "")
return
}
}
}
@@ -283,14 +231,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if httputil.IsWebSocketUpgrade(r) {
if h.isWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
w.Header().Set("Connection", "close")
httputil.SetCloseConnection(w)
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
@@ -305,7 +253,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
)
if err := r.Write(countingStream); err != nil {
w.Header().Set("Connection", "close")
httputil.SetCloseConnection(w)
_ = r.Body.Close()
http.Error(w, "Forward failed", http.StatusBadGateway)
return
@@ -316,7 +264,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resp, err := http.ReadResponse(reader, r)
if err != nil {
bufioReaderPool.Put(reader)
w.Header().Set("Connection", "close")
httputil.SetCloseConnection(w)
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
@@ -334,7 +282,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
httputil.SetContentLength(w, resp.ContentLength)
} else {
w.Header().Del("Content-Length")
}
@@ -343,7 +291,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
httputil.SetContentLength(w, resp.ContentLength)
} else {
w.Header().Del("Content-Length")
}
@@ -396,58 +344,6 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err
}
}
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
tconn.IncActiveConnections()
hj, ok := w.(http.Hijacker)
if !ok {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
return
}
clientConn, clientBuf, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
return
}
if err := r.Write(stream); err != nil {
stream.Close()
clientConn.Close()
tconn.DecActiveConnections()
return
}
go func() {
defer stream.Close()
defer clientConn.Close()
defer tconn.DecActiveConnections()
var clientRW io.ReadWriteCloser = clientConn
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
clientRW = &bufferedReadWriteCloser{
Reader: clientBuf.Reader,
Conn: clientConn,
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
}()
}
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
@@ -530,571 +426,18 @@ func (h *Handler) extractSubdomain(host string) (string, subdomainResult) {
return "", subdomainNotFound
}
// extractClientIP extracts the client IP from the request.
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
// comes from a private/loopback network (typical reverse proxy setup).
func (h *Handler) extractClientIP(r *http.Request) string {
// First, get the direct remote address
remoteIP := h.extractRemoteIP(r.RemoteAddr)
// Only trust proxy headers if the request comes from a private network
if isPrivateIP(remoteIP) {
// Check X-Forwarded-For header (may contain multiple IPs)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
func (h *Handler) validateMetricsAuth(w http.ResponseWriter, r *http.Request, realm string) bool {
if h.metricsToken == "" {
return true
}
// Fall back to remote address
return remoteIP
}
token := extractBearerToken(r.Header.Get("Authorization"))
// extractRemoteIP extracts the IP address from a remote address string (host:port format).
func (h *Handler) extractRemoteIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return host
}
// isPrivateIP checks if the given IP is a private/loopback address.
func isPrivateIP(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
if subtle.ConstantTimeCompare([]byte(token), []byte(h.metricsToken)) != 1 {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm))
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
return false
}
for _, network := range privateNetworks {
if network.Contains(parsedIP) {
return true
}
}
return false
}
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
html := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
` + faviconLink + `
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #fff;
color: #24292f;
line-height: 1.6;
}
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
header { margin-bottom: 48px; }
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
h1 span { margin-right: 8px; }
.desc { color: #57606a; font-size: 16px; }
h2 { font-size: 18px; font-weight: 600; margin: 32px 0 12px; }
.code-wrap {
position: relative;
background: #f6f8fa;
border: 1px solid #d0d7de;
border-radius: 6px;
margin-bottom: 12px;
}
.code-wrap pre {
margin: 0;
padding: 12px 16px;
padding-right: 60px;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
font-size: 14px;
overflow-x: auto;
white-space: pre-wrap;
word-break: break-all;
}
.copy-btn {
position: absolute;
top: 8px;
right: 8px;
background: #fff;
border: 1px solid #d0d7de;
border-radius: 6px;
padding: 4px 6px;
cursor: pointer;
color: #57606a;
display: flex;
align-items: center;
justify-content: center;
}
.copy-btn:hover { background: #f3f4f6; }
.copy-btn svg { width: 16px; height: 16px; }
.copy-btn .check { display: none; color: #1a7f37; }
.copy-btn.copied .copy { display: none; }
.copy-btn.copied .check { display: block; }
.links { margin-top: 32px; display: flex; gap: 24px; flex-wrap: wrap; }
.links a { color: #0969da; text-decoration: none; font-size: 14px; }
.links a:hover { text-decoration: underline; }
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
footer a:hover { color: #0969da; }
</style>
</head>
<body>
<div class="container">
<header>
<h1><span>💧</span>Drip</h1>
<p class="desc">Your Tunnel, Your Domain, Anywhere</p>
</header>
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
<h2>Install</h2>
<div class="code-wrap">
<pre>bash &lt;(curl -fsSL https://driptunnel.app/install.sh)</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<h2>Usage</h2>
<div class="code-wrap">
<pre>drip http 3000</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<div class="code-wrap">
<pre>drip https 443</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<div class="code-wrap">
<pre>drip tcp 5432</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<div class="links">
<a href="/health">Health Check</a>
<a href="/stats">Statistics</a>
<a href="/metrics">Prometheus Metrics</a>
</div>
<footer>
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
</footer>
</div>
<script>
function copy(btn) {
const text = btn.previousElementSibling.textContent;
navigator.clipboard.writeText(text).then(() => {
btn.classList.add('copied');
setTimeout(() => { btn.classList.remove('copied'); }, 2000);
});
}
</script>
</body>
</html>`
data := []byte(html)
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func (h *Handler) serveTunnelNotFound(w http.ResponseWriter, r *http.Request) {
html := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>404 - Tunnel Not Found</title>
` + faviconLink + `
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #fff;
color: #24292f;
line-height: 1.6;
}
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
header { margin-bottom: 48px; }
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
h1 span { margin-right: 8px; }
.desc { color: #57606a; font-size: 16px; }
p { margin-bottom: 16px; }
.info-box {
background: #f6f8fa;
border: 1px solid #d0d7de;
border-radius: 6px;
padding: 16px;
margin: 24px 0;
}
.info-box ul {
margin: 12px 0 0 20px;
color: #57606a;
}
.info-box li { margin-bottom: 8px; }
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
footer a:hover { color: #0969da; }
</style>
</head>
<body>
<div class="container">
<header>
<h1><span>🔍</span>Tunnel Not Found</h1>
<p class="desc">The requested tunnel does not exist or has been closed.</p>
</header>
<div class="info-box">
<p>This could happen because:</p>
<ul>
<li>The tunnel was never created</li>
<li>The tunnel has been closed by the owner</li>
<li>The tunnel URL is incorrect</li>
</ul>
</div>
<p>If you are the tunnel owner, please restart your tunnel client.</p>
<footer>
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
</footer>
</div>
</body>
</html>`
data := []byte(html)
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusNotFound)
w.Write(data)
}
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
health := map[string]interface{}{
"status": "ok",
"active_tunnels": h.manager.Count(),
"timestamp": time.Now().Unix(),
}
data, err := json.Marshal(health)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if h.metricsToken != "" {
// Only accept token via Authorization header (Bearer token)
// URL query parameters are insecure (logged, cached, visible in browser history)
var token string
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
if token != h.metricsToken {
w.Header().Set("WWW-Authenticate", `Bearer realm="stats"`)
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
return
}
}
connections := h.manager.List()
// Pre-allocate slice to avoid O(n²) reallocations
tunnelStats := make([]map[string]interface{}, 0, len(connections))
for _, conn := range connections {
if conn == nil {
continue
}
tunnelStats = append(tunnelStats, map[string]interface{}{
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
"bytes_in": conn.GetBytesIn(),
"bytes_out": conn.GetBytesOut(),
"active_connections": conn.GetActiveConnections(),
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
})
}
stats := map[string]interface{}{
"total_tunnels": len(tunnelStats),
"tunnels": tunnelStats,
}
data, err := json.Marshal(stats)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) {
if h.metricsToken != "" {
// Only accept token via Authorization header (Bearer token)
var token string
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
if token != h.metricsToken {
w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`)
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
return
}
}
// Serve Prometheus metrics
promhttp.Handler().ServeHTTP(w, r)
}
type bufferedReadWriteCloser struct {
*bufio.Reader
net.Conn
}
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}
func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool {
cookie, err := r.Cookie(authCookieName + "_" + subdomain)
if err != nil {
return false
}
return sessionStore.validate(cookie.Value, subdomain)
}
func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) {
if r.Method != http.MethodPost {
h.serveLoginPage(w, r, subdomain, "")
return
}
if err := r.ParseForm(); err != nil {
h.serveLoginPage(w, r, subdomain, "Invalid form data")
return
}
password := r.FormValue("password")
if !tconn.ValidateProxyAuth(password) {
h.serveLoginPage(w, r, subdomain, "Invalid password")
return
}
token := sessionStore.create(subdomain)
http.SetCookie(w, &http.Cookie{
Name: authCookieName + "_" + subdomain,
Value: token,
Path: "/",
MaxAge: int(authSessionDuration.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
redirectURL := r.FormValue("redirect")
if redirectURL == "" || redirectURL == "/_drip/login" {
redirectURL = "/"
}
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) {
redirectURL := r.URL.Path
if r.URL.RawQuery != "" {
redirectURL += "?" + r.URL.RawQuery
}
if redirectURL == "/_drip/login" {
redirectURL = "/"
}
errorHTML := ""
if errorMsg != "" {
errorHTML = fmt.Sprintf(`<p class="error">%s</p>`, html.EscapeString(errorMsg))
}
safeRedirectURL := html.EscapeString(redirectURL)
htmlContent := fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>%s - Drip</title>
`+faviconLink+`
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #fff;
color: #24292f;
line-height: 1.6;
}
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
header { margin-bottom: 48px; }
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
h1 span { margin-right: 8px; }
.desc { color: #57606a; font-size: 16px; }
p { margin-bottom: 24px; }
.error { color: #cf222e; margin-bottom: 16px; }
.input-wrap {
position: relative;
background: #f6f8fa;
border: 1px solid #d0d7de;
border-radius: 6px;
margin-bottom: 12px;
display: flex;
}
.input-wrap input {
flex: 1;
margin: 0;
padding: 12px 16px;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
font-size: 14px;
background: transparent;
border: none;
outline: none;
}
.input-wrap button {
background: #24292f;
color: #fff;
border: none;
padding: 8px 16px;
margin: 4px;
border-radius: 4px;
font-size: 14px;
cursor: pointer;
}
.input-wrap button:hover { background: #32383f; }
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
footer a:hover { color: #0969da; }
</style>
</head>
<body>
<div class="container">
<header>
<h1><span>🔒</span>%s</h1>
<p class="desc">This tunnel is password protected</p>
</header>
%s
<form method="POST" action="/_drip/login">
<input type="hidden" name="redirect" value="%s" />
<div class="input-wrap">
<input type="password" name="password" placeholder="Enter password" required autofocus />
<button type="submit">Continue</button>
</div>
</form>
<footer>
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
</footer>
</div>
</body>
</html>`, subdomain, subdomain, errorHTML, safeRedirectURL)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(htmlContent))
}
// handleTunnelWebSocket handles WebSocket connections for tunnel clients
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
// Check if WSS transport is allowed
if !h.IsTransportAllowed("wss") {
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
return
}
if h.wsConnHandler == nil {
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
return
}
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
// Configure WebSocket for tunnel use
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
// Extract real client IP (support CDN headers)
remoteAddr := h.extractClientIP(r)
h.logger.Info("WebSocket tunnel connection established",
zap.String("remote_addr", remoteAddr),
)
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
// Handle the connection using the registered handler
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
}
// serveDiscovery returns server capabilities for client auto-detection
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
transports := h.allowedTransports
if len(transports) == 0 {
transports = []string{"tcp", "wss"}
}
tunnelTypes := h.allowedTunnelTypes
if len(tunnelTypes) == 0 {
tunnelTypes = []string{"http", "https", "tcp"}
}
response := map[string]interface{}{
"transports": transports,
"tunnel_types": tunnelTypes,
"preferred": h.GetPreferredTransport(),
"version": "1",
}
data, err := json.Marshal(response)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-cache")
w.Write(data)
return true
}

View File

@@ -0,0 +1,299 @@
package proxy
import (
"net/http"
"time"
json "github.com/goccy/go-json"
"github.com/prometheus/client_golang/prometheus/promhttp"
"drip/internal/shared/httputil"
)
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
html := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
` + faviconLink + `
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #fff;
color: #24292f;
line-height: 1.6;
}
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
header { margin-bottom: 48px; }
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
h1 span { margin-right: 8px; }
.desc { color: #57606a; font-size: 16px; }
h2 { font-size: 18px; font-weight: 600; margin: 32px 0 12px; }
.code-wrap {
position: relative;
background: #f6f8fa;
border: 1px solid #d0d7de;
border-radius: 6px;
margin-bottom: 12px;
}
.code-wrap pre {
margin: 0;
padding: 12px 16px;
padding-right: 60px;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
font-size: 14px;
overflow-x: auto;
white-space: pre-wrap;
word-break: break-all;
}
.copy-btn {
position: absolute;
top: 8px;
right: 8px;
background: #fff;
border: 1px solid #d0d7de;
border-radius: 6px;
padding: 4px 6px;
cursor: pointer;
color: #57606a;
display: flex;
align-items: center;
justify-content: center;
}
.copy-btn:hover { background: #f3f4f6; }
.copy-btn svg { width: 16px; height: 16px; }
.copy-btn .check { display: none; color: #1a7f37; }
.copy-btn.copied .copy { display: none; }
.copy-btn.copied .check { display: block; }
.links { margin-top: 32px; display: flex; gap: 24px; flex-wrap: wrap; }
.links a { color: #0969da; text-decoration: none; font-size: 14px; }
.links a:hover { text-decoration: underline; }
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
footer a:hover { color: #0969da; }
</style>
</head>
<body>
<div class="container">
<header>
<h1><span>💧</span>Drip</h1>
<p class="desc">Your Tunnel, Your Domain, Anywhere</p>
</header>
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
<h2>Install</h2>
<div class="code-wrap">
<pre>bash &lt;(curl -fsSL https://driptunnel.app/install.sh)</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<h2>Usage</h2>
<div class="code-wrap">
<pre>drip http 3000</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<div class="code-wrap">
<pre>drip https 443</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<div class="code-wrap">
<pre>drip tcp 5432</pre>
<button class="copy-btn" onclick="copy(this)">
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
</button>
</div>
<div class="links">
<a href="/health">Health Check</a>
<a href="/stats">Statistics</a>
<a href="/metrics">Prometheus Metrics</a>
</div>
<footer>
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
</footer>
</div>
<script>
function copy(btn) {
const text = btn.previousElementSibling.textContent;
navigator.clipboard.writeText(text).then(() => {
btn.classList.add('copied');
setTimeout(() => { btn.classList.remove('copied'); }, 2000);
});
}
</script>
</body>
</html>`
httputil.WriteHTML(w, []byte(html))
}
func (h *Handler) serveTunnelNotFound(w http.ResponseWriter, r *http.Request) {
html := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>404 - Tunnel Not Found</title>
` + faviconLink + `
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #fff;
color: #24292f;
line-height: 1.6;
}
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
header { margin-bottom: 48px; }
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
h1 span { margin-right: 8px; }
.desc { color: #57606a; font-size: 16px; }
p { margin-bottom: 16px; }
.info-box {
background: #f6f8fa;
border: 1px solid #d0d7de;
border-radius: 6px;
padding: 16px;
margin: 24px 0;
}
.info-box ul {
margin: 12px 0 0 20px;
color: #57606a;
}
.info-box li { margin-bottom: 8px; }
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
footer a:hover { color: #0969da; }
</style>
</head>
<body>
<div class="container">
<header>
<h1><span>🔍</span>Tunnel Not Found</h1>
<p class="desc">The requested tunnel does not exist or has been closed.</p>
</header>
<div class="info-box">
<p>This could happen because:</p>
<ul>
<li>The tunnel was never created</li>
<li>The tunnel has been closed by the owner</li>
<li>The tunnel URL is incorrect</li>
</ul>
</div>
<p>If you are the tunnel owner, please restart your tunnel client.</p>
<footer>
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
</footer>
</div>
</body>
</html>`
httputil.WriteHTMLWithStatus(w, []byte(html), http.StatusNotFound)
}
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
health := map[string]interface{}{
"status": "ok",
"active_tunnels": h.manager.Count(),
"timestamp": time.Now().Unix(),
}
data, err := json.Marshal(health)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
httputil.WriteJSON(w, data)
}
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if !h.validateMetricsAuth(w, r, "stats") {
return
}
connections := h.manager.List()
tunnelStats := make([]map[string]interface{}, 0, len(connections))
for _, conn := range connections {
if conn == nil {
continue
}
tunnelStats = append(tunnelStats, map[string]interface{}{
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
"bytes_in": conn.GetBytesIn(),
"bytes_out": conn.GetBytesOut(),
"active_connections": conn.GetActiveConnections(),
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
})
}
stats := map[string]interface{}{
"total_tunnels": len(tunnelStats),
"tunnels": tunnelStats,
}
data, err := json.Marshal(stats)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
httputil.WriteJSON(w, data)
}
func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) {
if !h.validateMetricsAuth(w, r, "metrics") {
return
}
promhttp.Handler().ServeHTTP(w, r)
}
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
transports := h.allowedTransports
if len(transports) == 0 {
transports = []string{"tcp", "wss"}
}
tunnelTypes := h.allowedTunnelTypes
if len(tunnelTypes) == 0 {
tunnelTypes = []string{"http", "https", "tcp"}
}
response := map[string]interface{}{
"transports": transports,
"tunnel_types": tunnelTypes,
"preferred": h.GetPreferredTransport(),
"version": "1",
}
data, err := json.Marshal(response)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Cache-Control", "no-cache")
httputil.WriteJSON(w, data)
}

View File

@@ -0,0 +1,113 @@
package proxy
import (
"bufio"
"context"
"io"
"net"
"net/http"
"time"
"go.uber.org/zap"
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/wsutil"
)
type bufferedReadWriteCloser struct {
*bufio.Reader
net.Conn
}
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
tconn.IncActiveConnections()
hj, ok := w.(http.Hijacker)
if !ok {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
return
}
clientConn, clientBuf, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
return
}
if err := r.Write(stream); err != nil {
stream.Close()
clientConn.Close()
tconn.DecActiveConnections()
return
}
go func() {
defer stream.Close()
defer clientConn.Close()
defer tconn.DecActiveConnections()
var clientRW io.ReadWriteCloser = clientConn
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
clientRW = &bufferedReadWriteCloser{
Reader: clientBuf.Reader,
Conn: clientConn,
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
}()
}
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
if !h.IsTransportAllowed("wss") {
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
return
}
if h.wsConnHandler == nil {
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
return
}
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
remoteAddr := netutil.ExtractClientIP(r)
h.logger.Info("WebSocket tunnel connection established",
zap.String("remote_addr", remoteAddr),
)
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
}
func (h *Handler) isWebSocketUpgrade(r *http.Request) bool {
return httputil.IsWebSocketUpgrade(r)
}

View File

@@ -5,7 +5,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
@@ -18,46 +17,54 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/httputil"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriterSize(nil, 4096)
},
type ConnectionConfig struct {
Conn net.Conn
AuthToken string
Manager *tunnel.Manager
Logger *zap.Logger
PortAlloc *PortAllocator
Domain string
TunnelDomain string
PublicPort int
HTTPHandler http.Handler
GroupManager *ConnectionGroupManager
HTTPListener *connQueueListener
}
type Connection struct {
conn net.Conn
authToken string
manager *tunnel.Manager
logger *zap.Logger
subdomain string
port int
domain string
tunnelDomain string
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
stopCh chan struct{}
once sync.Once
lastHeartbeat time.Time
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
tunnelType protocol.TunnelType
ctx context.Context
cancel context.CancelFunc
session *yamux.Session
proxy *Proxy
tunnelID string
groupManager *ConnectionGroupManager
httpListener *connQueueListener
handedOff bool
conn net.Conn
authToken string
manager *tunnel.Manager
logger *zap.Logger
subdomain string
port int
domain string
tunnelDomain string
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
stopCh chan struct{}
once sync.Once
lastHeartbeat time.Time
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
tunnelType protocol.TunnelType
ctx context.Context
cancel context.CancelFunc
session *yamux.Session
proxy *Proxy
tunnelID string
groupManager *ConnectionGroupManager
httpListener *connQueueListener
handedOff bool
lifecycleManager *ConnectionLifecycleManager
// Server capabilities
allowedTunnelTypes []string
@@ -65,25 +72,32 @@ type Connection struct {
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
func NewConnection(cfg ConnectionConfig) *Connection {
ctx, cancel := context.WithCancel(context.Background())
stopCh := make(chan struct{})
c := &Connection{
conn: conn,
authToken: authToken,
manager: manager,
logger: logger,
portAlloc: portAlloc,
domain: domain,
tunnelDomain: tunnelDomain,
publicPort: publicPort,
httpHandler: httpHandler,
stopCh: make(chan struct{}),
lastHeartbeat: time.Now(),
ctx: ctx,
cancel: cancel,
groupManager: groupManager,
httpListener: httpListener,
conn: cfg.Conn,
authToken: cfg.AuthToken,
manager: cfg.Manager,
logger: cfg.Logger,
portAlloc: cfg.PortAlloc,
domain: cfg.Domain,
tunnelDomain: cfg.TunnelDomain,
publicPort: cfg.PublicPort,
httpHandler: cfg.HTTPHandler,
stopCh: stopCh,
lastHeartbeat: time.Now(),
ctx: ctx,
cancel: cancel,
groupManager: cfg.GroupManager,
httpListener: cfg.HTTPListener,
lifecycleManager: NewConnectionLifecycleManager(stopCh, cancel, cfg.Logger),
}
// Set connection in lifecycle manager
c.lifecycleManager.SetConnection(cfg.Conn)
return c
}
@@ -99,17 +113,7 @@ func (c *Connection) Handle() error {
return fmt.Errorf("failed to peek connection: %w", err)
}
peekStr := string(peek)
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
isHTTP := false
for _, method := range httpMethods {
if strings.HasPrefix(peekStr, method) {
isHTTP = true
break
}
}
if isHTTP {
if httputil.IsHTTPRequest(peek) {
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
return c.handleHTTPRequest(reader)
}
@@ -128,7 +132,24 @@ func (c *Connection) Handle() error {
defer sf.Close()
if sf.Frame.Type == protocol.FrameTypeDataConnect {
return c.handleDataConnect(sf.Frame, reader)
handler := NewDataConnectionHandler(
c.conn,
reader,
c.authToken,
c.groupManager,
c.stopCh,
c.logger,
)
handler.SetSessionCreatedHandler(func(session *yamux.Session) {
c.session = session
if c.lifecycleManager != nil {
c.lifecycleManager.SetSession(session)
}
})
handler.SetTunnelIDHandler(func(tunnelID string) {
c.tunnelID = tunnelID
})
return handler.Handle(sf.Frame)
}
if sf.Frame.Type != protocol.FrameTypeRegister {
@@ -153,121 +174,70 @@ func (c *Connection) Handle() error {
return fmt.Errorf("authentication failed")
}
if req.TunnelType == protocol.TunnelTypeTCP {
if c.portAlloc == nil {
return fmt.Errorf("port allocator not configured")
}
if requestedPort, ok := parseTCPSubdomainPort(req.CustomSubdomain); ok {
port, err := c.portAlloc.AllocateSpecific(requestedPort)
if err != nil {
c.sendError("port_allocation_failed", err.Error())
return fmt.Errorf("failed to allocate requested port %d: %w", requestedPort, err)
}
c.port = port
} else {
port, err := c.portAlloc.Allocate()
if err != nil {
c.sendError("port_allocation_failed", err.Error())
return fmt.Errorf("failed to allocate port: %w", err)
}
c.port = port
if req.CustomSubdomain == "" {
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
}
}
}
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
if err != nil {
c.sendError("registration_failed", err.Error())
c.portAlloc.Release(c.port)
c.port = 0
return fmt.Errorf("tunnel registration failed: %w", err)
}
c.subdomain = subdomain
tunnelConn, ok := c.manager.Get(subdomain)
if !ok {
return fmt.Errorf("failed to get registered tunnel")
}
c.tunnelConn = tunnelConn
c.tunnelConn.Conn = nil
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
c.logger.Info("IP access control configured",
zap.String("subdomain", subdomain),
zap.Strings("allow_ips", req.IPAccess.AllowIPs),
zap.Strings("deny_ips", req.IPAccess.DenyIPs),
)
}
if req.ProxyAuth != nil && req.ProxyAuth.Enabled {
c.tunnelConn.SetProxyAuth(req.ProxyAuth)
c.logger.Info("Proxy authentication configured",
zap.String("subdomain", subdomain),
)
}
c.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
zap.Int("local_port", req.LocalPort),
zap.Int("remote_port", c.port),
// Use RegistrationHandler for registration logic
regHandler := NewRegistrationHandler(
c.manager,
c.portAlloc,
c.groupManager,
c.domain,
c.tunnelDomain,
c.publicPort,
c.logger,
)
var tunnelURL string
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
if c.publicPort == 443 {
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.tunnelDomain)
} else {
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.tunnelDomain, c.publicPort)
}
} else {
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.tunnelDomain, c.port)
regReq := &RegistrationRequest{
TunnelType: req.TunnelType,
CustomSubdomain: req.CustomSubdomain,
Token: req.Token,
ConnectionType: req.ConnectionType,
PoolCapabilities: req.PoolCapabilities,
IPAccess: req.IPAccess,
ProxyAuth: req.ProxyAuth,
LocalPort: req.LocalPort,
}
var tunnelID string
var supportsDataConn bool
recommendedConns := 0
result, err := regHandler.Register(regReq)
if err != nil {
c.sendError("registration_failed", err.Error())
return fmt.Errorf("registration failed: %w", err)
}
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
tunnelID = group.TunnelID
c.tunnelID = tunnelID
supportsDataConn = true
recommendedConns = 4
// Store registration results
c.subdomain = result.Subdomain
c.port = result.Port
c.tunnelConn = result.TunnelConn
c.tunnelConn.Conn = nil
// Update lifecycle manager with registration info
if c.lifecycleManager != nil {
c.lifecycleManager.SetPortAllocation(c.portAlloc, c.port)
c.lifecycleManager.SetTunnelRegistration(c.manager, c.subdomain, "", c.groupManager)
}
// Handle connection groups
if result.SupportsDataConn && c.groupManager != nil {
group := c.groupManager.CreateGroup(result.Subdomain, req.Token, c, req.TunnelType)
result.TunnelID = group.TunnelID
c.tunnelID = result.TunnelID
// Update lifecycle manager with tunnel ID
if c.lifecycleManager != nil {
c.lifecycleManager.SetTunnelRegistration(c.manager, c.subdomain, c.tunnelID, c.groupManager)
}
c.logger.Info("Created connection group for multi-connection support",
zap.String("tunnel_id", tunnelID),
zap.String("tunnel_id", result.TunnelID),
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
)
}
resp := protocol.RegisterResponse{
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
TunnelID: tunnelID,
SupportsDataConn: supportsDataConn,
RecommendedConns: recommendedConns,
// Build and send registration response
resp, err := regHandler.BuildRegistrationResponse(result)
if err != nil {
return fmt.Errorf("failed to build registration response: %w", err)
}
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal registration response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
err = protocol.WriteFrame(c.conn, ackFrame)
if err != nil {
if err := regHandler.SendRegistrationResponse(c.conn, resp); err != nil {
return fmt.Errorf("failed to send registration ack: %w", err)
}
@@ -282,6 +252,11 @@ func (c *Connection) Handle() error {
c.frameWriter = protocol.NewFrameWriter(c.conn)
// Update lifecycle manager with frame writer
if c.lifecycleManager != nil {
c.lifecycleManager.SetFrameWriter(c.frameWriter)
}
c.frameWriter.SetWriteErrorHandler(func(err error) {
c.logger.Error("Write error detected, closing connection", zap.Error(err))
c.Close()
@@ -289,39 +264,30 @@ func (c *Connection) Handle() error {
go c.heartbeatChecker()
return c.handleFrames(reader)
// Use FrameHandler for frame processing
frameHandler := NewFrameHandler(c.conn, reader, c.stopCh, c.frameWriter, c.logger)
frameHandler.SetHeartbeatHandler(func() {
c.handleHeartbeat()
})
frameHandler.SetCloseHandler(func() {
c.logger.Info("Client requested close")
})
return frameHandler.HandleFrames()
}
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
if c.httpListener == nil {
return c.handleHTTPRequestLegacy(reader)
}
c.conn.SetReadDeadline(time.Time{})
wrappedConn := &bufferedConn{
Conn: c.conn,
reader: reader,
}
if !c.httpListener.Enqueue(wrappedConn) {
c.logger.Warn("HTTP listener queue full, rejecting connection")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 32\r\n" +
"Connection: close\r\n" +
"\r\n" +
"Server busy, please retry later\r\n"
c.conn.Write([]byte(response))
return fmt.Errorf("http listener queue full")
}
c.mu.Lock()
c.conn = nil
c.handedOff = true
c.mu.Unlock()
return nil
handler := NewHTTPRequestHandler(
c.conn,
reader,
c.httpHandler,
c.httpListener,
c.ctx,
c.logger,
&c.mu,
&c.handedOff,
)
return handler.Handle()
}
func (c *Connection) IsHandedOff() bool {
@@ -330,113 +296,6 @@ func (c *Connection) IsHandedOff() bool {
return c.handedOff
}
func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
if c.httpHandler == nil {
c.logger.Warn("HTTP request received but no HTTP handler configured")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 47\r\n" +
"\r\n" +
"HTTP handler not configured for this TCP port\r\n"
c.conn.Write([]byte(response))
return fmt.Errorf("HTTP handler not configured")
}
c.conn.SetReadDeadline(time.Time{})
for {
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
req, err := http.ReadRequest(reader)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
c.logger.Debug("Client closed HTTP connection")
return nil
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Debug("HTTP keep-alive timeout")
return nil
}
errStr := err.Error()
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
return nil
}
if strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
if strings.Contains(errStr, "malformed HTTP") {
c.logger.Warn("Received malformed HTTP request",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
if c.ctx != nil {
req = req.WithContext(c.ctx)
}
c.logger.Info("Processing HTTP request on TCP port",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
zap.String("host", req.Host),
)
// Get writer from pool to reduce GC pressure
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
pooledWriter.Reset(c.conn)
respWriter := &httpResponseWriter{
conn: c.conn,
writer: pooledWriter,
header: make(http.Header),
}
c.httpHandler.ServeHTTP(respWriter, req)
if err := respWriter.writer.Flush(); err != nil {
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
}
// Return writer to pool
pooledWriter.Reset(nil) // Clear reference to connection
bufioWriterPool.Put(pooledWriter)
// Keep TCP_NODELAY enabled for low latency HTTP responses
// (removed the toggle that was disabling it)
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
shouldClose := false
if req.Close {
shouldClose = true
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
if req.Header.Get("Connection") != "keep-alive" {
shouldClose = true
}
}
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
if shouldClose {
c.logger.Debug("Closing connection as requested by client or server")
return nil
}
}
}
func parseTCPSubdomainPort(subdomain string) (int, bool) {
if !strings.HasPrefix(subdomain, "tcp-") {
return 0, false
@@ -455,55 +314,6 @@ func parseTCPSubdomainPort(subdomain string) (int, bool) {
return port, true
}
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
select {
case <-c.stopCh:
return nil
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
if isTimeoutError(err) {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
c.logger.Info("Client disconnected")
return nil
}
select {
case <-c.stopCh:
c.logger.Debug("Connection closed during shutdown")
return nil
default:
return fmt.Errorf("failed to read frame: %w", err)
}
}
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeHeartbeat:
c.handleHeartbeat()
sf.Close()
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Client requested close")
return nil
default:
sf.Close()
c.logger.Warn("Unexpected frame type",
zap.String("type", sf.Frame.Type.String()),
)
}
}
}
func (c *Connection) handleHeartbeat() {
c.mu.Lock()
c.lastHeartbeat = time.Now()
@@ -562,188 +372,71 @@ func (c *Connection) sendError(code, message string) {
func (c *Connection) Close() {
c.once.Do(func() {
protocol.UnregisterConnection()
close(c.stopCh)
// Check if connection was handed off to HTTP handler
c.mu.RLock()
handedOff := c.handedOff
c.mu.RUnlock()
if c.cancel != nil {
c.cancel()
// If handed off, don't close the connection - HTTP handler owns it now
if handedOff {
c.logger.Debug("Connection handed off to HTTP handler, skipping close")
return
}
// Prevent race with handleHTTPRequest setting c.conn = nil
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
// Use lifecycle manager for cleanup
if c.lifecycleManager != nil {
c.lifecycleManager.Close()
} else {
// Fallback if lifecycle manager not initialized
protocol.UnregisterConnection()
close(c.stopCh)
if conn != nil {
_ = conn.SetDeadline(time.Now())
}
if c.frameWriter != nil {
c.frameWriter.Close()
}
if c.proxy != nil {
c.proxy.Stop()
}
if c.session != nil {
_ = c.session.Close()
}
if conn != nil {
conn.Close()
}
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
}
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
if c.tunnelID != "" && c.groupManager != nil {
c.groupManager.RemoveGroup(c.tunnelID)
if c.cancel != nil {
c.cancel()
}
}
c.logger.Info("Connection closed",
zap.String("subdomain", c.subdomain),
)
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn != nil {
_ = conn.SetDeadline(time.Now())
}
if c.frameWriter != nil {
c.frameWriter.Close()
}
if c.proxy != nil {
c.proxy.Stop()
}
if c.session != nil {
_ = c.session.Close()
}
if conn != nil {
conn.Close()
}
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
}
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
if c.tunnelID != "" && c.groupManager != nil {
c.groupManager.RemoveGroup(c.tunnelID)
}
}
c.logger.Info("Connection closed",
zap.String("subdomain", c.subdomain),
)
}
})
}
type httpResponseWriter struct {
conn net.Conn
writer *bufio.Writer
header http.Header
statusCode int
headerWritten bool
}
func (w *httpResponseWriter) Header() http.Header {
return w.header
}
func (w *httpResponseWriter) WriteHeader(statusCode int) {
if w.headerWritten {
return
}
w.statusCode = statusCode
w.headerWritten = true
statusText := http.StatusText(statusCode)
if statusText == "" {
statusText = "Unknown"
}
w.writer.WriteString("HTTP/1.1 ")
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
w.writer.WriteByte(' ')
w.writer.WriteString(statusText)
w.writer.WriteString("\r\n")
for key, values := range w.header {
for _, value := range values {
w.writer.WriteString(key)
w.writer.WriteString(": ")
w.writer.WriteString(value)
w.writer.WriteString("\r\n")
}
}
w.writer.WriteString("\r\n")
}
func (w *httpResponseWriter) Write(data []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.writer.Write(data)
}
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
var req protocol.DataConnectRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
c.sendError("invalid_request", "Failed to parse data connect request")
return fmt.Errorf("failed to parse data connect request: %w", err)
}
c.logger.Info("Data connection request received",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
if c.groupManager == nil {
c.sendDataConnectError("not_supported", "Multi-connection not supported")
return fmt.Errorf("group manager not available")
}
if c.authToken != "" && req.Token != c.authToken {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
group, ok := c.groupManager.GetGroup(req.TunnelID)
if !ok || group == nil {
c.sendDataConnectError("join_failed", "Tunnel not found")
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
if group.Token != "" && req.Token != group.Token {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
c.tunnelID = req.TunnelID
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal data connect response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
c.logger.Info("Data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
_ = c.conn.SetReadDeadline(time.Time{})
// Server acts as yamux Client, client connector acts as yamux Server
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
func isTimeoutError(err error) bool {
if err == nil {
return false
@@ -755,20 +448,6 @@ func isTimeoutError(err error) bool {
return strings.Contains(err.Error(), "i/o timeout")
}
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,
Message: fmt.Sprintf("%s: %s", code, message),
}
respData, err := json.Marshal(resp)
if err != nil {
c.logger.Error("Failed to marshal data connect error", zap.Error(err))
return
}
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
_ = protocol.WriteFrame(c.conn, frame)
}
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
func (c *Connection) SetAllowedTunnelTypes(types []string) {
c.allowedTunnelTypes = types

View File

@@ -17,10 +17,10 @@ import (
// sessionEntry represents a session with its current stream count for heap operations
type sessionEntry struct {
id string
session *yamux.Session
streams int
heapIdx int // index in the heap, managed by heap.Interface
id string
session *yamux.Session
streams int
heapIdx int // index in the heap, managed by heap.Interface
}
// sessionHeap implements heap.Interface for O(log n) session selection

View File

@@ -0,0 +1,161 @@
package tcp
import (
"bufio"
"fmt"
"net"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
)
// DataConnectionHandler handles data connection requests for multi-connection support.
type DataConnectionHandler struct {
conn net.Conn
reader *bufio.Reader
authToken string
groupManager *ConnectionGroupManager
stopCh <-chan struct{}
logger *zap.Logger
onSessionCreated func(*yamux.Session)
onTunnelIDSet func(string)
}
// NewDataConnectionHandler creates a new data connection handler.
func NewDataConnectionHandler(
conn net.Conn,
reader *bufio.Reader,
authToken string,
groupManager *ConnectionGroupManager,
stopCh <-chan struct{},
logger *zap.Logger,
) *DataConnectionHandler {
return &DataConnectionHandler{
conn: conn,
reader: reader,
authToken: authToken,
groupManager: groupManager,
stopCh: stopCh,
logger: logger,
}
}
// SetSessionCreatedHandler sets the callback for when a session is created.
func (h *DataConnectionHandler) SetSessionCreatedHandler(handler func(*yamux.Session)) {
h.onSessionCreated = handler
}
// SetTunnelIDHandler sets the callback for when tunnel ID is set.
func (h *DataConnectionHandler) SetTunnelIDHandler(handler func(string)) {
h.onTunnelIDSet = handler
}
// Handle processes the data connection request.
func (h *DataConnectionHandler) Handle(frame *protocol.Frame) error {
var req protocol.DataConnectRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
h.sendError("invalid_request", "Failed to parse data connect request")
return fmt.Errorf("failed to parse data connect request: %w", err)
}
h.logger.Info("Data connection request received",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
if h.groupManager == nil {
h.sendError("not_supported", "Multi-connection not supported")
return fmt.Errorf("group manager not available")
}
if h.authToken != "" && req.Token != h.authToken {
h.sendError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
group, ok := h.groupManager.GetGroup(req.TunnelID)
if !ok || group == nil {
h.sendError("join_failed", "Tunnel not found")
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
if group.Token != "" && req.Token != group.Token {
h.sendError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
if h.onTunnelIDSet != nil {
h.onTunnelIDSet(req.TunnelID)
}
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal data connect response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(h.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
h.logger.Info("Data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
_ = h.conn.SetReadDeadline(time.Time{})
// Server acts as yamux Client, client connector acts as yamux Server
bc := &bufferedConn{
Conn: h.conn,
reader: h.reader,
}
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
if h.onSessionCreated != nil {
h.onSessionCreated(session)
}
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
select {
case <-h.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
// sendError sends an error response to the client.
func (h *DataConnectionHandler) sendError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,
Message: fmt.Sprintf("%s: %s", code, message),
}
respData, err := json.Marshal(resp)
if err != nil {
h.logger.Error("Failed to marshal data connect error", zap.Error(err))
return
}
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
_ = protocol.WriteFrame(h.conn, frame)
}

View File

@@ -0,0 +1,122 @@
package tcp
import (
"bufio"
"fmt"
"net"
"time"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// FrameHandler handles protocol frame reading and processing.
type FrameHandler struct {
conn net.Conn
reader *bufio.Reader
stopCh <-chan struct{}
logger *zap.Logger
frameWriter *protocol.FrameWriter
// Heartbeat tracking
onHeartbeat func()
onClose func()
}
// NewFrameHandler creates a new frame handler.
func NewFrameHandler(
conn net.Conn,
reader *bufio.Reader,
stopCh <-chan struct{},
frameWriter *protocol.FrameWriter,
logger *zap.Logger,
) *FrameHandler {
return &FrameHandler{
conn: conn,
reader: reader,
stopCh: stopCh,
frameWriter: frameWriter,
logger: logger,
}
}
// SetHeartbeatHandler sets the callback for heartbeat frames.
func (fh *FrameHandler) SetHeartbeatHandler(handler func()) {
fh.onHeartbeat = handler
}
// SetCloseHandler sets the callback for close frames.
func (fh *FrameHandler) SetCloseHandler(handler func()) {
fh.onClose = handler
}
// HandleFrames processes incoming frames in a loop.
func (fh *FrameHandler) HandleFrames() error {
for {
select {
case <-fh.stopCh:
return nil
default:
}
fh.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(fh.reader)
if err != nil {
return fh.handleReadError(err)
}
sf := protocol.WithFrame(frame)
err = fh.processFrame(sf)
sf.Close()
if err != nil {
return err
}
}
}
// handleReadError handles errors that occur while reading frames.
func (fh *FrameHandler) handleReadError(err error) error {
if isTimeoutError(err) {
fh.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
fh.logger.Info("Client disconnected")
return nil
}
select {
case <-fh.stopCh:
fh.logger.Debug("Connection closed during shutdown")
return nil
default:
return fmt.Errorf("failed to read frame: %w", err)
}
}
// processFrame processes a single frame based on its type.
func (fh *FrameHandler) processFrame(sf *protocol.SafeFrame) error {
switch sf.Frame.Type {
case protocol.FrameTypeHeartbeat:
if fh.onHeartbeat != nil {
fh.onHeartbeat()
}
return nil
case protocol.FrameTypeClose:
fh.logger.Info("Client requested close")
if fh.onClose != nil {
fh.onClose()
}
return fmt.Errorf("client requested close")
default:
fh.logger.Warn("Unexpected frame type",
zap.String("type", sf.Frame.Type.String()),
)
return nil
}
}

View File

@@ -0,0 +1,251 @@
package tcp
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriterSize(nil, 4096)
},
}
// HTTPRequestHandler handles HTTP requests on TCP connections.
type HTTPRequestHandler struct {
conn net.Conn
reader *bufio.Reader
httpHandler http.Handler
httpListener *connQueueListener
ctx interface{ Done() <-chan struct{} }
logger *zap.Logger
mu *sync.RWMutex
handedOff *bool
}
// NewHTTPRequestHandler creates a new HTTP request handler.
func NewHTTPRequestHandler(
conn net.Conn,
reader *bufio.Reader,
httpHandler http.Handler,
httpListener *connQueueListener,
ctx interface{ Done() <-chan struct{} },
logger *zap.Logger,
mu *sync.RWMutex,
handedOff *bool,
) *HTTPRequestHandler {
return &HTTPRequestHandler{
conn: conn,
reader: reader,
httpHandler: httpHandler,
httpListener: httpListener,
ctx: ctx,
logger: logger,
mu: mu,
handedOff: handedOff,
}
}
// Handle processes the HTTP request.
func (h *HTTPRequestHandler) Handle() error {
if h.httpListener == nil {
return h.handleLegacy()
}
h.conn.SetReadDeadline(time.Time{})
wrappedConn := &bufferedConn{
Conn: h.conn,
reader: h.reader,
}
if !h.httpListener.Enqueue(wrappedConn) {
h.logger.Warn("HTTP listener queue full, rejecting connection")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 32\r\n" +
"Connection: close\r\n" +
"\r\n" +
"Server busy, please retry later\r\n"
h.conn.Write([]byte(response))
return fmt.Errorf("http listener queue full")
}
h.mu.Lock()
*h.handedOff = true
h.mu.Unlock()
return nil
}
// handleLegacy processes HTTP requests using the legacy handler.
func (h *HTTPRequestHandler) handleLegacy() error {
if h.httpHandler == nil {
h.logger.Warn("HTTP request received but no HTTP handler configured")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 47\r\n" +
"\r\n" +
"HTTP handler not configured for this TCP port\r\n"
h.conn.Write([]byte(response))
return fmt.Errorf("HTTP handler not configured")
}
h.conn.SetReadDeadline(time.Time{})
for {
h.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
req, err := http.ReadRequest(h.reader)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
h.logger.Debug("Client closed HTTP connection")
return nil
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
h.logger.Debug("HTTP keep-alive timeout")
return nil
}
errStr := err.Error()
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
h.logger.Debug("HTTP connection closed during read", zap.Error(err))
return nil
}
if strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
h.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
if strings.Contains(errStr, "malformed HTTP") {
h.logger.Warn("Received malformed HTTP request",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
return nil
}
h.logger.Error("Failed to parse HTTP request", zap.Error(err))
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
if h.ctx != nil {
if ctxWithContext, ok := h.ctx.(interface{ Done() <-chan struct{} }); ok {
req = req.WithContext(ctxWithContext.(interface {
Done() <-chan struct{}
Deadline() (deadline time.Time, ok bool)
Err() error
Value(key interface{}) interface{}
}))
}
}
h.logger.Info("Processing HTTP request on TCP port",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
zap.String("host", req.Host),
)
// Get writer from pool to reduce GC pressure
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
pooledWriter.Reset(h.conn)
respWriter := &httpResponseWriter{
conn: h.conn,
writer: pooledWriter,
header: make(http.Header),
}
h.httpHandler.ServeHTTP(respWriter, req)
if err := respWriter.writer.Flush(); err != nil {
h.logger.Debug("Failed to flush HTTP response", zap.Error(err))
}
// Return writer to pool
pooledWriter.Reset(nil) // Clear reference to connection
bufioWriterPool.Put(pooledWriter)
h.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
shouldClose := false
if req.Close {
shouldClose = true
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
if req.Header.Get("Connection") != "keep-alive" {
shouldClose = true
}
}
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
if shouldClose {
h.logger.Debug("Closing connection as requested by client or server")
return nil
}
}
}
// httpResponseWriter implements http.ResponseWriter for raw TCP connections.
type httpResponseWriter struct {
conn net.Conn
writer *bufio.Writer
header http.Header
statusCode int
headerWritten bool
}
func (w *httpResponseWriter) Header() http.Header {
return w.header
}
func (w *httpResponseWriter) WriteHeader(statusCode int) {
if w.headerWritten {
return
}
w.statusCode = statusCode
w.headerWritten = true
statusText := http.StatusText(statusCode)
if statusText == "" {
statusText = "Unknown"
}
w.writer.WriteString("HTTP/1.1 ")
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
w.writer.WriteByte(' ')
w.writer.WriteString(statusText)
w.writer.WriteString("\r\n")
for key, values := range w.header {
for _, value := range values {
w.writer.WriteString(key)
w.writer.WriteString(": ")
w.writer.WriteString(value)
w.writer.WriteString("\r\n")
}
}
w.writer.WriteString("\r\n")
}
func (w *httpResponseWriter) Write(data []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.writer.Write(data)
}

View File

@@ -0,0 +1,137 @@
package tcp
import (
"sync"
"time"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/server/tunnel"
"drip/internal/shared/protocol"
)
// ConnectionLifecycleManager manages the lifecycle of a connection.
type ConnectionLifecycleManager struct {
once sync.Once
stopCh chan struct{}
cancel func()
logger *zap.Logger
// Resources to clean up
conn interface {
Close() error
SetDeadline(time.Time) error
}
frameWriter *protocol.FrameWriter
proxy interface{ Stop() }
session *yamux.Session
portAlloc *PortAllocator
port int
manager *tunnel.Manager
subdomain string
tunnelID string
groupManager *ConnectionGroupManager
}
// NewConnectionLifecycleManager creates a new lifecycle manager.
func NewConnectionLifecycleManager(
stopCh chan struct{},
cancel func(),
logger *zap.Logger,
) *ConnectionLifecycleManager {
return &ConnectionLifecycleManager{
stopCh: stopCh,
cancel: cancel,
logger: logger,
}
}
// SetConnection sets the connection to manage.
func (clm *ConnectionLifecycleManager) SetConnection(conn interface {
Close() error
SetDeadline(time.Time) error
}) {
clm.conn = conn
}
// SetFrameWriter sets the frame writer to close.
func (clm *ConnectionLifecycleManager) SetFrameWriter(fw *protocol.FrameWriter) {
clm.frameWriter = fw
}
// SetProxy sets the proxy to stop.
func (clm *ConnectionLifecycleManager) SetProxy(proxy interface{ Stop() }) {
clm.proxy = proxy
}
// SetSession sets the yamux session to close.
func (clm *ConnectionLifecycleManager) SetSession(session *yamux.Session) {
clm.session = session
}
// SetPortAllocation sets the port allocation to release.
func (clm *ConnectionLifecycleManager) SetPortAllocation(portAlloc *PortAllocator, port int) {
clm.portAlloc = portAlloc
clm.port = port
}
// SetTunnelRegistration sets the tunnel registration to clean up.
func (clm *ConnectionLifecycleManager) SetTunnelRegistration(
manager *tunnel.Manager,
subdomain string,
tunnelID string,
groupManager *ConnectionGroupManager,
) {
clm.manager = manager
clm.subdomain = subdomain
clm.tunnelID = tunnelID
clm.groupManager = groupManager
}
// Close closes the connection and cleans up all resources.
func (clm *ConnectionLifecycleManager) Close() {
clm.once.Do(func() {
protocol.UnregisterConnection()
close(clm.stopCh)
if clm.cancel != nil {
clm.cancel()
}
if clm.conn != nil {
_ = clm.conn.SetDeadline(time.Now())
}
if clm.frameWriter != nil {
clm.frameWriter.Close()
}
if clm.proxy != nil {
clm.proxy.Stop()
}
if clm.session != nil {
_ = clm.session.Close()
}
if clm.conn != nil {
clm.conn.Close()
}
if clm.port > 0 && clm.portAlloc != nil {
clm.portAlloc.Release(clm.port)
}
if clm.subdomain != "" && clm.manager != nil {
clm.manager.Unregister(clm.subdomain)
if clm.tunnelID != "" && clm.groupManager != nil {
clm.groupManager.RemoveGroup(clm.tunnelID)
}
}
clm.logger.Info("Connection closed",
zap.String("subdomain", clm.subdomain),
)
})
}

View File

@@ -15,11 +15,25 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/recovery"
"drip/internal/shared/utils"
"go.uber.org/zap"
"golang.org/x/net/http2"
)
type ListenerConfig struct {
Address string
TLSConfig *tls.Config
AuthToken string
Manager *tunnel.Manager
Logger *zap.Logger
PortAlloc *PortAllocator
Domain string
TunnelDomain string
PublicPort int
HTTPHandler http.Handler
}
type Listener struct {
address string
tlsConfig *tls.Config
@@ -48,47 +62,47 @@ type Listener struct {
allowedTunnelTypes []string
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
func NewListener(cfg ListenerConfig) *Listener {
numCPU := pool.NumCPU()
workers := numCPU * 5
queueSize := workers * 20
workerPool := pool.NewWorkerPool(workers, queueSize)
logger.Info("Worker pool configured",
cfg.Logger.Info("Worker pool configured",
zap.Int("cpu_cores", numCPU),
zap.Int("workers", workers),
zap.Int("queue_size", queueSize),
)
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
panicMetrics := recovery.NewPanicMetrics(cfg.Logger, nil)
recoverer := recovery.NewRecoverer(cfg.Logger, panicMetrics)
// Initialize worker pool metrics
metrics.WorkerPoolSize.Set(float64(workers))
l := &Listener{
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
manager: manager,
portAlloc: portAlloc,
logger: logger,
domain: domain,
tunnelDomain: tunnelDomain,
publicPort: publicPort,
httpHandler: httpHandler,
address: cfg.Address,
tlsConfig: cfg.TLSConfig,
authToken: cfg.AuthToken,
manager: cfg.Manager,
portAlloc: cfg.PortAlloc,
logger: cfg.Logger,
domain: cfg.Domain,
tunnelDomain: cfg.TunnelDomain,
publicPort: cfg.PublicPort,
httpHandler: cfg.HTTPHandler,
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
recoverer: recoverer,
panicMetrics: panicMetrics,
groupManager: NewConnectionGroupManager(logger),
groupManager: NewConnectionGroupManager(cfg.Logger),
}
// Set up WebSocket connection handler if httpHandler supports it
if h, ok := httpHandler.(*proxy.Handler); ok {
if h, ok := cfg.HTTPHandler.(*proxy.Handler); ok {
h.SetWSConnectionHandler(l)
h.SetPublicPort(publicPort)
h.SetPublicPort(cfg.PublicPort)
}
return l
@@ -269,7 +283,19 @@ func (l *Listener) handleConnection(netConn net.Conn) {
)
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
conn := NewConnection(ConnectionConfig{
Conn: netConn,
AuthToken: l.authToken,
Manager: l.manager,
Logger: l.logger,
PortAlloc: l.portAlloc,
Domain: l.domain,
TunnelDomain: l.tunnelDomain,
PublicPort: l.publicPort,
HTTPHandler: l.httpHandler,
GroupManager: l.groupManager,
HTTPListener: l.httpListener,
})
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
conn.SetAllowedTransports(l.allowedTransports)
@@ -297,18 +323,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
if err := conn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
if utils.IsNetworkError(errStr) {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "failed to parse HTTP request") {
if utils.IsProtocolError(errStr) {
l.logger.Warn("Protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
@@ -387,7 +406,19 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
)
// Create connection handler (no TLS verification needed - already done by HTTP server)
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
tcpConn := NewConnection(ConnectionConfig{
Conn: conn,
AuthToken: l.authToken,
Manager: l.manager,
Logger: l.logger,
PortAlloc: l.portAlloc,
Domain: l.domain,
TunnelDomain: l.tunnelDomain,
PublicPort: l.publicPort,
HTTPHandler: l.httpHandler,
GroupManager: l.groupManager,
HTTPListener: l.httpListener,
})
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
l.connMu.Lock()
@@ -412,19 +443,11 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
if err := tcpConn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "websocket: close") {
if utils.IsNetworkError(errStr) {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "tunnel type not allowed") {
if utils.IsProtocolError(errStr) {
l.logger.Warn("WebSocket tunnel protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),

View File

@@ -0,0 +1,186 @@
package tcp
import (
"fmt"
json "github.com/goccy/go-json"
"go.uber.org/zap"
"drip/internal/server/tunnel"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
)
// RegistrationHandler handles tunnel registration logic.
type RegistrationHandler struct {
manager *tunnel.Manager
portAlloc *PortAllocator
groupManager *ConnectionGroupManager
domain string
tunnelDomain string
publicPort int
logger *zap.Logger
}
// NewRegistrationHandler creates a new registration handler.
func NewRegistrationHandler(
manager *tunnel.Manager,
portAlloc *PortAllocator,
groupManager *ConnectionGroupManager,
domain, tunnelDomain string,
publicPort int,
logger *zap.Logger,
) *RegistrationHandler {
return &RegistrationHandler{
manager: manager,
portAlloc: portAlloc,
groupManager: groupManager,
domain: domain,
tunnelDomain: tunnelDomain,
publicPort: publicPort,
logger: logger,
}
}
// RegistrationRequest contains all information needed for registration.
type RegistrationRequest struct {
TunnelType protocol.TunnelType
CustomSubdomain string
Token string
ConnectionType string
PoolCapabilities *protocol.PoolCapabilities
IPAccess *protocol.IPAccessControl
ProxyAuth *protocol.ProxyAuth
LocalPort int
}
// RegistrationResult contains the result of a registration attempt.
type RegistrationResult struct {
Subdomain string
Port int
TunnelURL string
TunnelID string
SupportsDataConn bool
RecommendedConns int
TunnelConn *tunnel.Connection
}
// Register handles the tunnel registration process.
func (rh *RegistrationHandler) Register(req *RegistrationRequest) (*RegistrationResult, error) {
// Allocate port for TCP tunnels
port := 0
if req.TunnelType == protocol.TunnelTypeTCP {
if rh.portAlloc == nil {
return nil, fmt.Errorf("port allocator not configured")
}
if requestedPort, ok := parseTCPSubdomainPort(req.CustomSubdomain); ok {
allocatedPort, err := rh.portAlloc.AllocateSpecific(requestedPort)
if err != nil {
return nil, fmt.Errorf("failed to allocate requested port %d: %w", requestedPort, err)
}
port = allocatedPort
} else {
allocatedPort, err := rh.portAlloc.Allocate()
if err != nil {
return nil, fmt.Errorf("failed to allocate port: %w", err)
}
port = allocatedPort
if req.CustomSubdomain == "" {
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
}
}
}
// Register with tunnel manager
subdomain, err := rh.manager.Register(nil, req.CustomSubdomain)
if err != nil {
if port > 0 && rh.portAlloc != nil {
rh.portAlloc.Release(port)
}
return nil, fmt.Errorf("tunnel registration failed: %w", err)
}
// Get tunnel connection
tunnelConn, ok := rh.manager.Get(subdomain)
if !ok {
return nil, fmt.Errorf("failed to get registered tunnel")
}
// Configure tunnel
tunnelConn.SetTunnelType(req.TunnelType)
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
rh.logger.Info("IP access control configured",
zap.String("subdomain", subdomain),
zap.Strings("allow_ips", req.IPAccess.AllowIPs),
zap.Strings("deny_ips", req.IPAccess.DenyIPs),
)
}
if req.ProxyAuth != nil && req.ProxyAuth.Enabled {
tunnelConn.SetProxyAuth(req.ProxyAuth)
rh.logger.Info("Proxy authentication configured",
zap.String("subdomain", subdomain),
)
}
// Build tunnel URL
urlBuilder := utils.NewTunnelURLBuilder(rh.tunnelDomain, rh.publicPort)
tunnelURL := urlBuilder.BuildURL(subdomain, req.TunnelType, port)
// Handle connection groups for multi-connection support
var tunnelID string
var supportsDataConn bool
recommendedConns := 0
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && rh.groupManager != nil {
// This will be handled by the caller since it needs the connection instance
supportsDataConn = true
recommendedConns = 4
}
rh.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
zap.Int("local_port", req.LocalPort),
zap.Int("remote_port", port),
)
return &RegistrationResult{
Subdomain: subdomain,
Port: port,
TunnelURL: tunnelURL,
TunnelID: tunnelID,
SupportsDataConn: supportsDataConn,
RecommendedConns: recommendedConns,
TunnelConn: tunnelConn,
}, nil
}
// BuildRegistrationResponse creates a protocol registration response.
func (rh *RegistrationHandler) BuildRegistrationResponse(result *RegistrationResult) (*protocol.RegisterResponse, error) {
resp := &protocol.RegisterResponse{
Subdomain: result.Subdomain,
Port: result.Port,
URL: result.TunnelURL,
Message: "Tunnel registered successfully",
TunnelID: result.TunnelID,
SupportsDataConn: result.SupportsDataConn,
RecommendedConns: result.RecommendedConns,
}
return resp, nil
}
// SendRegistrationResponse sends the registration response frame.
func (rh *RegistrationHandler) SendRegistrationResponse(conn interface{ Write([]byte) (int, error) }, resp *protocol.RegisterResponse) error {
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal registration response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
return protocol.WriteFrame(conn, ackFrame)
}

View File

@@ -35,6 +35,11 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
}
c.session = session
// Update lifecycle manager with session
if c.lifecycleManager != nil {
c.lifecycleManager.SetSession(session)
}
openStream := session.Open
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
@@ -48,6 +53,11 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
}
// Update lifecycle manager with proxy
if c.lifecycleManager != nil {
c.lifecycleManager.SetProxy(c.proxy)
}
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start tcp proxy: %w", err)
}
@@ -76,6 +86,11 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
}
c.session = session
// Update lifecycle manager with session
if c.lifecycleManager != nil {
c.lifecycleManager.SetSession(session)
}
openStream := session.Open
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {

View File

@@ -31,12 +31,6 @@ var (
ErrRateLimitExceeded = errors.New("rate limit exceeded, try again later")
)
// rateLimitEntry tracks registration attempts per IP
type rateLimitEntry struct {
count int
windowEnd time.Time
}
// shard holds a subset of tunnels with its own lock
type shard struct {
tunnels map[string]*Connection
@@ -52,16 +46,16 @@ type Manager struct {
// Limits
maxTunnels int
maxTunnelsPerIP int
rateLimit int
rateLimitWindow time.Duration
// Global counters (atomic for lock-free reads)
tunnelCount atomic.Int64
// Per-IP tracking (requires separate lock as it spans shards)
ipMu sync.RWMutex
tunnelsByIP map[string]int // IP -> tunnel count
rateLimits map[string]*rateLimitEntry // IP -> rate limit entry
tunnelsByIP map[string]int // IP -> tunnel count
// Rate limiting
rateLimiter *RateLimiter
// Lifecycle
stopCh chan struct{}
@@ -71,7 +65,7 @@ type Manager struct {
type ManagerConfig struct {
MaxTunnels int
MaxTunnelsPerIP int
RateLimit int // Registrations per IP per window
RateLimit int // Registrations per IP per window
RateLimitWindow time.Duration
}
@@ -117,10 +111,8 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
logger: logger,
maxTunnels: cfg.MaxTunnels,
maxTunnelsPerIP: cfg.MaxTunnelsPerIP,
rateLimit: cfg.RateLimit,
rateLimitWindow: cfg.RateLimitWindow,
tunnelsByIP: make(map[string]int),
rateLimits: make(map[string]*rateLimitEntry),
rateLimiter: NewRateLimiter(cfg.RateLimit, cfg.RateLimitWindow, logger),
stopCh: make(chan struct{}),
}
@@ -140,28 +132,6 @@ func (m *Manager) getShard(subdomain string) *shard {
return &m.shards[h.Sum32()%numShards]
}
// checkRateLimit checks if the IP has exceeded rate limit (caller must hold ipMu)
func (m *Manager) checkRateLimitLocked(ip string) bool {
now := time.Now()
entry, exists := m.rateLimits[ip]
if !exists || now.After(entry.windowEnd) {
// New window
m.rateLimits[ip] = &rateLimitEntry{
count: 1,
windowEnd: now.Add(m.rateLimitWindow),
}
return true
}
if entry.count >= m.rateLimit {
return false
}
entry.count++
return true
}
// Register registers a new tunnel connection with IP-based limits
func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string, error) {
return m.RegisterWithIP(conn, customSubdomain, "")
@@ -193,19 +163,14 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
// Check per-IP limits and reserve slot atomically
if remoteIP != "" {
m.ipMu.Lock()
if !m.checkRateLimitLocked(remoteIP) {
m.ipMu.Unlock()
// Check rate limit first (has its own lock)
if !m.rateLimiter.CheckAndIncrement(remoteIP) {
rollbackGlobal()
m.logger.Warn("Rate limit exceeded",
zap.String("ip", remoteIP),
zap.Int("limit", m.rateLimit),
)
metrics.RateLimitRejections.WithLabelValues("registration", remoteIP).Inc()
metrics.TunnelRegistrationFailures.WithLabelValues("rate_limit").Inc()
return "", ErrRateLimitExceeded
}
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
currentPerIP := m.tunnelsByIP[remoteIP]
m.ipMu.Unlock()
@@ -427,14 +392,7 @@ func (m *Manager) CleanupStale(timeout time.Duration) int {
}
// Cleanup expired rate limit entries
m.ipMu.Lock()
now := time.Now()
for ip, entry := range m.rateLimits {
if now.After(entry.windowEnd) {
delete(m.rateLimits, ip)
}
}
m.ipMu.Unlock()
m.rateLimiter.Cleanup()
if totalCleaned > 0 {
m.logger.Info("Cleaned up stale tunnels",

View File

@@ -0,0 +1,86 @@
package tunnel
import (
"sync"
"time"
"drip/internal/server/metrics"
"go.uber.org/zap"
)
// rateLimitEntry tracks registration attempts per IP
type rateLimitEntry struct {
count int
windowEnd time.Time
}
// RateLimiter manages rate limiting for tunnel registrations.
type RateLimiter struct {
mu sync.RWMutex
rateLimits map[string]*rateLimitEntry
rateLimit int
rateLimitWindow time.Duration
logger *zap.Logger
}
// NewRateLimiter creates a new rate limiter.
func NewRateLimiter(rateLimit int, rateLimitWindow time.Duration, logger *zap.Logger) *RateLimiter {
return &RateLimiter{
rateLimits: make(map[string]*rateLimitEntry),
rateLimit: rateLimit,
rateLimitWindow: rateLimitWindow,
logger: logger,
}
}
// CheckAndIncrement checks if the IP has exceeded rate limit and increments the counter.
func (rl *RateLimiter) CheckAndIncrement(ip string) bool {
if ip == "" {
return true
}
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
entry, exists := rl.rateLimits[ip]
if !exists || now.After(entry.windowEnd) {
// New window
rl.rateLimits[ip] = &rateLimitEntry{
count: 1,
windowEnd: now.Add(rl.rateLimitWindow),
}
return true
}
if entry.count >= rl.rateLimit {
rl.logger.Warn("Rate limit exceeded",
zap.String("ip", ip),
zap.Int("limit", rl.rateLimit),
)
metrics.RateLimitRejections.WithLabelValues("registration", ip).Inc()
return false
}
entry.count++
return true
}
// Cleanup removes expired rate limit entries.
func (rl *RateLimiter) Cleanup() int {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
removed := 0
for ip, entry := range rl.rateLimits {
if now.After(entry.windowEnd) {
delete(rl.rateLimits, ip)
removed++
}
}
return removed
}

View File

@@ -0,0 +1,39 @@
package httputil
import "strings"
// HTTPMethods contains common HTTP method prefixes for protocol detection.
var HTTPMethods = []string{
"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC",
}
// IsHTTPRequest checks if the given bytes represent the start of an HTTP request.
// It checks for common HTTP method prefixes.
func IsHTTPRequest(data []byte) bool {
if len(data) < 4 {
return false
}
dataStr := string(data[:4])
for _, method := range HTTPMethods {
if strings.HasPrefix(dataStr, method) {
return true
}
}
return false
}
// DetectHTTPMethod returns the HTTP method if the data starts with one, or empty string.
func DetectHTTPMethod(data []byte) string {
if len(data) < 4 {
return ""
}
dataStr := string(data)
for _, method := range HTTPMethods {
if strings.HasPrefix(dataStr, method) {
return strings.TrimSpace(method)
}
}
return ""
}

View File

@@ -0,0 +1,58 @@
package httputil
import (
"fmt"
"net"
)
// HTTPErrorResponse represents a standard HTTP error response.
type HTTPErrorResponse struct {
StatusCode int
StatusText string
Message string
}
// Common HTTP error responses
var (
ServiceUnavailable = &HTTPErrorResponse{
StatusCode: 503,
StatusText: "Service Unavailable",
Message: "Server busy, please retry later",
}
HandlerNotConfigured = &HTTPErrorResponse{
StatusCode: 503,
StatusText: "Service Unavailable",
Message: "HTTP handler not configured for this TCP port",
}
)
// WriteErrorResponse writes an HTTP error response to the connection.
func WriteErrorResponse(conn net.Conn, resp *HTTPErrorResponse) error {
response := fmt.Sprintf(
"HTTP/1.1 %d %s\r\n"+
"Content-Type: text/plain\r\n"+
"Content-Length: %d\r\n"+
"Connection: close\r\n"+
"\r\n"+
"%s\r\n",
resp.StatusCode,
resp.StatusText,
len(resp.Message)+2,
resp.Message,
)
_, err := conn.Write([]byte(response))
return err
}
// WriteServiceUnavailable writes a 503 Service Unavailable response.
func WriteServiceUnavailable(conn net.Conn, message string) error {
if message == "" {
message = ServiceUnavailable.Message
}
return WriteErrorResponse(conn, &HTTPErrorResponse{
StatusCode: 503,
StatusText: "Service Unavailable",
Message: message,
})
}

View File

@@ -0,0 +1,38 @@
package httputil
import (
"fmt"
"net/http"
)
// WriteJSON writes a JSON response with the appropriate headers.
func WriteJSON(w http.ResponseWriter, data []byte) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
// WriteHTML writes an HTML response with the appropriate headers.
func WriteHTML(w http.ResponseWriter, data []byte) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
// WriteHTMLWithStatus writes an HTML response with a custom status code.
func WriteHTMLWithStatus(w http.ResponseWriter, data []byte, statusCode int) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(statusCode)
w.Write(data)
}
// SetContentLength sets the Content-Length header.
func SetContentLength(w http.ResponseWriter, length int64) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", length))
}
// SetCloseConnection sets the Connection: close header.
func SetCloseConnection(w http.ResponseWriter) {
w.Header().Set("Connection", "close")
}

View File

@@ -0,0 +1,116 @@
package mux
import (
"bufio"
"net"
"github.com/hashicorp/yamux"
)
// BufferedConn wraps a connection with a buffered reader.
type BufferedConn struct {
net.Conn
reader *bufio.Reader
}
// NewBufferedConn creates a new buffered connection.
func NewBufferedConn(conn net.Conn, reader *bufio.Reader) *BufferedConn {
return &BufferedConn{
Conn: conn,
reader: reader,
}
}
// Read reads from the buffered reader if available, otherwise from the connection.
func (bc *BufferedConn) Read(p []byte) (int, error) {
if bc.reader != nil {
return bc.reader.Read(p)
}
return bc.Conn.Read(p)
}
// SessionBuilder helps build yamux sessions with consistent configuration.
type SessionBuilder struct {
conn net.Conn
reader *bufio.Reader
config *yamux.Config
isServer bool
}
// NewSessionBuilder creates a new session builder.
func NewSessionBuilder(conn net.Conn) *SessionBuilder {
return &SessionBuilder{
conn: conn,
}
}
// WithReader sets the buffered reader for the session.
func (sb *SessionBuilder) WithReader(reader *bufio.Reader) *SessionBuilder {
sb.reader = reader
return sb
}
// WithConfig sets the yamux configuration.
func (sb *SessionBuilder) WithConfig(config *yamux.Config) *SessionBuilder {
sb.config = config
return sb
}
// AsServer configures the session as a server.
func (sb *SessionBuilder) AsServer() *SessionBuilder {
sb.isServer = true
if sb.config == nil {
sb.config = NewServerConfig()
}
return sb
}
// AsClient configures the session as a client.
func (sb *SessionBuilder) AsClient() *SessionBuilder {
sb.isServer = false
if sb.config == nil {
sb.config = NewClientConfig()
}
return sb
}
// Build creates the yamux session.
func (sb *SessionBuilder) Build() (*yamux.Session, error) {
conn := sb.conn
// Wrap with buffered reader if provided
if sb.reader != nil {
conn = NewBufferedConn(sb.conn, sb.reader)
}
// Use default config if not set
if sb.config == nil {
if sb.isServer {
sb.config = NewServerConfig()
} else {
sb.config = NewClientConfig()
}
}
// Create session based on role
if sb.isServer {
return yamux.Server(conn, sb.config)
}
return yamux.Client(conn, sb.config)
}
// BuildClientSession creates a yamux client session with the given connection and reader.
func BuildClientSession(conn net.Conn, reader *bufio.Reader) (*yamux.Session, error) {
return NewSessionBuilder(conn).
WithReader(reader).
AsClient().
Build()
}
// BuildServerSession creates a yamux server session with the given connection and reader.
func BuildServerSession(conn net.Conn, reader *bufio.Reader) (*yamux.Session, error) {
return NewSessionBuilder(conn).
WithReader(reader).
AsServer().
Build()
}

View File

@@ -0,0 +1,78 @@
package netutil
import (
"net"
"net/http"
"strings"
)
var privateNetworks []*net.IPNet
func init() {
privateCIDRs := []string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"::1/128",
"fc00::/7",
"fe80::/10",
}
for _, cidr := range privateCIDRs {
_, ipNet, _ := net.ParseCIDR(cidr)
privateNetworks = append(privateNetworks, ipNet)
}
}
// ExtractRemoteIP extracts the IP address from a remote address string (host:port format).
func ExtractRemoteIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return host
}
// IsPrivateIP checks if the given IP is a private/loopback address.
func IsPrivateIP(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
for _, network := range privateNetworks {
if network.Contains(parsedIP) {
return true
}
}
return false
}
// ExtractClientIP extracts the client IP from the request.
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
// comes from a private/loopback network (typical reverse proxy setup).
func ExtractClientIP(r *http.Request) string {
// First, get the direct remote address
remoteIP := ExtractRemoteIP(r.RemoteAddr)
// Only trust proxy headers if the request comes from a private network
if IsPrivateIP(remoteIP) {
// Check X-Forwarded-For header (may contain multiple IPs)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
}
// Fall back to remote address
return remoteIP
}

View File

@@ -3,10 +3,10 @@ package pool
import "sync"
const (
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
)
type BufferPool struct {

View File

@@ -0,0 +1,48 @@
package pool
import (
"bufio"
"sync"
)
// BufioReaderPool provides a pool of bufio.Reader instances.
var BufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReaderSize(nil, 32*1024)
},
}
// BufioWriterPool provides a pool of bufio.Writer instances.
var BufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriterSize(nil, 4096)
},
}
// GetReader gets a bufio.Reader from the pool and resets it to read from r.
func GetReader(r interface{}) *bufio.Reader {
reader := BufioReaderPool.Get().(*bufio.Reader)
if resetter, ok := r.(interface{ Reset(interface{}) }); ok {
resetter.Reset(r)
}
return reader
}
// PutReader returns a bufio.Reader to the pool.
func PutReader(reader *bufio.Reader) {
BufioReaderPool.Put(reader)
}
// GetWriter gets a bufio.Writer from the pool and resets it to write to w.
func GetWriter(w interface{}) *bufio.Writer {
writer := BufioWriterPool.Get().(*bufio.Writer)
if resetter, ok := w.(interface{ Reset(interface{}) }); ok {
resetter.Reset(w)
}
return writer
}
// PutWriter returns a bufio.Writer to the pool.
func PutWriter(writer *bufio.Writer) {
BufioWriterPool.Put(writer)
}

View File

@@ -0,0 +1,68 @@
package protocol
import (
"fmt"
"net"
json "github.com/goccy/go-json"
"go.uber.org/zap"
)
// ErrorSender handles sending error frames over connections.
type ErrorSender struct {
conn net.Conn
frameWriter *FrameWriter
logger *zap.Logger
}
// NewErrorSender creates a new error sender.
func NewErrorSender(conn net.Conn, frameWriter *FrameWriter, logger *zap.Logger) *ErrorSender {
return &ErrorSender{
conn: conn,
frameWriter: frameWriter,
logger: logger,
}
}
// SendError sends an error frame with the given code and message.
func (e *ErrorSender) SendError(code, message string) error {
errMsg := ErrorMessage{
Code: code,
Message: message,
}
data, err := json.Marshal(errMsg)
if err != nil {
e.logger.Error("Failed to marshal error message", zap.Error(err))
return fmt.Errorf("failed to marshal error: %w", err)
}
errFrame := NewFrame(FrameTypeError, data)
if e.frameWriter == nil {
return WriteFrame(e.conn, errFrame)
}
return e.frameWriter.WriteFrame(errFrame)
}
// SendAuthenticationError sends an authentication failed error.
func (e *ErrorSender) SendAuthenticationError() error {
return e.SendError("authentication_failed", "Invalid authentication token")
}
// SendRegistrationError sends a registration failed error.
func (e *ErrorSender) SendRegistrationError(message string) error {
return e.SendError("registration_failed", message)
}
// SendPortAllocationError sends a port allocation failed error.
func (e *ErrorSender) SendPortAllocationError(message string) error {
return e.SendError("port_allocation_failed", message)
}
// SendTunnelTypeNotAllowedError sends a tunnel type not allowed error.
func (e *ErrorSender) SendTunnelTypeNotAllowedError(tunnelType string) error {
return e.SendError("tunnel_type_not_allowed",
fmt.Sprintf("Tunnel type '%s' is not allowed on this server", tunnelType))
}

View File

@@ -0,0 +1,74 @@
package protocol
import (
"sync"
"time"
"go.uber.org/zap"
)
// HeartbeatManager manages heartbeat tracking and checking for connections.
type HeartbeatManager struct {
mu sync.RWMutex
lastHeartbeat time.Time
interval time.Duration
timeout time.Duration
logger *zap.Logger
frameWriter *FrameWriter
}
// NewHeartbeatManager creates a new heartbeat manager.
func NewHeartbeatManager(interval, timeout time.Duration, frameWriter *FrameWriter, logger *zap.Logger) *HeartbeatManager {
return &HeartbeatManager{
lastHeartbeat: time.Now(),
interval: interval,
timeout: timeout,
frameWriter: frameWriter,
logger: logger,
}
}
// UpdateLastHeartbeat updates the last heartbeat timestamp.
func (h *HeartbeatManager) UpdateLastHeartbeat() {
h.mu.Lock()
h.lastHeartbeat = time.Now()
h.mu.Unlock()
}
// GetLastHeartbeat returns the last heartbeat timestamp.
func (h *HeartbeatManager) GetLastHeartbeat() time.Time {
h.mu.RLock()
defer h.mu.RUnlock()
return h.lastHeartbeat
}
// IsAlive checks if the connection is still alive based on heartbeat timeout.
func (h *HeartbeatManager) IsAlive() bool {
h.mu.RLock()
defer h.mu.RUnlock()
return time.Since(h.lastHeartbeat) < h.timeout
}
// SendHeartbeatAck sends a heartbeat acknowledgment frame.
func (h *HeartbeatManager) SendHeartbeatAck() error {
ackFrame := NewFrame(FrameTypeHeartbeatAck, nil)
err := h.frameWriter.WriteControl(ackFrame)
if err != nil {
h.logger.Error("Failed to send heartbeat ack", zap.Error(err))
return err
}
return nil
}
// HandleHeartbeat handles a received heartbeat frame.
func (h *HeartbeatManager) HandleHeartbeat() error {
h.UpdateLastHeartbeat()
return h.SendHeartbeatAck()
}
// TimeSinceLastHeartbeat returns the duration since the last heartbeat.
func (h *HeartbeatManager) TimeSinceLastHeartbeat() time.Duration {
h.mu.RLock()
defer h.mu.RUnlock()
return time.Since(h.lastHeartbeat)
}

View File

@@ -14,7 +14,9 @@ type IPAccessControl struct {
type ProxyAuth struct {
Enabled bool `json:"enabled"`
Type string `json:"type,omitempty"`
Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"`
}
type RegisterRequest struct {

View File

@@ -8,8 +8,8 @@ import (
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
kernel32 = syscall.NewLazyDLL("kernel32.dll")
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
)
type memoryStatusEx struct {

View File

@@ -0,0 +1,49 @@
package utils
import "strings"
// AllowedList manages a list of allowed values with case-insensitive matching.
type AllowedList struct {
items []string
}
// NewAllowedList creates a new AllowedList from the given items.
func NewAllowedList(items []string) *AllowedList {
return &AllowedList{items: items}
}
// IsAllowed checks if a value is in the allowed list.
// If the list is empty, all values are allowed.
// Matching is case-insensitive.
func (a *AllowedList) IsAllowed(value string) bool {
if len(a.items) == 0 {
return true
}
for _, item := range a.items {
if strings.EqualFold(item, value) {
return true
}
}
return false
}
// GetPreferred returns the first item in the list, or the default value if empty.
func (a *AllowedList) GetPreferred(defaultValue string) string {
if len(a.items) == 0 {
return defaultValue
}
if len(a.items) == 1 {
return a.items[0]
}
return defaultValue
}
// Items returns the underlying slice of allowed items.
func (a *AllowedList) Items() []string {
return a.items
}
// IsEmpty returns true if the allowed list is empty.
func (a *AllowedList) IsEmpty() bool {
return len(a.items) == 0
}

View File

@@ -0,0 +1,35 @@
package utils
import "strings"
// IsNetworkError checks if an error message indicates a common network error
// that should be handled gracefully (not logged as severe errors).
func IsNetworkError(errStr string) bool {
return strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "use of closed network connection") ||
strings.Contains(errStr, "websocket: close")
}
// IsProtocolError checks if an error message indicates a protocol-level error
// (invalid requests, malformed data, etc.).
func IsProtocolError(errStr string) bool {
return strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "failed to parse HTTP request") ||
strings.Contains(errStr, "tunnel type not allowed")
}
// ContainsAny checks if a string contains any of the given substrings.
func ContainsAny(s string, substrings ...string) bool {
for _, substr := range substrings {
if strings.Contains(s, substr) {
return true
}
}
return false
}

View File

@@ -0,0 +1,42 @@
package utils
import (
"fmt"
"drip/internal/shared/protocol"
)
// TunnelURLBuilder helps construct tunnel URLs consistently.
type TunnelURLBuilder struct {
tunnelDomain string
publicPort int
}
// NewTunnelURLBuilder creates a new URL builder.
func NewTunnelURLBuilder(tunnelDomain string, publicPort int) *TunnelURLBuilder {
return &TunnelURLBuilder{
tunnelDomain: tunnelDomain,
publicPort: publicPort,
}
}
// BuildHTTPURL builds an HTTP/HTTPS tunnel URL.
func (b *TunnelURLBuilder) BuildHTTPURL(subdomain string) string {
if b.publicPort == 443 {
return fmt.Sprintf("https://%s.%s", subdomain, b.tunnelDomain)
}
return fmt.Sprintf("https://%s.%s:%d", subdomain, b.tunnelDomain, b.publicPort)
}
// BuildTCPURL builds a TCP tunnel URL.
func (b *TunnelURLBuilder) BuildTCPURL(port int) string {
return fmt.Sprintf("tcp://%s:%d", b.tunnelDomain, port)
}
// BuildURL builds a tunnel URL based on the tunnel type.
func (b *TunnelURLBuilder) BuildURL(subdomain string, tunnelType protocol.TunnelType, port int) string {
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
return b.BuildHTTPURL(subdomain)
}
return b.BuildTCPURL(port)
}

View File

@@ -12,15 +12,16 @@ import (
// TunnelConfig holds configuration for a predefined tunnel
type TunnelConfig struct {
Name string `yaml:"name"` // Tunnel name (required, unique identifier)
Type string `yaml:"type"` // Tunnel type: http, https, tcp (required)
Port int `yaml:"port"` // Local port to forward (required)
Address string `yaml:"address,omitempty"` // Local address (default: 127.0.0.1)
Subdomain string `yaml:"subdomain,omitempty"` // Custom subdomain
Transport string `yaml:"transport,omitempty"` // Transport: auto, tcp, wss
AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
Name string `yaml:"name"` // Tunnel name (required, unique identifier)
Type string `yaml:"type"` // Tunnel type: http, https, tcp (required)
Port int `yaml:"port"` // Local port to forward (required)
Address string `yaml:"address,omitempty"` // Local address (default: 127.0.0.1)
Subdomain string `yaml:"subdomain,omitempty"` // Custom subdomain
Transport string `yaml:"transport,omitempty"` // Transport: auto, tcp, wss
AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
AuthBearer string `yaml:"auth_bearer,omitempty"` // Proxy authentication bearer token (http/https only)
}
// Validate checks if the tunnel configuration is valid
@@ -44,6 +45,9 @@ func (t *TunnelConfig) Validate() error {
return fmt.Errorf("invalid transport '%s' for '%s': must be auto, tcp, or wss", t.Transport, t.Name)
}
}
if t.Auth != "" && t.AuthBearer != "" {
return fmt.Errorf("only one of auth or auth_bearer can be set for '%s'", t.Name)
}
return nil
}