mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat: Add Bearer Token authentication support and optimize code structure
- Add Bearer Token authentication, supporting tunnel access control via the --auth-bearer parameter - Refactor large modules into smaller, more focused components to improve code maintainability - Update dependency versions, including golang.org/x/crypto, golang.org/x/net, etc. - Add SilenceUsage and SilenceErrors configuration for all CLI commands - Modify connector configuration structure to support the new authentication method - Update recent change log in README with new feature descriptions BREAKING CHANGE: Authentication via Bearer Token is now supported, requiring the new --auth-bearer parameter
This commit is contained in:
@@ -33,6 +33,13 @@
|
||||
- **Actually free** - Use your own domain, no paid tiers or feature restrictions
|
||||
- **Open source** - BSD 3-Clause License
|
||||
|
||||
## Recent Changes
|
||||
|
||||
### 2025-01-29
|
||||
|
||||
- **Bearer Token Authentication** - Added bearer token authentication support for tunnel access control
|
||||
- **Code Optimization** - Refactored large modules into smaller, focused components for better maintainability
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Install
|
||||
|
||||
@@ -33,6 +33,13 @@
|
||||
- **真的免费** - 用你自己的域名,没有付费档位或功能阉割
|
||||
- **开源** - BSD 3-Clause 协议
|
||||
|
||||
## 最近更新
|
||||
|
||||
### 2025-01-29
|
||||
|
||||
- **Bearer Token 认证** - 新增 Bearer Token 认证支持,用于隧道访问控制
|
||||
- **代码优化** - 将大型模块重构为更小、更专注的组件,提升可维护性
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 安装
|
||||
|
||||
18
go.mod
18
go.mod
@@ -1,6 +1,6 @@
|
||||
module drip
|
||||
|
||||
go 1.25.4
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
@@ -10,9 +10,9 @@ require (
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/spf13/cobra v1.10.2
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/net v0.48.0
|
||||
golang.org/x/sys v0.39.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sys v0.40.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -21,12 +21,12 @@ require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.3 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.4 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.6.1 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.7.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -34,13 +34,13 @@ require (
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.4 // indirect
|
||||
github.com/prometheus/common v0.67.5 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
)
|
||||
|
||||
32
go.sum
32
go.sum
@@ -8,18 +8,18 @@ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI=
|
||||
github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI=
|
||||
github.com/charmbracelet/x/ansi v0.11.4 h1:6G65PLu6HjmE858CnTUQY1LXT3ZUWwfvqEROLF8vqHI=
|
||||
github.com/charmbracelet/x/ansi v0.11.4/go.mod h1:/5AZ+UfWExW3int5H5ugnsG/PWjNcSQcwYsHBlPFQN4=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.6.1 h1:/zMlAezfDzT2xy6acHBzwIfyu2ic0hgkT83UX5EY2gY=
|
||||
github.com/clipperhouse/displaywidth v0.6.1/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
|
||||
github.com/clipperhouse/displaywidth v0.7.0 h1:QNv1GYsnLX9QBrcWUtMlogpTXuM5FVnBwKWp1O5NwmE=
|
||||
github.com/clipperhouse/displaywidth v0.7.0/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.1 h1:RjM8gnVbFbgI67SBekIC7ihFpyXwRPYWXn9BZActHbw=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.1/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -57,8 +57,8 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
||||
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
||||
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
@@ -84,17 +84,17 @@ go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -28,9 +28,11 @@ Examples:
|
||||
drip attach tcp 5432 Attach to TCP tunnel on port 5432
|
||||
|
||||
Press Ctrl+C to detach (tunnel will continue running).`,
|
||||
Aliases: []string{"logs", "tail"},
|
||||
Args: cobra.MaximumNArgs(2),
|
||||
RunE: runAttach,
|
||||
Aliases: []string{"logs", "tail"},
|
||||
Args: cobra.MaximumNArgs(2),
|
||||
RunE: runAttach,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -13,44 +13,56 @@ import (
|
||||
)
|
||||
|
||||
var configCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Manage configuration",
|
||||
Long: "Manage Drip client configuration (server, token, tunnels)",
|
||||
Use: "config",
|
||||
Short: "Manage configuration",
|
||||
Long: "Manage Drip client configuration (server, token, tunnels)",
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
var configInitCmd = &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Initialize configuration interactively",
|
||||
Long: "Initialize Drip configuration with interactive prompts",
|
||||
RunE: runConfigInit,
|
||||
Use: "init",
|
||||
Short: "Initialize configuration interactively",
|
||||
Long: "Initialize Drip configuration with interactive prompts",
|
||||
RunE: runConfigInit,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
var configShowCmd = &cobra.Command{
|
||||
Use: "show",
|
||||
Short: "Show current configuration",
|
||||
Long: "Display the current Drip configuration",
|
||||
RunE: runConfigShow,
|
||||
Use: "show",
|
||||
Short: "Show current configuration",
|
||||
Long: "Display the current Drip configuration",
|
||||
RunE: runConfigShow,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
var configSetCmd = &cobra.Command{
|
||||
Use: "set",
|
||||
Short: "Set configuration values",
|
||||
Long: "Set specific configuration values (server, token)",
|
||||
RunE: runConfigSet,
|
||||
Use: "set",
|
||||
Short: "Set configuration values",
|
||||
Long: "Set specific configuration values (server, token)",
|
||||
RunE: runConfigSet,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
var configResetCmd = &cobra.Command{
|
||||
Use: "reset",
|
||||
Short: "Reset configuration",
|
||||
Long: "Delete the configuration file",
|
||||
RunE: runConfigReset,
|
||||
Use: "reset",
|
||||
Short: "Reset configuration",
|
||||
Long: "Delete the configuration file",
|
||||
RunE: runConfigReset,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
var configValidateCmd = &cobra.Command{
|
||||
Use: "validate",
|
||||
Short: "Validate configuration",
|
||||
Long: "Validate the configuration file",
|
||||
RunE: runConfigValidate,
|
||||
Use: "validate",
|
||||
Short: "Validate configuration",
|
||||
Long: "Validate the configuration file",
|
||||
RunE: runConfigValidate,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -19,6 +19,7 @@ var (
|
||||
allowIPs []string
|
||||
denyIPs []string
|
||||
authPass string
|
||||
authBearer string
|
||||
transport string
|
||||
)
|
||||
|
||||
@@ -34,6 +35,7 @@ Example:
|
||||
drip http 3000 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip http 3000 --auth secret Enable proxy authentication with password
|
||||
drip http 3000 --auth-bearer sk-xxx Enable proxy authentication with bearer token
|
||||
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
|
||||
Configuration:
|
||||
@@ -44,8 +46,10 @@ Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTP,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTP,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -55,6 +59,7 @@ func init() {
|
||||
httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
|
||||
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpCmd.Flags().MarkHidden("daemon-child")
|
||||
@@ -71,6 +76,10 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
if authPass != "" && authBearer != "" {
|
||||
return fmt.Errorf("cannot use --auth and --auth-bearer together")
|
||||
}
|
||||
|
||||
serverAddr, token, err := resolveServerAddrAndToken("http", port)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -87,6 +96,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
AuthPass: authPass,
|
||||
AuthBearer: authBearer,
|
||||
Transport: parseTransport(transport),
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ Example:
|
||||
drip https 443 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip https 443 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip https 443 --auth secret Enable proxy authentication with password
|
||||
drip https 443 --auth-bearer sk-xxx Enable proxy authentication with bearer token
|
||||
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
|
||||
Configuration:
|
||||
@@ -32,8 +33,10 @@ Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTPS,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTPS,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -43,6 +46,7 @@ func init() {
|
||||
httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpsCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
|
||||
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
@@ -59,6 +63,10 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
if authPass != "" && authBearer != "" {
|
||||
return fmt.Errorf("cannot use --auth and --auth-bearer together")
|
||||
}
|
||||
|
||||
serverAddr, token, err := resolveServerAddrAndToken("https", port)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -75,6 +83,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
AuthPass: authPass,
|
||||
AuthBearer: authBearer,
|
||||
Transport: parseTransport(transport),
|
||||
}
|
||||
|
||||
|
||||
@@ -33,8 +33,10 @@ This command shows:
|
||||
In interactive mode, you can select a tunnel to:
|
||||
- Attach: View real-time logs
|
||||
- Stop: Terminate the tunnel`,
|
||||
Aliases: []string{"ls", "ps", "status"},
|
||||
RunE: runList,
|
||||
Aliases: []string{"ls", "ps", "status"},
|
||||
RunE: runList,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -45,6 +45,8 @@ Features:
|
||||
✓ Auto-save configuration
|
||||
✓ Custom subdomains
|
||||
✓ Authentication via token`,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -22,28 +22,30 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverTunnelDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverTransports string
|
||||
serverTunnelTypes string
|
||||
serverConfigFile string
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverTunnelDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverTransports string
|
||||
serverTunnelTypes string
|
||||
serverConfigFile string
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
Use: "server",
|
||||
Short: "Start Drip server",
|
||||
Long: `Start the Drip tunnel server to accept client connections`,
|
||||
RunE: runServer,
|
||||
Use: "server",
|
||||
Short: "Start Drip server",
|
||||
Long: `Start the Drip tunnel server to accept client connections`,
|
||||
RunE: runServer,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -285,11 +287,29 @@ func runServer(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", cfg.Port)
|
||||
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, cfg.Domain, cfg.TunnelDomain, cfg.AuthToken, cfg.MetricsToken)
|
||||
httpHandler := proxy.NewHandler(proxy.HandlerConfig{
|
||||
Manager: tunnelManager,
|
||||
Logger: logger,
|
||||
ServerDomain: cfg.Domain,
|
||||
TunnelDomain: cfg.TunnelDomain,
|
||||
AuthToken: cfg.AuthToken,
|
||||
MetricsToken: cfg.MetricsToken,
|
||||
})
|
||||
httpHandler.SetAllowedTransports(cfg.AllowedTransports)
|
||||
httpHandler.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
|
||||
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, cfg.AuthToken, tunnelManager, logger, portAllocator, cfg.Domain, cfg.TunnelDomain, cfg.PublicPort, httpHandler)
|
||||
listener := tcp.NewListener(tcp.ListenerConfig{
|
||||
Address: listenAddr,
|
||||
TLSConfig: tlsConfig,
|
||||
AuthToken: cfg.AuthToken,
|
||||
Manager: tunnelManager,
|
||||
Logger: logger,
|
||||
PortAlloc: portAllocator,
|
||||
Domain: cfg.Domain,
|
||||
TunnelDomain: cfg.TunnelDomain,
|
||||
PublicPort: cfg.PublicPort,
|
||||
HTTPHandler: httpHandler,
|
||||
})
|
||||
listener.SetAllowedTransports(cfg.AllowedTransports)
|
||||
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
|
||||
|
||||
|
||||
@@ -54,7 +54,9 @@ Configuration file example (~/.drip/config.yaml):
|
||||
allow_ips:
|
||||
- 192.168.0.0/16
|
||||
- 10.0.0.0/8`,
|
||||
RunE: runStart,
|
||||
RunE: runStart,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -238,6 +240,7 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
|
||||
AllowIPs: t.AllowIPs,
|
||||
DenyIPs: t.DenyIPs,
|
||||
AuthPass: t.Auth,
|
||||
AuthBearer: t.AuthBearer,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,9 +19,11 @@ Examples:
|
||||
drip stop all Stop all running tunnels
|
||||
|
||||
Use 'drip list' to see running tunnels.`,
|
||||
Aliases: []string{"kill"},
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: runStop,
|
||||
Aliases: []string{"kill"},
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: runStop,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -41,8 +41,10 @@ Transport options:
|
||||
|
||||
Note: TCP tunnels require dynamic port allocation on the server.
|
||||
When using CDN (--transport wss), the server must still expose the allocated port directly.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runTCP,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runTCP,
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -24,6 +24,12 @@ func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAd
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if authPass != "" {
|
||||
daemonArgs = append(daemonArgs, "--auth", authPass)
|
||||
}
|
||||
if authBearer != "" {
|
||||
daemonArgs = append(daemonArgs, "--auth-bearer", authBearer)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
|
||||
197
internal/client/tcp/connection_dialer.go
Normal file
197
internal/client/tcp/connection_dialer.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/wsutil"
|
||||
)
|
||||
|
||||
// serverCapabilities holds the discovered server capabilities
|
||||
type serverCapabilities struct {
|
||||
Transports []string `json:"transports"`
|
||||
Preferred string `json:"preferred"`
|
||||
}
|
||||
|
||||
// ConnectionDialer handles establishing connections to the server.
|
||||
type ConnectionDialer struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
token string
|
||||
transport TransportType
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewConnectionDialer creates a new connection dialer.
|
||||
func NewConnectionDialer(
|
||||
serverAddr string,
|
||||
tlsConfig *tls.Config,
|
||||
token string,
|
||||
transport TransportType,
|
||||
logger *zap.Logger,
|
||||
) *ConnectionDialer {
|
||||
return &ConnectionDialer{
|
||||
serverAddr: serverAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: token,
|
||||
transport: transport,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial establishes a connection using the appropriate transport.
|
||||
func (d *ConnectionDialer) Dial() (net.Conn, error) {
|
||||
switch d.transport {
|
||||
case TransportWebSocket:
|
||||
return d.dialWebSocket()
|
||||
case TransportTCP:
|
||||
// User explicitly requested TCP, verify server supports it
|
||||
caps := d.discoverServerCapabilities()
|
||||
if caps != nil && len(caps.Transports) > 0 {
|
||||
tcpSupported := false
|
||||
for _, t := range caps.Transports {
|
||||
if t == "tcp" {
|
||||
tcpSupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tcpSupported {
|
||||
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
|
||||
}
|
||||
}
|
||||
return d.dialTLS()
|
||||
default: // TransportAuto
|
||||
// Check if server address indicates WebSocket
|
||||
if strings.HasPrefix(d.serverAddr, "wss://") {
|
||||
return d.dialWebSocket()
|
||||
}
|
||||
// Query server for preferred transport
|
||||
caps := d.discoverServerCapabilities()
|
||||
if caps != nil && caps.Preferred == "wss" {
|
||||
return d.dialWebSocket()
|
||||
}
|
||||
// Default to TCP
|
||||
return d.dialTLS()
|
||||
}
|
||||
}
|
||||
|
||||
// dialTLS establishes a TLS connection to the server.
|
||||
func (d *ConnectionDialer) dialTLS() (net.Conn, error) {
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", d.serverAddr, d.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialWebSocket establishes a WebSocket connection to the server over TLS.
|
||||
func (d *ConnectionDialer) dialWebSocket() (net.Conn, error) {
|
||||
// Build WebSocket URL
|
||||
host, port, err := net.SplitHostPort(d.serverAddr)
|
||||
if err != nil {
|
||||
// No port specified, use default
|
||||
host = d.serverAddr
|
||||
port = "443"
|
||||
}
|
||||
|
||||
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
|
||||
|
||||
d.logger.Debug("Connecting via WebSocket over TLS",
|
||||
zap.String("url", wsURL),
|
||||
)
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
TLSClientConfig: d.tlsConfig,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: 256 * 1024,
|
||||
WriteBufferSize: 256 * 1024,
|
||||
}
|
||||
|
||||
// Add authorization header if token is set
|
||||
header := http.Header{}
|
||||
if d.token != "" {
|
||||
header.Set("Authorization", "Bearer "+d.token)
|
||||
}
|
||||
|
||||
ws, resp, err := dialer.Dial(wsURL, header)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
|
||||
}
|
||||
|
||||
d.logger.Info("Connected via WebSocket over TLS",
|
||||
zap.String("url", wsURL),
|
||||
)
|
||||
|
||||
// Wrap WebSocket connection to implement net.Conn with ping loop for CDN keep-alive
|
||||
return wsutil.NewConnWithPing(ws, 30*time.Second), nil
|
||||
}
|
||||
|
||||
// discoverServerCapabilities queries the server for its capabilities.
|
||||
func (d *ConnectionDialer) discoverServerCapabilities() *serverCapabilities {
|
||||
host, port, err := net.SplitHostPort(d.serverAddr)
|
||||
if err != nil {
|
||||
host = d.serverAddr
|
||||
port = "443"
|
||||
}
|
||||
|
||||
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: d.tlsConfig,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(discoverURL)
|
||||
if err != nil {
|
||||
d.logger.Debug("Failed to discover server capabilities",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
var caps serverCapabilities
|
||||
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.logger.Debug("Discovered server capabilities",
|
||||
zap.Strings("transports", caps.Transports),
|
||||
zap.String("preferred", caps.Preferred),
|
||||
)
|
||||
|
||||
return &caps
|
||||
}
|
||||
@@ -41,7 +41,8 @@ type ConnectorConfig struct {
|
||||
DenyIPs []string
|
||||
|
||||
// Proxy authentication
|
||||
AuthPass string
|
||||
AuthPass string
|
||||
AuthBearer string
|
||||
|
||||
// Transport protocol selection
|
||||
Transport TransportType
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/stats"
|
||||
"drip/internal/shared/wsutil"
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
@@ -71,11 +69,18 @@ type PoolClient struct {
|
||||
allowIPs []string
|
||||
denyIPs []string
|
||||
|
||||
authPass string
|
||||
authPass string
|
||||
authBearer string
|
||||
|
||||
// Transport protocol selection
|
||||
transport TransportType
|
||||
insecure bool
|
||||
|
||||
// Connection dialer
|
||||
dialer *ConnectionDialer
|
||||
|
||||
// Session scaler
|
||||
scaler *SessionScaler
|
||||
}
|
||||
|
||||
// NewPoolClient creates a new pool client.
|
||||
@@ -169,8 +174,10 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
allowIPs: cfg.AllowIPs,
|
||||
denyIPs: cfg.DenyIPs,
|
||||
authPass: cfg.AuthPass,
|
||||
authBearer: cfg.AuthBearer,
|
||||
transport: transport,
|
||||
insecure: cfg.Insecure,
|
||||
dialer: NewConnectionDialer(serverAddr, tlsConfig, cfg.Token, transport, logger),
|
||||
}
|
||||
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
@@ -183,7 +190,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
|
||||
// Connect establishes the primary connection and starts background workers.
|
||||
func (c *PoolClient) Connect() error {
|
||||
primaryConn, err := c.dial()
|
||||
primaryConn, err := c.dialer.Dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -208,9 +215,16 @@ func (c *PoolClient) Connect() error {
|
||||
}
|
||||
}
|
||||
|
||||
if c.authPass != "" {
|
||||
if c.authBearer != "" {
|
||||
req.ProxyAuth = &protocol.ProxyAuth{
|
||||
Enabled: true,
|
||||
Type: "bearer",
|
||||
Token: c.authBearer,
|
||||
}
|
||||
} else if c.authPass != "" {
|
||||
req.ProxyAuth = &protocol.ProxyAuth{
|
||||
Enabled: true,
|
||||
Type: "password",
|
||||
Password: c.authPass,
|
||||
}
|
||||
}
|
||||
@@ -299,8 +313,9 @@ func (c *PoolClient) Connect() error {
|
||||
|
||||
c.warmupSessions()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.scalerLoop()
|
||||
// Initialize and start session scaler
|
||||
c.scaler = NewSessionScaler(c, c.logger, c.stopCh, &c.wg)
|
||||
c.scaler.Start()
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -311,162 +326,6 @@ func (c *PoolClient) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) dialTLS() (net.Conn, error) {
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// serverCapabilities holds the discovered server capabilities
|
||||
type serverCapabilities struct {
|
||||
Transports []string `json:"transports"`
|
||||
Preferred string `json:"preferred"`
|
||||
}
|
||||
|
||||
// dial selects the appropriate transport and establishes a connection
|
||||
func (c *PoolClient) dial() (net.Conn, error) {
|
||||
switch c.transport {
|
||||
case TransportWebSocket:
|
||||
return c.dialWebSocket()
|
||||
case TransportTCP:
|
||||
// User explicitly requested TCP, verify server supports it
|
||||
caps := c.discoverServerCapabilities()
|
||||
if caps != nil && len(caps.Transports) > 0 {
|
||||
tcpSupported := false
|
||||
for _, t := range caps.Transports {
|
||||
if t == "tcp" {
|
||||
tcpSupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tcpSupported {
|
||||
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
|
||||
}
|
||||
}
|
||||
return c.dialTLS()
|
||||
default: // TransportAuto
|
||||
// Check if server address indicates WebSocket
|
||||
if strings.HasPrefix(c.serverAddr, "wss://") {
|
||||
return c.dialWebSocket()
|
||||
}
|
||||
// Query server for preferred transport
|
||||
caps := c.discoverServerCapabilities()
|
||||
if caps != nil && caps.Preferred == "wss" {
|
||||
return c.dialWebSocket()
|
||||
}
|
||||
// Default to TCP
|
||||
return c.dialTLS()
|
||||
}
|
||||
}
|
||||
|
||||
// discoverServerCapabilities queries the server for its capabilities
|
||||
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
|
||||
host, port, err := net.SplitHostPort(c.serverAddr)
|
||||
if err != nil {
|
||||
host = c.serverAddr
|
||||
port = "443"
|
||||
}
|
||||
|
||||
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: c.tlsConfig,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(discoverURL)
|
||||
if err != nil {
|
||||
c.logger.Debug("Failed to discover server capabilities",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
var caps serverCapabilities
|
||||
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debug("Discovered server capabilities",
|
||||
zap.Strings("transports", caps.Transports),
|
||||
zap.String("preferred", caps.Preferred),
|
||||
)
|
||||
|
||||
return &caps
|
||||
}
|
||||
|
||||
// dialWebSocket establishes a WebSocket connection to the server over TLS
|
||||
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
|
||||
// Build WebSocket URL
|
||||
host, port, err := net.SplitHostPort(c.serverAddr)
|
||||
if err != nil {
|
||||
// No port specified, use default
|
||||
host = c.serverAddr
|
||||
port = "443"
|
||||
}
|
||||
|
||||
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
|
||||
|
||||
c.logger.Debug("Connecting via WebSocket over TLS",
|
||||
zap.String("url", wsURL),
|
||||
)
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
TLSClientConfig: c.tlsConfig,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: 256 * 1024,
|
||||
WriteBufferSize: 256 * 1024,
|
||||
}
|
||||
|
||||
// Add authorization header if token is set
|
||||
header := http.Header{}
|
||||
if c.token != "" {
|
||||
header.Set("Authorization", "Bearer "+c.token)
|
||||
}
|
||||
|
||||
ws, resp, err := dialer.Dial(wsURL, header)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
|
||||
}
|
||||
|
||||
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
|
||||
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
|
||||
|
||||
c.logger.Debug("WebSocket connection established",
|
||||
zap.String("remote_addr", ws.RemoteAddr().String()),
|
||||
)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
|
||||
defer c.wg.Done()
|
||||
|
||||
@@ -623,12 +482,12 @@ func (c *PoolClient) Close() error {
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (c *PoolClient) Wait() { <-c.doneCh }
|
||||
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
||||
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
||||
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
||||
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
||||
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
||||
func (c *PoolClient) Wait() { <-c.doneCh }
|
||||
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
||||
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
||||
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
||||
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
||||
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
||||
|
||||
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
|
||||
if cb == nil {
|
||||
|
||||
@@ -188,7 +188,7 @@ func (c *PoolClient) addDataSession() error {
|
||||
return fmt.Errorf("server does not support data connections")
|
||||
}
|
||||
|
||||
conn, err := c.dial()
|
||||
conn, err := c.dialer.Dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
150
internal/client/tcp/session_scaler.go
Normal file
150
internal/client/tcp/session_scaler.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SessionScaler manages automatic scaling of yamux sessions based on load.
|
||||
type SessionScaler struct {
|
||||
client *PoolClient
|
||||
logger *zap.Logger
|
||||
stopCh <-chan struct{}
|
||||
wg *sync.WaitGroup
|
||||
|
||||
// Scaling configuration
|
||||
checkInterval time.Duration
|
||||
scaleUpCooldown time.Duration
|
||||
scaleDownCooldown time.Duration
|
||||
capacityPerSession int64
|
||||
scaleUpLoad float64
|
||||
scaleDownLoad float64
|
||||
burstThreshold float64
|
||||
maxBurstAdd int
|
||||
}
|
||||
|
||||
// NewSessionScaler creates a new session scaler.
|
||||
func NewSessionScaler(
|
||||
client *PoolClient,
|
||||
logger *zap.Logger,
|
||||
stopCh <-chan struct{},
|
||||
wg *sync.WaitGroup,
|
||||
) *SessionScaler {
|
||||
return &SessionScaler{
|
||||
client: client,
|
||||
logger: logger,
|
||||
stopCh: stopCh,
|
||||
wg: wg,
|
||||
checkInterval: 1 * time.Second,
|
||||
scaleUpCooldown: 1 * time.Second,
|
||||
scaleDownCooldown: 60 * time.Second,
|
||||
capacityPerSession: 256,
|
||||
scaleUpLoad: 0.6,
|
||||
scaleDownLoad: 0.2,
|
||||
burstThreshold: 0.9,
|
||||
maxBurstAdd: 4,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the scaler loop.
|
||||
func (s *SessionScaler) Start() {
|
||||
s.wg.Add(1)
|
||||
go s.scalerLoop()
|
||||
}
|
||||
|
||||
// scalerLoop monitors load and adjusts session count.
|
||||
func (s *SessionScaler) scalerLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(s.checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.client.mu.Lock()
|
||||
desired := s.client.desiredTotal
|
||||
if desired == 0 {
|
||||
desired = s.client.initialSessions
|
||||
s.client.desiredTotal = desired
|
||||
}
|
||||
lastScale := s.client.lastScale
|
||||
s.client.mu.Unlock()
|
||||
|
||||
current := s.client.sessionCount()
|
||||
if current == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
activeConns := s.client.stats.GetActiveConnections()
|
||||
capacity := int64(current) * s.capacityPerSession
|
||||
load := float64(activeConns) / float64(capacity)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Burst scaling: rapid scale-up under extreme load
|
||||
if load > s.burstThreshold && current < s.client.maxSessions {
|
||||
toAdd := min(s.maxBurstAdd, s.client.maxSessions-current)
|
||||
s.logger.Info("Burst scaling up sessions",
|
||||
zap.Int("current", current),
|
||||
zap.Int("adding", toAdd),
|
||||
zap.Float64("load", load),
|
||||
)
|
||||
for i := 0; i < toAdd; i++ {
|
||||
_ = s.client.addDataSession()
|
||||
}
|
||||
s.client.mu.Lock()
|
||||
s.client.desiredTotal = current + toAdd
|
||||
s.client.lastScale = now
|
||||
s.client.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Scale up: add sessions when load is high
|
||||
if load > s.scaleUpLoad && current < s.client.maxSessions {
|
||||
if now.Sub(lastScale) < s.scaleUpCooldown {
|
||||
continue
|
||||
}
|
||||
newDesired := min(desired+1, s.client.maxSessions)
|
||||
if newDesired > desired {
|
||||
s.logger.Debug("Scaling up sessions",
|
||||
zap.Int("current", current),
|
||||
zap.Int("desired", newDesired),
|
||||
zap.Float64("load", load),
|
||||
)
|
||||
_ = s.client.addDataSession()
|
||||
s.client.mu.Lock()
|
||||
s.client.desiredTotal = newDesired
|
||||
s.client.lastScale = now
|
||||
s.client.mu.Unlock()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Scale down: remove idle sessions when load is low
|
||||
if load < s.scaleDownLoad && current > s.client.minSessions {
|
||||
if now.Sub(lastScale) < s.scaleDownCooldown {
|
||||
continue
|
||||
}
|
||||
newDesired := max(desired-1, s.client.minSessions)
|
||||
if newDesired < desired && current > newDesired {
|
||||
s.logger.Debug("Scaling down sessions",
|
||||
zap.Int("current", current),
|
||||
zap.Int("desired", newDesired),
|
||||
zap.Float64("load", load),
|
||||
)
|
||||
s.client.removeIdleSessions(1)
|
||||
s.client.mu.Lock()
|
||||
s.client.desiredTotal = newDesired
|
||||
s.client.lastScale = now
|
||||
s.client.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
412
internal/server/proxy/auth_handler.go
Normal file
412
internal/server/proxy/auth_handler.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
const authCookieName = "drip_auth"
|
||||
const authSessionDuration = 24 * time.Hour
|
||||
|
||||
const (
|
||||
authRateLimitWindow = 1 * time.Minute
|
||||
authRateLimitMax = 10
|
||||
authRateLimitLockout = 5 * time.Minute
|
||||
authRateLimitLockoutThreshold = 20
|
||||
)
|
||||
|
||||
type authSession struct {
|
||||
subdomain string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type authSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*authSession
|
||||
}
|
||||
|
||||
type authRateLimitEntry struct {
|
||||
failures int
|
||||
windowStart time.Time
|
||||
lockedUntil time.Time
|
||||
}
|
||||
|
||||
type authRateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*authRateLimitEntry
|
||||
}
|
||||
|
||||
var sessionStore = &authSessionStore{
|
||||
sessions: make(map[string]*authSession),
|
||||
}
|
||||
|
||||
var authLimiter = &authRateLimiter{
|
||||
entries: make(map[string]*authRateLimitEntry),
|
||||
}
|
||||
|
||||
func init() {
|
||||
go authLimiter.startCleanupLoop()
|
||||
go sessionStore.startCleanupLoop()
|
||||
}
|
||||
|
||||
func (rl *authRateLimiter) startCleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *authSessionStore) startCleanupLoop() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *authSessionStore) cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for token, session := range s.sessions {
|
||||
if now.After(session.expiresAt) {
|
||||
delete(s.sessions, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *authRateLimiter) isRateLimited(ip string) bool {
|
||||
if ip == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
rl.mu.RLock()
|
||||
entry, exists := rl.entries[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if !entry.lockedUntil.IsZero() && now.Before(entry.lockedUntil) {
|
||||
return true
|
||||
}
|
||||
|
||||
if now.Sub(entry.windowStart) < authRateLimitWindow && entry.failures >= authRateLimitMax {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (rl *authRateLimiter) recordFailure(ip string) {
|
||||
if ip == "" {
|
||||
return
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, exists := rl.entries[ip]
|
||||
|
||||
if !exists {
|
||||
rl.entries[ip] = &authRateLimitEntry{
|
||||
failures: 1,
|
||||
windowStart: now,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if now.Sub(entry.windowStart) >= authRateLimitWindow {
|
||||
entry.failures = 1
|
||||
entry.windowStart = now
|
||||
entry.lockedUntil = time.Time{}
|
||||
return
|
||||
}
|
||||
|
||||
entry.failures++
|
||||
|
||||
if entry.failures >= authRateLimitLockoutThreshold {
|
||||
entry.lockedUntil = now.Add(authRateLimitLockout)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *authRateLimiter) resetFailures(ip string) {
|
||||
if ip == "" {
|
||||
return
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
delete(rl.entries, ip)
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
func (rl *authRateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for ip, entry := range rl.entries {
|
||||
windowExpired := now.Sub(entry.windowStart) >= authRateLimitWindow
|
||||
lockoutExpired := entry.lockedUntil.IsZero() || now.After(entry.lockedUntil)
|
||||
if windowExpired && lockoutExpired {
|
||||
delete(rl.entries, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *authSessionStore) create(subdomain string) string {
|
||||
token := generateSessionToken()
|
||||
s.mu.Lock()
|
||||
s.sessions[token] = &authSession{
|
||||
subdomain: subdomain,
|
||||
expiresAt: time.Now().Add(authSessionDuration),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *authSessionStore) validate(token, subdomain string) bool {
|
||||
s.mu.RLock()
|
||||
session, ok := s.sessions[token]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if time.Now().After(session.expiresAt) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, token)
|
||||
s.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
return session.subdomain == subdomain
|
||||
}
|
||||
|
||||
func generateSessionToken() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
hash := sha256.Sum256(b)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func isBearerProxyAuth(auth *protocol.ProxyAuth) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if auth.Type != "" {
|
||||
return strings.EqualFold(auth.Type, "bearer")
|
||||
}
|
||||
return auth.Token != ""
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
if header == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Fields(header)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
if !strings.EqualFold(parts[0], "Bearer") {
|
||||
return ""
|
||||
}
|
||||
return parts[1]
|
||||
}
|
||||
|
||||
func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool {
|
||||
cookie, err := r.Cookie(authCookieName + "_" + subdomain)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return sessionStore.validate(cookie.Value, subdomain)
|
||||
}
|
||||
|
||||
func (h *Handler) isBearerAuthenticated(r *http.Request, auth *protocol.ProxyAuth) bool {
|
||||
token := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(auth.Token)) == 1
|
||||
}
|
||||
|
||||
func (h *Handler) serveBearerAuthRequired(w http.ResponseWriter, realm string) {
|
||||
if realm == "" {
|
||||
realm = "drip"
|
||||
}
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm))
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
http.Error(w, "Unauthorized: provide bearer token via Authorization header", http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) {
|
||||
h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, "")
|
||||
}
|
||||
|
||||
func (h *Handler) handleProxyLoginWithRateLimit(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string, clientIP string) {
|
||||
if r.Method != http.MethodPost {
|
||||
h.serveLoginPage(w, r, subdomain, "")
|
||||
return
|
||||
}
|
||||
|
||||
if clientIP != "" && authLimiter.isRateLimited(clientIP) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
h.serveLoginPage(w, r, subdomain, "Invalid form data")
|
||||
return
|
||||
}
|
||||
|
||||
password := r.FormValue("password")
|
||||
|
||||
if !tconn.ValidateProxyAuth(password) {
|
||||
if clientIP != "" {
|
||||
authLimiter.recordFailure(clientIP)
|
||||
}
|
||||
h.serveLoginPage(w, r, subdomain, "Invalid password")
|
||||
return
|
||||
}
|
||||
|
||||
if clientIP != "" {
|
||||
authLimiter.resetFailures(clientIP)
|
||||
}
|
||||
|
||||
token := sessionStore.create(subdomain)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: authCookieName + "_" + subdomain,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: int(authSessionDuration.Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
redirectURL := r.FormValue("redirect")
|
||||
if redirectURL == "" || redirectURL == "/_drip/login" {
|
||||
redirectURL = "/"
|
||||
}
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) {
|
||||
redirectURL := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
redirectURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
if redirectURL == "/_drip/login" {
|
||||
redirectURL = "/"
|
||||
}
|
||||
|
||||
errorHTML := ""
|
||||
if errorMsg != "" {
|
||||
errorHTML = fmt.Sprintf(`<p class="error">%s</p>`, html.EscapeString(errorMsg))
|
||||
}
|
||||
|
||||
safeRedirectURL := html.EscapeString(redirectURL)
|
||||
|
||||
htmlContent := fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>%s - Drip</title>
|
||||
`+faviconLink+`
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
p { margin-bottom: 24px; }
|
||||
.error { color: #cf222e; margin-bottom: 16px; }
|
||||
.input-wrap {
|
||||
position: relative;
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 12px;
|
||||
display: flex;
|
||||
}
|
||||
.input-wrap input {
|
||||
flex: 1;
|
||||
margin: 0;
|
||||
padding: 12px 16px;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
|
||||
font-size: 14px;
|
||||
background: transparent;
|
||||
border: none;
|
||||
outline: none;
|
||||
}
|
||||
.input-wrap button {
|
||||
background: #24292f;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
margin: 4px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.input-wrap button:hover { background: #32383f; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>🔒</span>%s</h1>
|
||||
<p class="desc">This tunnel is password protected</p>
|
||||
</header>
|
||||
|
||||
%s
|
||||
<form method="POST" action="/_drip/login">
|
||||
<input type="hidden" name="redirect" value="%s" />
|
||||
<div class="input-wrap">
|
||||
<input type="password" name="password" placeholder="Enter password" required autofocus />
|
||||
<button type="submit">Continue</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, subdomain, subdomain, errorHTML, safeRedirectURL)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(htmlContent))
|
||||
}
|
||||
@@ -2,12 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,18 +12,14 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/wsutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// bufio.Reader pool to reduce allocations on hot path
|
||||
@@ -38,56 +30,14 @@ var bufioReaderPool = sync.Pool{
|
||||
}
|
||||
|
||||
const openStreamTimeout = 3 * time.Second
|
||||
const authCookieName = "drip_auth"
|
||||
const authSessionDuration = 24 * time.Hour
|
||||
|
||||
type authSession struct {
|
||||
subdomain string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type authSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*authSession
|
||||
}
|
||||
|
||||
var sessionStore = &authSessionStore{
|
||||
sessions: make(map[string]*authSession),
|
||||
}
|
||||
|
||||
func (s *authSessionStore) create(subdomain string) string {
|
||||
token := generateSessionToken()
|
||||
s.mu.Lock()
|
||||
s.sessions[token] = &authSession{
|
||||
subdomain: subdomain,
|
||||
expiresAt: time.Now().Add(authSessionDuration),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *authSessionStore) validate(token, subdomain string) bool {
|
||||
s.mu.RLock()
|
||||
session, ok := s.sessions[token]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if time.Now().After(session.expiresAt) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, token)
|
||||
s.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
return session.subdomain == subdomain
|
||||
}
|
||||
|
||||
func generateSessionToken() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
hash := sha256.Sum256(b)
|
||||
return hex.EncodeToString(hash[:])
|
||||
type HandlerConfig struct {
|
||||
Manager *tunnel.Manager
|
||||
Logger *zap.Logger
|
||||
ServerDomain string
|
||||
TunnelDomain string
|
||||
AuthToken string
|
||||
MetricsToken string
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
@@ -113,32 +63,14 @@ type WSConnectionHandler interface {
|
||||
HandleWSConnection(conn net.Conn, remoteAddr string)
|
||||
}
|
||||
|
||||
var privateNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
privateCIDRs := []string{
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
}
|
||||
for _, cidr := range privateCIDRs {
|
||||
_, ipNet, _ := net.ParseCIDR(cidr)
|
||||
privateNetworks = append(privateNetworks, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, serverDomain, tunnelDomain string, authToken string, metricsToken string) *Handler {
|
||||
func NewHandler(cfg HandlerConfig) *Handler {
|
||||
return &Handler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
serverDomain: serverDomain,
|
||||
tunnelDomain: tunnelDomain,
|
||||
authToken: authToken,
|
||||
metricsToken: metricsToken,
|
||||
manager: cfg.Manager,
|
||||
logger: cfg.Logger,
|
||||
serverDomain: cfg.ServerDomain,
|
||||
tunnelDomain: cfg.TunnelDomain,
|
||||
authToken: cfg.AuthToken,
|
||||
metricsToken: cfg.MetricsToken,
|
||||
wsUpgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 256 * 1024,
|
||||
WriteBufferSize: 256 * 1024,
|
||||
@@ -253,22 +185,38 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if tconn.HasIPAccessControl() {
|
||||
clientIP := h.extractClientIP(r)
|
||||
clientIP := netutil.ExtractClientIP(r)
|
||||
if !tconn.IsIPAllowed(clientIP) {
|
||||
http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check proxy authentication
|
||||
if tconn.HasProxyAuth() {
|
||||
if r.URL.Path == "/_drip/login" {
|
||||
h.handleProxyLogin(w, r, tconn, subdomain)
|
||||
if auth := tconn.GetProxyAuth(); auth != nil && auth.Enabled {
|
||||
clientIP := netutil.ExtractClientIP(r)
|
||||
|
||||
if authLimiter.isRateLimited(clientIP) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
if !h.isProxyAuthenticated(r, subdomain) {
|
||||
h.serveLoginPage(w, r, subdomain, "")
|
||||
return
|
||||
|
||||
if isBearerProxyAuth(auth) {
|
||||
if !h.isBearerAuthenticated(r, auth) {
|
||||
authLimiter.recordFailure(clientIP)
|
||||
h.serveBearerAuthRequired(w, "drip")
|
||||
return
|
||||
}
|
||||
authLimiter.resetFailures(clientIP)
|
||||
} else {
|
||||
if r.URL.Path == "/_drip/login" {
|
||||
h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, clientIP)
|
||||
return
|
||||
}
|
||||
if !h.isProxyAuthenticated(r, subdomain) {
|
||||
h.serveLoginPage(w, r, subdomain, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,14 +231,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if httputil.IsWebSocketUpgrade(r) {
|
||||
if h.isWebSocketUpgrade(r) {
|
||||
h.handleWebSocket(w, r, tconn)
|
||||
return
|
||||
}
|
||||
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
httputil.SetCloseConnection(w)
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
@@ -305,7 +253,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
if err := r.Write(countingStream); err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
httputil.SetCloseConnection(w)
|
||||
_ = r.Body.Close()
|
||||
http.Error(w, "Forward failed", http.StatusBadGateway)
|
||||
return
|
||||
@@ -316,7 +264,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
resp, err := http.ReadResponse(reader, r)
|
||||
if err != nil {
|
||||
bufioReaderPool.Put(reader)
|
||||
w.Header().Set("Connection", "close")
|
||||
httputil.SetCloseConnection(w)
|
||||
http.Error(w, "Read response failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
@@ -334,7 +282,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
httputil.SetContentLength(w, resp.ContentLength)
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
@@ -343,7 +291,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
httputil.SetContentLength(w, resp.ContentLength)
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
@@ -396,58 +344,6 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
tconn.IncActiveConnections()
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, clientBuf, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
stream.Close()
|
||||
clientConn.Close()
|
||||
tconn.DecActiveConnections()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
var clientRW io.ReadWriteCloser = clientConn
|
||||
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
|
||||
clientRW = &bufferedReadWriteCloser{
|
||||
Reader: clientBuf.Reader,
|
||||
Conn: clientConn,
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
|
||||
for key, values := range src {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
@@ -530,571 +426,18 @@ func (h *Handler) extractSubdomain(host string) (string, subdomainResult) {
|
||||
return "", subdomainNotFound
|
||||
}
|
||||
|
||||
// extractClientIP extracts the client IP from the request.
|
||||
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
|
||||
// comes from a private/loopback network (typical reverse proxy setup).
|
||||
func (h *Handler) extractClientIP(r *http.Request) string {
|
||||
// First, get the direct remote address
|
||||
remoteIP := h.extractRemoteIP(r.RemoteAddr)
|
||||
|
||||
// Only trust proxy headers if the request comes from a private network
|
||||
if isPrivateIP(remoteIP) {
|
||||
// Check X-Forwarded-For header (may contain multiple IPs)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
func (h *Handler) validateMetricsAuth(w http.ResponseWriter, r *http.Request, realm string) bool {
|
||||
if h.metricsToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fall back to remote address
|
||||
return remoteIP
|
||||
}
|
||||
token := extractBearerToken(r.Header.Get("Authorization"))
|
||||
|
||||
// extractRemoteIP extracts the IP address from a remote address string (host:port format).
|
||||
func (h *Handler) extractRemoteIP(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// isPrivateIP checks if the given IP is a private/loopback address.
|
||||
func isPrivateIP(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
if subtle.ConstantTimeCompare([]byte(token), []byte(h.metricsToken)) != 1 {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm))
|
||||
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, network := range privateNetworks {
|
||||
if network.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
html := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
|
||||
` + faviconLink + `
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
h2 { font-size: 18px; font-weight: 600; margin: 32px 0 12px; }
|
||||
.code-wrap {
|
||||
position: relative;
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.code-wrap pre {
|
||||
margin: 0;
|
||||
padding: 12px 16px;
|
||||
padding-right: 60px;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
|
||||
font-size: 14px;
|
||||
overflow-x: auto;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
}
|
||||
.copy-btn {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
background: #fff;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
padding: 4px 6px;
|
||||
cursor: pointer;
|
||||
color: #57606a;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.copy-btn:hover { background: #f3f4f6; }
|
||||
.copy-btn svg { width: 16px; height: 16px; }
|
||||
.copy-btn .check { display: none; color: #1a7f37; }
|
||||
.copy-btn.copied .copy { display: none; }
|
||||
.copy-btn.copied .check { display: block; }
|
||||
.links { margin-top: 32px; display: flex; gap: 24px; flex-wrap: wrap; }
|
||||
.links a { color: #0969da; text-decoration: none; font-size: 14px; }
|
||||
.links a:hover { text-decoration: underline; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>💧</span>Drip</h1>
|
||||
<p class="desc">Your Tunnel, Your Domain, Anywhere</p>
|
||||
</header>
|
||||
|
||||
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
|
||||
|
||||
<h2>Install</h2>
|
||||
<div class="code-wrap">
|
||||
<pre>bash <(curl -fsSL https://driptunnel.app/install.sh)</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<h2>Usage</h2>
|
||||
<div class="code-wrap">
|
||||
<pre>drip http 3000</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="code-wrap">
|
||||
<pre>drip https 443</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="code-wrap">
|
||||
<pre>drip tcp 5432</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="links">
|
||||
<a href="/health">Health Check</a>
|
||||
<a href="/stats">Statistics</a>
|
||||
<a href="/metrics">Prometheus Metrics</a>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
<script>
|
||||
function copy(btn) {
|
||||
const text = btn.previousElementSibling.textContent;
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
btn.classList.add('copied');
|
||||
setTimeout(() => { btn.classList.remove('copied'); }, 2000);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
data := []byte(html)
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveTunnelNotFound(w http.ResponseWriter, r *http.Request) {
|
||||
html := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>404 - Tunnel Not Found</title>
|
||||
` + faviconLink + `
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
p { margin-bottom: 16px; }
|
||||
.info-box {
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
margin: 24px 0;
|
||||
}
|
||||
.info-box ul {
|
||||
margin: 12px 0 0 20px;
|
||||
color: #57606a;
|
||||
}
|
||||
.info-box li { margin-bottom: 8px; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>🔍</span>Tunnel Not Found</h1>
|
||||
<p class="desc">The requested tunnel does not exist or has been closed.</p>
|
||||
</header>
|
||||
|
||||
<div class="info-box">
|
||||
<p>This could happen because:</p>
|
||||
<ul>
|
||||
<li>The tunnel was never created</li>
|
||||
<li>The tunnel has been closed by the owner</li>
|
||||
<li>The tunnel URL is incorrect</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<p>If you are the tunnel owner, please restart your tunnel client.</p>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
data := []byte(html)
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
||||
health := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"active_tunnels": h.manager.Count(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(health)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
if h.metricsToken != "" {
|
||||
// Only accept token via Authorization header (Bearer token)
|
||||
// URL query parameters are insecure (logged, cached, visible in browser history)
|
||||
var token string
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
|
||||
if token != h.metricsToken {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="stats"`)
|
||||
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
connections := h.manager.List()
|
||||
|
||||
// Pre-allocate slice to avoid O(n²) reallocations
|
||||
tunnelStats := make([]map[string]interface{}, 0, len(connections))
|
||||
for _, conn := range connections {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
tunnelStats = append(tunnelStats, map[string]interface{}{
|
||||
"subdomain": conn.Subdomain,
|
||||
"tunnel_type": string(conn.GetTunnelType()),
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"bytes_in": conn.GetBytesIn(),
|
||||
"bytes_out": conn.GetBytesOut(),
|
||||
"active_connections": conn.GetActiveConnections(),
|
||||
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
|
||||
})
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_tunnels": len(tunnelStats),
|
||||
"tunnels": tunnelStats,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(stats)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if h.metricsToken != "" {
|
||||
// Only accept token via Authorization header (Bearer token)
|
||||
var token string
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
|
||||
if token != h.metricsToken {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`)
|
||||
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Serve Prometheus metrics
|
||||
promhttp.Handler().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type bufferedReadWriteCloser struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool {
|
||||
cookie, err := r.Cookie(authCookieName + "_" + subdomain)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return sessionStore.validate(cookie.Value, subdomain)
|
||||
}
|
||||
|
||||
func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) {
|
||||
if r.Method != http.MethodPost {
|
||||
h.serveLoginPage(w, r, subdomain, "")
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
h.serveLoginPage(w, r, subdomain, "Invalid form data")
|
||||
return
|
||||
}
|
||||
|
||||
password := r.FormValue("password")
|
||||
|
||||
if !tconn.ValidateProxyAuth(password) {
|
||||
h.serveLoginPage(w, r, subdomain, "Invalid password")
|
||||
return
|
||||
}
|
||||
|
||||
token := sessionStore.create(subdomain)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: authCookieName + "_" + subdomain,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: int(authSessionDuration.Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
redirectURL := r.FormValue("redirect")
|
||||
if redirectURL == "" || redirectURL == "/_drip/login" {
|
||||
redirectURL = "/"
|
||||
}
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) {
|
||||
redirectURL := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
redirectURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
if redirectURL == "/_drip/login" {
|
||||
redirectURL = "/"
|
||||
}
|
||||
|
||||
errorHTML := ""
|
||||
if errorMsg != "" {
|
||||
errorHTML = fmt.Sprintf(`<p class="error">%s</p>`, html.EscapeString(errorMsg))
|
||||
}
|
||||
|
||||
safeRedirectURL := html.EscapeString(redirectURL)
|
||||
|
||||
htmlContent := fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>%s - Drip</title>
|
||||
`+faviconLink+`
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
p { margin-bottom: 24px; }
|
||||
.error { color: #cf222e; margin-bottom: 16px; }
|
||||
.input-wrap {
|
||||
position: relative;
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 12px;
|
||||
display: flex;
|
||||
}
|
||||
.input-wrap input {
|
||||
flex: 1;
|
||||
margin: 0;
|
||||
padding: 12px 16px;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
|
||||
font-size: 14px;
|
||||
background: transparent;
|
||||
border: none;
|
||||
outline: none;
|
||||
}
|
||||
.input-wrap button {
|
||||
background: #24292f;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
margin: 4px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.input-wrap button:hover { background: #32383f; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>🔒</span>%s</h1>
|
||||
<p class="desc">This tunnel is password protected</p>
|
||||
</header>
|
||||
|
||||
%s
|
||||
<form method="POST" action="/_drip/login">
|
||||
<input type="hidden" name="redirect" value="%s" />
|
||||
<div class="input-wrap">
|
||||
<input type="password" name="password" placeholder="Enter password" required autofocus />
|
||||
<button type="submit">Continue</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, subdomain, subdomain, errorHTML, safeRedirectURL)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(htmlContent))
|
||||
}
|
||||
|
||||
// handleTunnelWebSocket handles WebSocket connections for tunnel clients
|
||||
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if WSS transport is allowed
|
||||
if !h.IsTransportAllowed("wss") {
|
||||
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if h.wsConnHandler == nil {
|
||||
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WebSocket for tunnel use
|
||||
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
|
||||
|
||||
// Extract real client IP (support CDN headers)
|
||||
remoteAddr := h.extractClientIP(r)
|
||||
|
||||
h.logger.Info("WebSocket tunnel connection established",
|
||||
zap.String("remote_addr", remoteAddr),
|
||||
)
|
||||
|
||||
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
|
||||
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
|
||||
|
||||
// Handle the connection using the registered handler
|
||||
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
|
||||
}
|
||||
|
||||
// serveDiscovery returns server capabilities for client auto-detection
|
||||
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
transports := h.allowedTransports
|
||||
if len(transports) == 0 {
|
||||
transports = []string{"tcp", "wss"}
|
||||
}
|
||||
|
||||
tunnelTypes := h.allowedTunnelTypes
|
||||
if len(tunnelTypes) == 0 {
|
||||
tunnelTypes = []string{"http", "https", "tcp"}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"transports": transports,
|
||||
"tunnel_types": tunnelTypes,
|
||||
"preferred": h.GetPreferredTransport(),
|
||||
"version": "1",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Write(data)
|
||||
return true
|
||||
}
|
||||
|
||||
299
internal/server/proxy/pages.go
Normal file
299
internal/server/proxy/pages.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
"drip/internal/shared/httputil"
|
||||
)
|
||||
|
||||
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
html := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
|
||||
` + faviconLink + `
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
h2 { font-size: 18px; font-weight: 600; margin: 32px 0 12px; }
|
||||
.code-wrap {
|
||||
position: relative;
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.code-wrap pre {
|
||||
margin: 0;
|
||||
padding: 12px 16px;
|
||||
padding-right: 60px;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
|
||||
font-size: 14px;
|
||||
overflow-x: auto;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
}
|
||||
.copy-btn {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
background: #fff;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
padding: 4px 6px;
|
||||
cursor: pointer;
|
||||
color: #57606a;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.copy-btn:hover { background: #f3f4f6; }
|
||||
.copy-btn svg { width: 16px; height: 16px; }
|
||||
.copy-btn .check { display: none; color: #1a7f37; }
|
||||
.copy-btn.copied .copy { display: none; }
|
||||
.copy-btn.copied .check { display: block; }
|
||||
.links { margin-top: 32px; display: flex; gap: 24px; flex-wrap: wrap; }
|
||||
.links a { color: #0969da; text-decoration: none; font-size: 14px; }
|
||||
.links a:hover { text-decoration: underline; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>💧</span>Drip</h1>
|
||||
<p class="desc">Your Tunnel, Your Domain, Anywhere</p>
|
||||
</header>
|
||||
|
||||
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
|
||||
|
||||
<h2>Install</h2>
|
||||
<div class="code-wrap">
|
||||
<pre>bash <(curl -fsSL https://driptunnel.app/install.sh)</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<h2>Usage</h2>
|
||||
<div class="code-wrap">
|
||||
<pre>drip http 3000</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="code-wrap">
|
||||
<pre>drip https 443</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="code-wrap">
|
||||
<pre>drip tcp 5432</pre>
|
||||
<button class="copy-btn" onclick="copy(this)">
|
||||
<svg class="copy" viewBox="0 0 16 16" fill="currentColor"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg>
|
||||
<svg class="check" viewBox="0 0 16 16" fill="currentColor"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path></svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="links">
|
||||
<a href="/health">Health Check</a>
|
||||
<a href="/stats">Statistics</a>
|
||||
<a href="/metrics">Prometheus Metrics</a>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
<script>
|
||||
function copy(btn) {
|
||||
const text = btn.previousElementSibling.textContent;
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
btn.classList.add('copied');
|
||||
setTimeout(() => { btn.classList.remove('copied'); }, 2000);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
httputil.WriteHTML(w, []byte(html))
|
||||
}
|
||||
|
||||
func (h *Handler) serveTunnelNotFound(w http.ResponseWriter, r *http.Request) {
|
||||
html := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>404 - Tunnel Not Found</title>
|
||||
` + faviconLink + `
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
p { margin-bottom: 16px; }
|
||||
.info-box {
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
margin: 24px 0;
|
||||
}
|
||||
.info-box ul {
|
||||
margin: 12px 0 0 20px;
|
||||
color: #57606a;
|
||||
}
|
||||
.info-box li { margin-bottom: 8px; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>🔍</span>Tunnel Not Found</h1>
|
||||
<p class="desc">The requested tunnel does not exist or has been closed.</p>
|
||||
</header>
|
||||
|
||||
<div class="info-box">
|
||||
<p>This could happen because:</p>
|
||||
<ul>
|
||||
<li>The tunnel was never created</li>
|
||||
<li>The tunnel has been closed by the owner</li>
|
||||
<li>The tunnel URL is incorrect</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<p>If you are the tunnel owner, please restart your tunnel client.</p>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
httputil.WriteHTMLWithStatus(w, []byte(html), http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
||||
health := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"active_tunnels": h.manager.Count(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(health)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.validateMetricsAuth(w, r, "stats") {
|
||||
return
|
||||
}
|
||||
|
||||
connections := h.manager.List()
|
||||
|
||||
tunnelStats := make([]map[string]interface{}, 0, len(connections))
|
||||
for _, conn := range connections {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
tunnelStats = append(tunnelStats, map[string]interface{}{
|
||||
"subdomain": conn.Subdomain,
|
||||
"tunnel_type": string(conn.GetTunnelType()),
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"bytes_in": conn.GetBytesIn(),
|
||||
"bytes_out": conn.GetBytesOut(),
|
||||
"active_connections": conn.GetActiveConnections(),
|
||||
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
|
||||
})
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_tunnels": len(tunnelStats),
|
||||
"tunnels": tunnelStats,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(stats)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httputil.WriteJSON(w, data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.validateMetricsAuth(w, r, "metrics") {
|
||||
return
|
||||
}
|
||||
|
||||
promhttp.Handler().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
transports := h.allowedTransports
|
||||
if len(transports) == 0 {
|
||||
transports = []string{"tcp", "wss"}
|
||||
}
|
||||
|
||||
tunnelTypes := h.allowedTunnelTypes
|
||||
if len(tunnelTypes) == 0 {
|
||||
tunnelTypes = []string{"http", "https", "tcp"}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"transports": transports,
|
||||
"tunnel_types": tunnelTypes,
|
||||
"preferred": h.GetPreferredTransport(),
|
||||
"version": "1",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
httputil.WriteJSON(w, data)
|
||||
}
|
||||
113
internal/server/proxy/websocket_handler.go
Normal file
113
internal/server/proxy/websocket_handler.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/wsutil"
|
||||
)
|
||||
|
||||
type bufferedReadWriteCloser struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
tconn.IncActiveConnections()
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, clientBuf, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
stream.Close()
|
||||
clientConn.Close()
|
||||
tconn.DecActiveConnections()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
var clientRW io.ReadWriteCloser = clientConn
|
||||
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
|
||||
clientRW = &bufferedReadWriteCloser{
|
||||
Reader: clientBuf.Reader,
|
||||
Conn: clientConn,
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.IsTransportAllowed("wss") {
|
||||
http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if h.wsConnHandler == nil {
|
||||
http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
ws, err := h.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024)
|
||||
|
||||
remoteAddr := netutil.ExtractClientIP(r)
|
||||
|
||||
h.logger.Info("WebSocket tunnel connection established",
|
||||
zap.String("remote_addr", remoteAddr),
|
||||
)
|
||||
|
||||
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
|
||||
|
||||
h.wsConnHandler.HandleWSConnection(conn, remoteAddr)
|
||||
}
|
||||
|
||||
func (h *Handler) isWebSocketUpgrade(r *http.Request) bool {
|
||||
return httputil.IsWebSocketUpgrade(r)
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -18,46 +17,54 @@ import (
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure
|
||||
var bufioWriterPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bufio.NewWriterSize(nil, 4096)
|
||||
},
|
||||
type ConnectionConfig struct {
|
||||
Conn net.Conn
|
||||
AuthToken string
|
||||
Manager *tunnel.Manager
|
||||
Logger *zap.Logger
|
||||
PortAlloc *PortAllocator
|
||||
Domain string
|
||||
TunnelDomain string
|
||||
PublicPort int
|
||||
HTTPHandler http.Handler
|
||||
GroupManager *ConnectionGroupManager
|
||||
HTTPListener *connQueueListener
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
subdomain string
|
||||
port int
|
||||
domain string
|
||||
tunnelDomain string
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
lastHeartbeat time.Time
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
tunnelType protocol.TunnelType
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
httpListener *connQueueListener
|
||||
handedOff bool
|
||||
conn net.Conn
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
subdomain string
|
||||
port int
|
||||
domain string
|
||||
tunnelDomain string
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
lastHeartbeat time.Time
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
tunnelType protocol.TunnelType
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
httpListener *connQueueListener
|
||||
handedOff bool
|
||||
lifecycleManager *ConnectionLifecycleManager
|
||||
|
||||
// Server capabilities
|
||||
allowedTunnelTypes []string
|
||||
@@ -65,25 +72,32 @@ type Connection struct {
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
|
||||
func NewConnection(cfg ConnectionConfig) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
c := &Connection{
|
||||
conn: conn,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
portAlloc: portAlloc,
|
||||
domain: domain,
|
||||
tunnelDomain: tunnelDomain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
groupManager: groupManager,
|
||||
httpListener: httpListener,
|
||||
conn: cfg.Conn,
|
||||
authToken: cfg.AuthToken,
|
||||
manager: cfg.Manager,
|
||||
logger: cfg.Logger,
|
||||
portAlloc: cfg.PortAlloc,
|
||||
domain: cfg.Domain,
|
||||
tunnelDomain: cfg.TunnelDomain,
|
||||
publicPort: cfg.PublicPort,
|
||||
httpHandler: cfg.HTTPHandler,
|
||||
stopCh: stopCh,
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
groupManager: cfg.GroupManager,
|
||||
httpListener: cfg.HTTPListener,
|
||||
lifecycleManager: NewConnectionLifecycleManager(stopCh, cancel, cfg.Logger),
|
||||
}
|
||||
|
||||
// Set connection in lifecycle manager
|
||||
c.lifecycleManager.SetConnection(cfg.Conn)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -99,17 +113,7 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
|
||||
peekStr := string(peek)
|
||||
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
|
||||
isHTTP := false
|
||||
for _, method := range httpMethods {
|
||||
if strings.HasPrefix(peekStr, method) {
|
||||
isHTTP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isHTTP {
|
||||
if httputil.IsHTTPRequest(peek) {
|
||||
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
|
||||
return c.handleHTTPRequest(reader)
|
||||
}
|
||||
@@ -128,7 +132,24 @@ func (c *Connection) Handle() error {
|
||||
defer sf.Close()
|
||||
|
||||
if sf.Frame.Type == protocol.FrameTypeDataConnect {
|
||||
return c.handleDataConnect(sf.Frame, reader)
|
||||
handler := NewDataConnectionHandler(
|
||||
c.conn,
|
||||
reader,
|
||||
c.authToken,
|
||||
c.groupManager,
|
||||
c.stopCh,
|
||||
c.logger,
|
||||
)
|
||||
handler.SetSessionCreatedHandler(func(session *yamux.Session) {
|
||||
c.session = session
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetSession(session)
|
||||
}
|
||||
})
|
||||
handler.SetTunnelIDHandler(func(tunnelID string) {
|
||||
c.tunnelID = tunnelID
|
||||
})
|
||||
return handler.Handle(sf.Frame)
|
||||
}
|
||||
|
||||
if sf.Frame.Type != protocol.FrameTypeRegister {
|
||||
@@ -153,121 +174,70 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("authentication failed")
|
||||
}
|
||||
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
if c.portAlloc == nil {
|
||||
return fmt.Errorf("port allocator not configured")
|
||||
}
|
||||
|
||||
if requestedPort, ok := parseTCPSubdomainPort(req.CustomSubdomain); ok {
|
||||
port, err := c.portAlloc.AllocateSpecific(requestedPort)
|
||||
if err != nil {
|
||||
c.sendError("port_allocation_failed", err.Error())
|
||||
return fmt.Errorf("failed to allocate requested port %d: %w", requestedPort, err)
|
||||
}
|
||||
c.port = port
|
||||
} else {
|
||||
port, err := c.portAlloc.Allocate()
|
||||
if err != nil {
|
||||
c.sendError("port_allocation_failed", err.Error())
|
||||
return fmt.Errorf("failed to allocate port: %w", err)
|
||||
}
|
||||
c.port = port
|
||||
|
||||
if req.CustomSubdomain == "" {
|
||||
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
|
||||
if err != nil {
|
||||
c.sendError("registration_failed", err.Error())
|
||||
c.portAlloc.Release(c.port)
|
||||
c.port = 0
|
||||
return fmt.Errorf("tunnel registration failed: %w", err)
|
||||
}
|
||||
|
||||
c.subdomain = subdomain
|
||||
|
||||
tunnelConn, ok := c.manager.Get(subdomain)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to get registered tunnel")
|
||||
}
|
||||
c.tunnelConn = tunnelConn
|
||||
|
||||
c.tunnelConn.Conn = nil
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
|
||||
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
|
||||
c.logger.Info("IP access control configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Strings("allow_ips", req.IPAccess.AllowIPs),
|
||||
zap.Strings("deny_ips", req.IPAccess.DenyIPs),
|
||||
)
|
||||
}
|
||||
|
||||
if req.ProxyAuth != nil && req.ProxyAuth.Enabled {
|
||||
c.tunnelConn.SetProxyAuth(req.ProxyAuth)
|
||||
c.logger.Info("Proxy authentication configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
}
|
||||
|
||||
c.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("tunnel_type", string(req.TunnelType)),
|
||||
zap.Int("local_port", req.LocalPort),
|
||||
zap.Int("remote_port", c.port),
|
||||
// Use RegistrationHandler for registration logic
|
||||
regHandler := NewRegistrationHandler(
|
||||
c.manager,
|
||||
c.portAlloc,
|
||||
c.groupManager,
|
||||
c.domain,
|
||||
c.tunnelDomain,
|
||||
c.publicPort,
|
||||
c.logger,
|
||||
)
|
||||
|
||||
var tunnelURL string
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
if c.publicPort == 443 {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.tunnelDomain)
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.tunnelDomain, c.publicPort)
|
||||
}
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.tunnelDomain, c.port)
|
||||
regReq := &RegistrationRequest{
|
||||
TunnelType: req.TunnelType,
|
||||
CustomSubdomain: req.CustomSubdomain,
|
||||
Token: req.Token,
|
||||
ConnectionType: req.ConnectionType,
|
||||
PoolCapabilities: req.PoolCapabilities,
|
||||
IPAccess: req.IPAccess,
|
||||
ProxyAuth: req.ProxyAuth,
|
||||
LocalPort: req.LocalPort,
|
||||
}
|
||||
|
||||
var tunnelID string
|
||||
var supportsDataConn bool
|
||||
recommendedConns := 0
|
||||
result, err := regHandler.Register(regReq)
|
||||
if err != nil {
|
||||
c.sendError("registration_failed", err.Error())
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
|
||||
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
|
||||
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
|
||||
tunnelID = group.TunnelID
|
||||
c.tunnelID = tunnelID
|
||||
supportsDataConn = true
|
||||
recommendedConns = 4
|
||||
// Store registration results
|
||||
c.subdomain = result.Subdomain
|
||||
c.port = result.Port
|
||||
c.tunnelConn = result.TunnelConn
|
||||
c.tunnelConn.Conn = nil
|
||||
|
||||
// Update lifecycle manager with registration info
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetPortAllocation(c.portAlloc, c.port)
|
||||
c.lifecycleManager.SetTunnelRegistration(c.manager, c.subdomain, "", c.groupManager)
|
||||
}
|
||||
|
||||
// Handle connection groups
|
||||
if result.SupportsDataConn && c.groupManager != nil {
|
||||
group := c.groupManager.CreateGroup(result.Subdomain, req.Token, c, req.TunnelType)
|
||||
result.TunnelID = group.TunnelID
|
||||
c.tunnelID = result.TunnelID
|
||||
|
||||
// Update lifecycle manager with tunnel ID
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetTunnelRegistration(c.manager, c.subdomain, c.tunnelID, c.groupManager)
|
||||
}
|
||||
|
||||
c.logger.Info("Created connection group for multi-connection support",
|
||||
zap.String("tunnel_id", tunnelID),
|
||||
zap.String("tunnel_id", result.TunnelID),
|
||||
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
|
||||
)
|
||||
}
|
||||
|
||||
resp := protocol.RegisterResponse{
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
// Build and send registration response
|
||||
resp, err := regHandler.BuildRegistrationResponse(result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build registration response: %w", err)
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal registration response: %w", err)
|
||||
}
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
|
||||
|
||||
err = protocol.WriteFrame(c.conn, ackFrame)
|
||||
if err != nil {
|
||||
if err := regHandler.SendRegistrationResponse(c.conn, resp); err != nil {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
@@ -282,6 +252,11 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
// Update lifecycle manager with frame writer
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetFrameWriter(c.frameWriter)
|
||||
}
|
||||
|
||||
c.frameWriter.SetWriteErrorHandler(func(err error) {
|
||||
c.logger.Error("Write error detected, closing connection", zap.Error(err))
|
||||
c.Close()
|
||||
@@ -289,39 +264,30 @@ func (c *Connection) Handle() error {
|
||||
|
||||
go c.heartbeatChecker()
|
||||
|
||||
return c.handleFrames(reader)
|
||||
// Use FrameHandler for frame processing
|
||||
frameHandler := NewFrameHandler(c.conn, reader, c.stopCh, c.frameWriter, c.logger)
|
||||
frameHandler.SetHeartbeatHandler(func() {
|
||||
c.handleHeartbeat()
|
||||
})
|
||||
frameHandler.SetCloseHandler(func() {
|
||||
c.logger.Info("Client requested close")
|
||||
})
|
||||
|
||||
return frameHandler.HandleFrames()
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
if c.httpListener == nil {
|
||||
return c.handleHTTPRequestLegacy(reader)
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
wrappedConn := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
if !c.httpListener.Enqueue(wrappedConn) {
|
||||
c.logger.Warn("HTTP listener queue full, rejecting connection")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
"Content-Type: text/plain\r\n" +
|
||||
"Content-Length: 32\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n" +
|
||||
"Server busy, please retry later\r\n"
|
||||
c.conn.Write([]byte(response))
|
||||
return fmt.Errorf("http listener queue full")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = nil
|
||||
c.handedOff = true
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
handler := NewHTTPRequestHandler(
|
||||
c.conn,
|
||||
reader,
|
||||
c.httpHandler,
|
||||
c.httpListener,
|
||||
c.ctx,
|
||||
c.logger,
|
||||
&c.mu,
|
||||
&c.handedOff,
|
||||
)
|
||||
return handler.Handle()
|
||||
}
|
||||
|
||||
func (c *Connection) IsHandedOff() bool {
|
||||
@@ -330,113 +296,6 @@ func (c *Connection) IsHandedOff() bool {
|
||||
return c.handedOff
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
if c.httpHandler == nil {
|
||||
c.logger.Warn("HTTP request received but no HTTP handler configured")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
"Content-Type: text/plain\r\n" +
|
||||
"Content-Length: 47\r\n" +
|
||||
"\r\n" +
|
||||
"HTTP handler not configured for this TCP port\r\n"
|
||||
c.conn.Write([]byte(response))
|
||||
return fmt.Errorf("HTTP handler not configured")
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
for {
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
req, err := http.ReadRequest(reader)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
c.logger.Debug("Client closed HTTP connection")
|
||||
return nil
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
c.logger.Debug("HTTP keep-alive timeout")
|
||||
return nil
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
|
||||
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(errStr, "malformed HTTP") {
|
||||
c.logger.Warn("Received malformed HTTP request",
|
||||
zap.Error(err),
|
||||
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
|
||||
if c.ctx != nil {
|
||||
req = req.WithContext(c.ctx)
|
||||
}
|
||||
|
||||
c.logger.Info("Processing HTTP request on TCP port",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
zap.String("host", req.Host),
|
||||
)
|
||||
|
||||
// Get writer from pool to reduce GC pressure
|
||||
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
|
||||
pooledWriter.Reset(c.conn)
|
||||
|
||||
respWriter := &httpResponseWriter{
|
||||
conn: c.conn,
|
||||
writer: pooledWriter,
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
c.httpHandler.ServeHTTP(respWriter, req)
|
||||
|
||||
if err := respWriter.writer.Flush(); err != nil {
|
||||
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
|
||||
}
|
||||
|
||||
// Return writer to pool
|
||||
pooledWriter.Reset(nil) // Clear reference to connection
|
||||
bufioWriterPool.Put(pooledWriter)
|
||||
|
||||
// Keep TCP_NODELAY enabled for low latency HTTP responses
|
||||
// (removed the toggle that was disabling it)
|
||||
|
||||
c.logger.Debug("HTTP request processing completed",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
shouldClose := false
|
||||
if req.Close {
|
||||
shouldClose = true
|
||||
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||
if req.Header.Get("Connection") != "keep-alive" {
|
||||
shouldClose = true
|
||||
}
|
||||
}
|
||||
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
|
||||
if shouldClose {
|
||||
c.logger.Debug("Closing connection as requested by client or server")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseTCPSubdomainPort(subdomain string) (int, bool) {
|
||||
if !strings.HasPrefix(subdomain, "tcp-") {
|
||||
return 0, false
|
||||
@@ -455,55 +314,6 @@ func parseTCPSubdomainPort(subdomain string) (int, bool) {
|
||||
return port, true
|
||||
}
|
||||
|
||||
func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
if isTimeoutError(err) {
|
||||
c.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
|
||||
c.logger.Info("Client disconnected")
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
c.logger.Debug("Connection closed during shutdown")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("failed to read frame: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeHeartbeat:
|
||||
c.handleHeartbeat()
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Client requested close")
|
||||
return nil
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) handleHeartbeat() {
|
||||
c.mu.Lock()
|
||||
c.lastHeartbeat = time.Now()
|
||||
@@ -562,188 +372,71 @@ func (c *Connection) sendError(code, message string) {
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.once.Do(func() {
|
||||
protocol.UnregisterConnection()
|
||||
close(c.stopCh)
|
||||
// Check if connection was handed off to HTTP handler
|
||||
c.mu.RLock()
|
||||
handedOff := c.handedOff
|
||||
c.mu.RUnlock()
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
// If handed off, don't close the connection - HTTP handler owns it now
|
||||
if handedOff {
|
||||
c.logger.Debug("Connection handed off to HTTP handler, skipping close")
|
||||
return
|
||||
}
|
||||
|
||||
// Prevent race with handleHTTPRequest setting c.conn = nil
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
// Use lifecycle manager for cleanup
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.Close()
|
||||
} else {
|
||||
// Fallback if lifecycle manager not initialized
|
||||
protocol.UnregisterConnection()
|
||||
close(c.stopCh)
|
||||
|
||||
if conn != nil {
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
if c.proxy != nil {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
}
|
||||
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
c.groupManager.RemoveGroup(c.tunnelID)
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
zap.String("subdomain", c.subdomain),
|
||||
)
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn != nil {
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
if c.proxy != nil {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
}
|
||||
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
c.groupManager.RemoveGroup(c.tunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
zap.String("subdomain", c.subdomain),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
writer *bufio.Writer
|
||||
header http.Header
|
||||
statusCode int
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.headerWritten = true
|
||||
|
||||
statusText := http.StatusText(statusCode)
|
||||
if statusText == "" {
|
||||
statusText = "Unknown"
|
||||
}
|
||||
|
||||
w.writer.WriteString("HTTP/1.1 ")
|
||||
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
|
||||
w.writer.WriteByte(' ')
|
||||
w.writer.WriteString(statusText)
|
||||
w.writer.WriteString("\r\n")
|
||||
|
||||
for key, values := range w.header {
|
||||
for _, value := range values {
|
||||
w.writer.WriteString(key)
|
||||
w.writer.WriteString(": ")
|
||||
w.writer.WriteString(value)
|
||||
w.writer.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
w.writer.WriteString("\r\n")
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
|
||||
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
|
||||
var req protocol.DataConnectRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
c.sendError("invalid_request", "Failed to parse data connect request")
|
||||
return fmt.Errorf("failed to parse data connect request: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("Data connection request received",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
if c.groupManager == nil {
|
||||
c.sendDataConnectError("not_supported", "Multi-connection not supported")
|
||||
return fmt.Errorf("group manager not available")
|
||||
}
|
||||
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
group, ok := c.groupManager.GetGroup(req.TunnelID)
|
||||
if !ok || group == nil {
|
||||
c.sendDataConnectError("join_failed", "Tunnel not found")
|
||||
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
c.tunnelID = req.TunnelID
|
||||
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal data connect response: %w", err)
|
||||
}
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("Data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Server acts as yamux Client, client connector acts as yamux Server
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -755,20 +448,6 @@ func isTimeoutError(err error) bool {
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
Message: fmt.Sprintf("%s: %s", code, message),
|
||||
}
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to marshal data connect error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
_ = protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
|
||||
func (c *Connection) SetAllowedTunnelTypes(types []string) {
|
||||
c.allowedTunnelTypes = types
|
||||
|
||||
@@ -17,10 +17,10 @@ import (
|
||||
|
||||
// sessionEntry represents a session with its current stream count for heap operations
|
||||
type sessionEntry struct {
|
||||
id string
|
||||
session *yamux.Session
|
||||
streams int
|
||||
heapIdx int // index in the heap, managed by heap.Interface
|
||||
id string
|
||||
session *yamux.Session
|
||||
streams int
|
||||
heapIdx int // index in the heap, managed by heap.Interface
|
||||
}
|
||||
|
||||
// sessionHeap implements heap.Interface for O(log n) session selection
|
||||
|
||||
161
internal/server/tcp/data_connection_handler.go
Normal file
161
internal/server/tcp/data_connection_handler.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
// DataConnectionHandler handles data connection requests for multi-connection support.
|
||||
type DataConnectionHandler struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
authToken string
|
||||
groupManager *ConnectionGroupManager
|
||||
stopCh <-chan struct{}
|
||||
logger *zap.Logger
|
||||
onSessionCreated func(*yamux.Session)
|
||||
onTunnelIDSet func(string)
|
||||
}
|
||||
|
||||
// NewDataConnectionHandler creates a new data connection handler.
|
||||
func NewDataConnectionHandler(
|
||||
conn net.Conn,
|
||||
reader *bufio.Reader,
|
||||
authToken string,
|
||||
groupManager *ConnectionGroupManager,
|
||||
stopCh <-chan struct{},
|
||||
logger *zap.Logger,
|
||||
) *DataConnectionHandler {
|
||||
return &DataConnectionHandler{
|
||||
conn: conn,
|
||||
reader: reader,
|
||||
authToken: authToken,
|
||||
groupManager: groupManager,
|
||||
stopCh: stopCh,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetSessionCreatedHandler sets the callback for when a session is created.
|
||||
func (h *DataConnectionHandler) SetSessionCreatedHandler(handler func(*yamux.Session)) {
|
||||
h.onSessionCreated = handler
|
||||
}
|
||||
|
||||
// SetTunnelIDHandler sets the callback for when tunnel ID is set.
|
||||
func (h *DataConnectionHandler) SetTunnelIDHandler(handler func(string)) {
|
||||
h.onTunnelIDSet = handler
|
||||
}
|
||||
|
||||
// Handle processes the data connection request.
|
||||
func (h *DataConnectionHandler) Handle(frame *protocol.Frame) error {
|
||||
var req protocol.DataConnectRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
h.sendError("invalid_request", "Failed to parse data connect request")
|
||||
return fmt.Errorf("failed to parse data connect request: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("Data connection request received",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
if h.groupManager == nil {
|
||||
h.sendError("not_supported", "Multi-connection not supported")
|
||||
return fmt.Errorf("group manager not available")
|
||||
}
|
||||
|
||||
if h.authToken != "" && req.Token != h.authToken {
|
||||
h.sendError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
group, ok := h.groupManager.GetGroup(req.TunnelID)
|
||||
if !ok || group == nil {
|
||||
h.sendError("join_failed", "Tunnel not found")
|
||||
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
h.sendError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
if h.onTunnelIDSet != nil {
|
||||
h.onTunnelIDSet(req.TunnelID)
|
||||
}
|
||||
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal data connect response: %w", err)
|
||||
}
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(h.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("Data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
_ = h.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Server acts as yamux Client, client connector acts as yamux Server
|
||||
bc := &bufferedConn{
|
||||
Conn: h.conn,
|
||||
reader: h.reader,
|
||||
}
|
||||
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
|
||||
if h.onSessionCreated != nil {
|
||||
h.onSessionCreated(session)
|
||||
}
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-h.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sendError sends an error response to the client.
|
||||
func (h *DataConnectionHandler) sendError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
Message: fmt.Sprintf("%s: %s", code, message),
|
||||
}
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to marshal data connect error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
_ = protocol.WriteFrame(h.conn, frame)
|
||||
}
|
||||
122
internal/server/tcp/frame_handler.go
Normal file
122
internal/server/tcp/frame_handler.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// FrameHandler handles protocol frame reading and processing.
|
||||
type FrameHandler struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
stopCh <-chan struct{}
|
||||
logger *zap.Logger
|
||||
frameWriter *protocol.FrameWriter
|
||||
|
||||
// Heartbeat tracking
|
||||
onHeartbeat func()
|
||||
onClose func()
|
||||
}
|
||||
|
||||
// NewFrameHandler creates a new frame handler.
|
||||
func NewFrameHandler(
|
||||
conn net.Conn,
|
||||
reader *bufio.Reader,
|
||||
stopCh <-chan struct{},
|
||||
frameWriter *protocol.FrameWriter,
|
||||
logger *zap.Logger,
|
||||
) *FrameHandler {
|
||||
return &FrameHandler{
|
||||
conn: conn,
|
||||
reader: reader,
|
||||
stopCh: stopCh,
|
||||
frameWriter: frameWriter,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetHeartbeatHandler sets the callback for heartbeat frames.
|
||||
func (fh *FrameHandler) SetHeartbeatHandler(handler func()) {
|
||||
fh.onHeartbeat = handler
|
||||
}
|
||||
|
||||
// SetCloseHandler sets the callback for close frames.
|
||||
func (fh *FrameHandler) SetCloseHandler(handler func()) {
|
||||
fh.onClose = handler
|
||||
}
|
||||
|
||||
// HandleFrames processes incoming frames in a loop.
|
||||
func (fh *FrameHandler) HandleFrames() error {
|
||||
for {
|
||||
select {
|
||||
case <-fh.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
fh.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(fh.reader)
|
||||
if err != nil {
|
||||
return fh.handleReadError(err)
|
||||
}
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
err = fh.processFrame(sf)
|
||||
sf.Close()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleReadError handles errors that occur while reading frames.
|
||||
func (fh *FrameHandler) handleReadError(err error) error {
|
||||
if isTimeoutError(err) {
|
||||
fh.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
|
||||
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
|
||||
fh.logger.Info("Client disconnected")
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-fh.stopCh:
|
||||
fh.logger.Debug("Connection closed during shutdown")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("failed to read frame: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// processFrame processes a single frame based on its type.
|
||||
func (fh *FrameHandler) processFrame(sf *protocol.SafeFrame) error {
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeHeartbeat:
|
||||
if fh.onHeartbeat != nil {
|
||||
fh.onHeartbeat()
|
||||
}
|
||||
return nil
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
fh.logger.Info("Client requested close")
|
||||
if fh.onClose != nil {
|
||||
fh.onClose()
|
||||
}
|
||||
return fmt.Errorf("client requested close")
|
||||
|
||||
default:
|
||||
fh.logger.Warn("Unexpected frame type",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
251
internal/server/tcp/http_request_handler.go
Normal file
251
internal/server/tcp/http_request_handler.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure
|
||||
var bufioWriterPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bufio.NewWriterSize(nil, 4096)
|
||||
},
|
||||
}
|
||||
|
||||
// HTTPRequestHandler handles HTTP requests on TCP connections.
|
||||
type HTTPRequestHandler struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
httpHandler http.Handler
|
||||
httpListener *connQueueListener
|
||||
ctx interface{ Done() <-chan struct{} }
|
||||
logger *zap.Logger
|
||||
mu *sync.RWMutex
|
||||
handedOff *bool
|
||||
}
|
||||
|
||||
// NewHTTPRequestHandler creates a new HTTP request handler.
|
||||
func NewHTTPRequestHandler(
|
||||
conn net.Conn,
|
||||
reader *bufio.Reader,
|
||||
httpHandler http.Handler,
|
||||
httpListener *connQueueListener,
|
||||
ctx interface{ Done() <-chan struct{} },
|
||||
logger *zap.Logger,
|
||||
mu *sync.RWMutex,
|
||||
handedOff *bool,
|
||||
) *HTTPRequestHandler {
|
||||
return &HTTPRequestHandler{
|
||||
conn: conn,
|
||||
reader: reader,
|
||||
httpHandler: httpHandler,
|
||||
httpListener: httpListener,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
mu: mu,
|
||||
handedOff: handedOff,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle processes the HTTP request.
|
||||
func (h *HTTPRequestHandler) Handle() error {
|
||||
if h.httpListener == nil {
|
||||
return h.handleLegacy()
|
||||
}
|
||||
|
||||
h.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
wrappedConn := &bufferedConn{
|
||||
Conn: h.conn,
|
||||
reader: h.reader,
|
||||
}
|
||||
|
||||
if !h.httpListener.Enqueue(wrappedConn) {
|
||||
h.logger.Warn("HTTP listener queue full, rejecting connection")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
"Content-Type: text/plain\r\n" +
|
||||
"Content-Length: 32\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n" +
|
||||
"Server busy, please retry later\r\n"
|
||||
h.conn.Write([]byte(response))
|
||||
return fmt.Errorf("http listener queue full")
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
*h.handedOff = true
|
||||
h.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLegacy processes HTTP requests using the legacy handler.
|
||||
func (h *HTTPRequestHandler) handleLegacy() error {
|
||||
if h.httpHandler == nil {
|
||||
h.logger.Warn("HTTP request received but no HTTP handler configured")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
"Content-Type: text/plain\r\n" +
|
||||
"Content-Length: 47\r\n" +
|
||||
"\r\n" +
|
||||
"HTTP handler not configured for this TCP port\r\n"
|
||||
h.conn.Write([]byte(response))
|
||||
return fmt.Errorf("HTTP handler not configured")
|
||||
}
|
||||
|
||||
h.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
for {
|
||||
h.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
req, err := http.ReadRequest(h.reader)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
h.logger.Debug("Client closed HTTP connection")
|
||||
return nil
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
h.logger.Debug("HTTP keep-alive timeout")
|
||||
return nil
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
|
||||
h.logger.Debug("HTTP connection closed during read", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
h.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(errStr, "malformed HTTP") {
|
||||
h.logger.Warn("Received malformed HTTP request",
|
||||
zap.Error(err),
|
||||
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
h.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
|
||||
if h.ctx != nil {
|
||||
if ctxWithContext, ok := h.ctx.(interface{ Done() <-chan struct{} }); ok {
|
||||
req = req.WithContext(ctxWithContext.(interface {
|
||||
Done() <-chan struct{}
|
||||
Deadline() (deadline time.Time, ok bool)
|
||||
Err() error
|
||||
Value(key interface{}) interface{}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("Processing HTTP request on TCP port",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
zap.String("host", req.Host),
|
||||
)
|
||||
|
||||
// Get writer from pool to reduce GC pressure
|
||||
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
|
||||
pooledWriter.Reset(h.conn)
|
||||
|
||||
respWriter := &httpResponseWriter{
|
||||
conn: h.conn,
|
||||
writer: pooledWriter,
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
h.httpHandler.ServeHTTP(respWriter, req)
|
||||
|
||||
if err := respWriter.writer.Flush(); err != nil {
|
||||
h.logger.Debug("Failed to flush HTTP response", zap.Error(err))
|
||||
}
|
||||
|
||||
// Return writer to pool
|
||||
pooledWriter.Reset(nil) // Clear reference to connection
|
||||
bufioWriterPool.Put(pooledWriter)
|
||||
|
||||
h.logger.Debug("HTTP request processing completed",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
shouldClose := false
|
||||
if req.Close {
|
||||
shouldClose = true
|
||||
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||
if req.Header.Get("Connection") != "keep-alive" {
|
||||
shouldClose = true
|
||||
}
|
||||
}
|
||||
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
|
||||
if shouldClose {
|
||||
h.logger.Debug("Closing connection as requested by client or server")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// httpResponseWriter implements http.ResponseWriter for raw TCP connections.
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
writer *bufio.Writer
|
||||
header http.Header
|
||||
statusCode int
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.headerWritten = true
|
||||
|
||||
statusText := http.StatusText(statusCode)
|
||||
if statusText == "" {
|
||||
statusText = "Unknown"
|
||||
}
|
||||
|
||||
w.writer.WriteString("HTTP/1.1 ")
|
||||
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
|
||||
w.writer.WriteByte(' ')
|
||||
w.writer.WriteString(statusText)
|
||||
w.writer.WriteString("\r\n")
|
||||
|
||||
for key, values := range w.header {
|
||||
for _, value := range values {
|
||||
w.writer.WriteString(key)
|
||||
w.writer.WriteString(": ")
|
||||
w.writer.WriteString(value)
|
||||
w.writer.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
w.writer.WriteString("\r\n")
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
137
internal/server/tcp/lifecycle_manager.go
Normal file
137
internal/server/tcp/lifecycle_manager.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
// ConnectionLifecycleManager manages the lifecycle of a connection.
|
||||
type ConnectionLifecycleManager struct {
|
||||
once sync.Once
|
||||
stopCh chan struct{}
|
||||
cancel func()
|
||||
logger *zap.Logger
|
||||
|
||||
// Resources to clean up
|
||||
conn interface {
|
||||
Close() error
|
||||
SetDeadline(time.Time) error
|
||||
}
|
||||
frameWriter *protocol.FrameWriter
|
||||
proxy interface{ Stop() }
|
||||
session *yamux.Session
|
||||
portAlloc *PortAllocator
|
||||
port int
|
||||
manager *tunnel.Manager
|
||||
subdomain string
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
// NewConnectionLifecycleManager creates a new lifecycle manager.
|
||||
func NewConnectionLifecycleManager(
|
||||
stopCh chan struct{},
|
||||
cancel func(),
|
||||
logger *zap.Logger,
|
||||
) *ConnectionLifecycleManager {
|
||||
return &ConnectionLifecycleManager{
|
||||
stopCh: stopCh,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetConnection sets the connection to manage.
|
||||
func (clm *ConnectionLifecycleManager) SetConnection(conn interface {
|
||||
Close() error
|
||||
SetDeadline(time.Time) error
|
||||
}) {
|
||||
clm.conn = conn
|
||||
}
|
||||
|
||||
// SetFrameWriter sets the frame writer to close.
|
||||
func (clm *ConnectionLifecycleManager) SetFrameWriter(fw *protocol.FrameWriter) {
|
||||
clm.frameWriter = fw
|
||||
}
|
||||
|
||||
// SetProxy sets the proxy to stop.
|
||||
func (clm *ConnectionLifecycleManager) SetProxy(proxy interface{ Stop() }) {
|
||||
clm.proxy = proxy
|
||||
}
|
||||
|
||||
// SetSession sets the yamux session to close.
|
||||
func (clm *ConnectionLifecycleManager) SetSession(session *yamux.Session) {
|
||||
clm.session = session
|
||||
}
|
||||
|
||||
// SetPortAllocation sets the port allocation to release.
|
||||
func (clm *ConnectionLifecycleManager) SetPortAllocation(portAlloc *PortAllocator, port int) {
|
||||
clm.portAlloc = portAlloc
|
||||
clm.port = port
|
||||
}
|
||||
|
||||
// SetTunnelRegistration sets the tunnel registration to clean up.
|
||||
func (clm *ConnectionLifecycleManager) SetTunnelRegistration(
|
||||
manager *tunnel.Manager,
|
||||
subdomain string,
|
||||
tunnelID string,
|
||||
groupManager *ConnectionGroupManager,
|
||||
) {
|
||||
clm.manager = manager
|
||||
clm.subdomain = subdomain
|
||||
clm.tunnelID = tunnelID
|
||||
clm.groupManager = groupManager
|
||||
}
|
||||
|
||||
// Close closes the connection and cleans up all resources.
|
||||
func (clm *ConnectionLifecycleManager) Close() {
|
||||
clm.once.Do(func() {
|
||||
protocol.UnregisterConnection()
|
||||
close(clm.stopCh)
|
||||
|
||||
if clm.cancel != nil {
|
||||
clm.cancel()
|
||||
}
|
||||
|
||||
if clm.conn != nil {
|
||||
_ = clm.conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if clm.frameWriter != nil {
|
||||
clm.frameWriter.Close()
|
||||
}
|
||||
|
||||
if clm.proxy != nil {
|
||||
clm.proxy.Stop()
|
||||
}
|
||||
|
||||
if clm.session != nil {
|
||||
_ = clm.session.Close()
|
||||
}
|
||||
|
||||
if clm.conn != nil {
|
||||
clm.conn.Close()
|
||||
}
|
||||
|
||||
if clm.port > 0 && clm.portAlloc != nil {
|
||||
clm.portAlloc.Release(clm.port)
|
||||
}
|
||||
|
||||
if clm.subdomain != "" && clm.manager != nil {
|
||||
clm.manager.Unregister(clm.subdomain)
|
||||
if clm.tunnelID != "" && clm.groupManager != nil {
|
||||
clm.groupManager.RemoveGroup(clm.tunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
clm.logger.Info("Connection closed",
|
||||
zap.String("subdomain", clm.subdomain),
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -15,11 +15,25 @@ import (
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/recovery"
|
||||
"drip/internal/shared/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type ListenerConfig struct {
|
||||
Address string
|
||||
TLSConfig *tls.Config
|
||||
AuthToken string
|
||||
Manager *tunnel.Manager
|
||||
Logger *zap.Logger
|
||||
PortAlloc *PortAllocator
|
||||
Domain string
|
||||
TunnelDomain string
|
||||
PublicPort int
|
||||
HTTPHandler http.Handler
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
@@ -48,47 +62,47 @@ type Listener struct {
|
||||
allowedTunnelTypes []string
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
func NewListener(cfg ListenerConfig) *Listener {
|
||||
numCPU := pool.NumCPU()
|
||||
workers := numCPU * 5
|
||||
queueSize := workers * 20
|
||||
workerPool := pool.NewWorkerPool(workers, queueSize)
|
||||
|
||||
logger.Info("Worker pool configured",
|
||||
cfg.Logger.Info("Worker pool configured",
|
||||
zap.Int("cpu_cores", numCPU),
|
||||
zap.Int("workers", workers),
|
||||
zap.Int("queue_size", queueSize),
|
||||
)
|
||||
|
||||
panicMetrics := recovery.NewPanicMetrics(logger, nil)
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
panicMetrics := recovery.NewPanicMetrics(cfg.Logger, nil)
|
||||
recoverer := recovery.NewRecoverer(cfg.Logger, panicMetrics)
|
||||
|
||||
// Initialize worker pool metrics
|
||||
metrics.WorkerPoolSize.Set(float64(workers))
|
||||
|
||||
l := &Listener{
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
tunnelDomain: tunnelDomain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
address: cfg.Address,
|
||||
tlsConfig: cfg.TLSConfig,
|
||||
authToken: cfg.AuthToken,
|
||||
manager: cfg.Manager,
|
||||
portAlloc: cfg.PortAlloc,
|
||||
logger: cfg.Logger,
|
||||
domain: cfg.Domain,
|
||||
tunnelDomain: cfg.TunnelDomain,
|
||||
publicPort: cfg.PublicPort,
|
||||
httpHandler: cfg.HTTPHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
groupManager: NewConnectionGroupManager(cfg.Logger),
|
||||
}
|
||||
|
||||
// Set up WebSocket connection handler if httpHandler supports it
|
||||
if h, ok := httpHandler.(*proxy.Handler); ok {
|
||||
if h, ok := cfg.HTTPHandler.(*proxy.Handler); ok {
|
||||
h.SetWSConnectionHandler(l)
|
||||
h.SetPublicPort(publicPort)
|
||||
h.SetPublicPort(cfg.PublicPort)
|
||||
}
|
||||
|
||||
return l
|
||||
@@ -269,7 +283,19 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
)
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
conn := NewConnection(ConnectionConfig{
|
||||
Conn: netConn,
|
||||
AuthToken: l.authToken,
|
||||
Manager: l.manager,
|
||||
Logger: l.logger,
|
||||
PortAlloc: l.portAlloc,
|
||||
Domain: l.domain,
|
||||
TunnelDomain: l.tunnelDomain,
|
||||
PublicPort: l.publicPort,
|
||||
HTTPHandler: l.httpHandler,
|
||||
GroupManager: l.groupManager,
|
||||
HTTPListener: l.httpListener,
|
||||
})
|
||||
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
conn.SetAllowedTransports(l.allowedTransports)
|
||||
|
||||
@@ -297,18 +323,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
if utils.IsNetworkError(errStr) {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "failed to parse HTTP request") {
|
||||
if utils.IsProtocolError(errStr) {
|
||||
l.logger.Warn("Protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
@@ -387,7 +406,19 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
)
|
||||
|
||||
// Create connection handler (no TLS verification needed - already done by HTTP server)
|
||||
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
tcpConn := NewConnection(ConnectionConfig{
|
||||
Conn: conn,
|
||||
AuthToken: l.authToken,
|
||||
Manager: l.manager,
|
||||
Logger: l.logger,
|
||||
PortAlloc: l.portAlloc,
|
||||
Domain: l.domain,
|
||||
TunnelDomain: l.tunnelDomain,
|
||||
PublicPort: l.publicPort,
|
||||
HTTPHandler: l.httpHandler,
|
||||
GroupManager: l.groupManager,
|
||||
HTTPListener: l.httpListener,
|
||||
})
|
||||
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
|
||||
l.connMu.Lock()
|
||||
@@ -412,19 +443,11 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
if err := tcpConn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "websocket: close") {
|
||||
if utils.IsNetworkError(errStr) {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "tunnel type not allowed") {
|
||||
if utils.IsProtocolError(errStr) {
|
||||
l.logger.Warn("WebSocket tunnel protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
|
||||
186
internal/server/tcp/registration_handler.go
Normal file
186
internal/server/tcp/registration_handler.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
)
|
||||
|
||||
// RegistrationHandler handles tunnel registration logic.
|
||||
type RegistrationHandler struct {
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
groupManager *ConnectionGroupManager
|
||||
domain string
|
||||
tunnelDomain string
|
||||
publicPort int
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRegistrationHandler creates a new registration handler.
|
||||
func NewRegistrationHandler(
|
||||
manager *tunnel.Manager,
|
||||
portAlloc *PortAllocator,
|
||||
groupManager *ConnectionGroupManager,
|
||||
domain, tunnelDomain string,
|
||||
publicPort int,
|
||||
logger *zap.Logger,
|
||||
) *RegistrationHandler {
|
||||
return &RegistrationHandler{
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
groupManager: groupManager,
|
||||
domain: domain,
|
||||
tunnelDomain: tunnelDomain,
|
||||
publicPort: publicPort,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegistrationRequest contains all information needed for registration.
|
||||
type RegistrationRequest struct {
|
||||
TunnelType protocol.TunnelType
|
||||
CustomSubdomain string
|
||||
Token string
|
||||
ConnectionType string
|
||||
PoolCapabilities *protocol.PoolCapabilities
|
||||
IPAccess *protocol.IPAccessControl
|
||||
ProxyAuth *protocol.ProxyAuth
|
||||
LocalPort int
|
||||
}
|
||||
|
||||
// RegistrationResult contains the result of a registration attempt.
|
||||
type RegistrationResult struct {
|
||||
Subdomain string
|
||||
Port int
|
||||
TunnelURL string
|
||||
TunnelID string
|
||||
SupportsDataConn bool
|
||||
RecommendedConns int
|
||||
TunnelConn *tunnel.Connection
|
||||
}
|
||||
|
||||
// Register handles the tunnel registration process.
|
||||
func (rh *RegistrationHandler) Register(req *RegistrationRequest) (*RegistrationResult, error) {
|
||||
// Allocate port for TCP tunnels
|
||||
port := 0
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
if rh.portAlloc == nil {
|
||||
return nil, fmt.Errorf("port allocator not configured")
|
||||
}
|
||||
|
||||
if requestedPort, ok := parseTCPSubdomainPort(req.CustomSubdomain); ok {
|
||||
allocatedPort, err := rh.portAlloc.AllocateSpecific(requestedPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to allocate requested port %d: %w", requestedPort, err)
|
||||
}
|
||||
port = allocatedPort
|
||||
} else {
|
||||
allocatedPort, err := rh.portAlloc.Allocate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to allocate port: %w", err)
|
||||
}
|
||||
port = allocatedPort
|
||||
|
||||
if req.CustomSubdomain == "" {
|
||||
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Register with tunnel manager
|
||||
subdomain, err := rh.manager.Register(nil, req.CustomSubdomain)
|
||||
if err != nil {
|
||||
if port > 0 && rh.portAlloc != nil {
|
||||
rh.portAlloc.Release(port)
|
||||
}
|
||||
return nil, fmt.Errorf("tunnel registration failed: %w", err)
|
||||
}
|
||||
|
||||
// Get tunnel connection
|
||||
tunnelConn, ok := rh.manager.Get(subdomain)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get registered tunnel")
|
||||
}
|
||||
|
||||
// Configure tunnel
|
||||
tunnelConn.SetTunnelType(req.TunnelType)
|
||||
|
||||
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
|
||||
tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
|
||||
rh.logger.Info("IP access control configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Strings("allow_ips", req.IPAccess.AllowIPs),
|
||||
zap.Strings("deny_ips", req.IPAccess.DenyIPs),
|
||||
)
|
||||
}
|
||||
|
||||
if req.ProxyAuth != nil && req.ProxyAuth.Enabled {
|
||||
tunnelConn.SetProxyAuth(req.ProxyAuth)
|
||||
rh.logger.Info("Proxy authentication configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
}
|
||||
|
||||
// Build tunnel URL
|
||||
urlBuilder := utils.NewTunnelURLBuilder(rh.tunnelDomain, rh.publicPort)
|
||||
tunnelURL := urlBuilder.BuildURL(subdomain, req.TunnelType, port)
|
||||
|
||||
// Handle connection groups for multi-connection support
|
||||
var tunnelID string
|
||||
var supportsDataConn bool
|
||||
recommendedConns := 0
|
||||
|
||||
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && rh.groupManager != nil {
|
||||
// This will be handled by the caller since it needs the connection instance
|
||||
supportsDataConn = true
|
||||
recommendedConns = 4
|
||||
}
|
||||
|
||||
rh.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("tunnel_type", string(req.TunnelType)),
|
||||
zap.Int("local_port", req.LocalPort),
|
||||
zap.Int("remote_port", port),
|
||||
)
|
||||
|
||||
return &RegistrationResult{
|
||||
Subdomain: subdomain,
|
||||
Port: port,
|
||||
TunnelURL: tunnelURL,
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
TunnelConn: tunnelConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BuildRegistrationResponse creates a protocol registration response.
|
||||
func (rh *RegistrationHandler) BuildRegistrationResponse(result *RegistrationResult) (*protocol.RegisterResponse, error) {
|
||||
resp := &protocol.RegisterResponse{
|
||||
Subdomain: result.Subdomain,
|
||||
Port: result.Port,
|
||||
URL: result.TunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
TunnelID: result.TunnelID,
|
||||
SupportsDataConn: result.SupportsDataConn,
|
||||
RecommendedConns: result.RecommendedConns,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// SendRegistrationResponse sends the registration response frame.
|
||||
func (rh *RegistrationHandler) SendRegistrationResponse(conn interface{ Write([]byte) (int, error) }, resp *protocol.RegisterResponse) error {
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal registration response: %w", err)
|
||||
}
|
||||
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
|
||||
return protocol.WriteFrame(conn, ackFrame)
|
||||
}
|
||||
@@ -35,6 +35,11 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
}
|
||||
c.session = session
|
||||
|
||||
// Update lifecycle manager with session
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetSession(session)
|
||||
}
|
||||
|
||||
openStream := session.Open
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
@@ -48,6 +53,11 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
|
||||
}
|
||||
|
||||
// Update lifecycle manager with proxy
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetProxy(c.proxy)
|
||||
}
|
||||
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tcp proxy: %w", err)
|
||||
}
|
||||
@@ -76,6 +86,11 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
}
|
||||
c.session = session
|
||||
|
||||
// Update lifecycle manager with session
|
||||
if c.lifecycleManager != nil {
|
||||
c.lifecycleManager.SetSession(session)
|
||||
}
|
||||
|
||||
openStream := session.Open
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
|
||||
@@ -31,12 +31,6 @@ var (
|
||||
ErrRateLimitExceeded = errors.New("rate limit exceeded, try again later")
|
||||
)
|
||||
|
||||
// rateLimitEntry tracks registration attempts per IP
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
windowEnd time.Time
|
||||
}
|
||||
|
||||
// shard holds a subset of tunnels with its own lock
|
||||
type shard struct {
|
||||
tunnels map[string]*Connection
|
||||
@@ -52,16 +46,16 @@ type Manager struct {
|
||||
// Limits
|
||||
maxTunnels int
|
||||
maxTunnelsPerIP int
|
||||
rateLimit int
|
||||
rateLimitWindow time.Duration
|
||||
|
||||
// Global counters (atomic for lock-free reads)
|
||||
tunnelCount atomic.Int64
|
||||
|
||||
// Per-IP tracking (requires separate lock as it spans shards)
|
||||
ipMu sync.RWMutex
|
||||
tunnelsByIP map[string]int // IP -> tunnel count
|
||||
rateLimits map[string]*rateLimitEntry // IP -> rate limit entry
|
||||
tunnelsByIP map[string]int // IP -> tunnel count
|
||||
|
||||
// Rate limiting
|
||||
rateLimiter *RateLimiter
|
||||
|
||||
// Lifecycle
|
||||
stopCh chan struct{}
|
||||
@@ -71,7 +65,7 @@ type Manager struct {
|
||||
type ManagerConfig struct {
|
||||
MaxTunnels int
|
||||
MaxTunnelsPerIP int
|
||||
RateLimit int // Registrations per IP per window
|
||||
RateLimit int // Registrations per IP per window
|
||||
RateLimitWindow time.Duration
|
||||
}
|
||||
|
||||
@@ -117,10 +111,8 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
|
||||
logger: logger,
|
||||
maxTunnels: cfg.MaxTunnels,
|
||||
maxTunnelsPerIP: cfg.MaxTunnelsPerIP,
|
||||
rateLimit: cfg.RateLimit,
|
||||
rateLimitWindow: cfg.RateLimitWindow,
|
||||
tunnelsByIP: make(map[string]int),
|
||||
rateLimits: make(map[string]*rateLimitEntry),
|
||||
rateLimiter: NewRateLimiter(cfg.RateLimit, cfg.RateLimitWindow, logger),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
@@ -140,28 +132,6 @@ func (m *Manager) getShard(subdomain string) *shard {
|
||||
return &m.shards[h.Sum32()%numShards]
|
||||
}
|
||||
|
||||
// checkRateLimit checks if the IP has exceeded rate limit (caller must hold ipMu)
|
||||
func (m *Manager) checkRateLimitLocked(ip string) bool {
|
||||
now := time.Now()
|
||||
entry, exists := m.rateLimits[ip]
|
||||
|
||||
if !exists || now.After(entry.windowEnd) {
|
||||
// New window
|
||||
m.rateLimits[ip] = &rateLimitEntry{
|
||||
count: 1,
|
||||
windowEnd: now.Add(m.rateLimitWindow),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if entry.count >= m.rateLimit {
|
||||
return false
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return true
|
||||
}
|
||||
|
||||
// Register registers a new tunnel connection with IP-based limits
|
||||
func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string, error) {
|
||||
return m.RegisterWithIP(conn, customSubdomain, "")
|
||||
@@ -193,19 +163,14 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
|
||||
// Check per-IP limits and reserve slot atomically
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
if !m.checkRateLimitLocked(remoteIP) {
|
||||
m.ipMu.Unlock()
|
||||
// Check rate limit first (has its own lock)
|
||||
if !m.rateLimiter.CheckAndIncrement(remoteIP) {
|
||||
rollbackGlobal()
|
||||
m.logger.Warn("Rate limit exceeded",
|
||||
zap.String("ip", remoteIP),
|
||||
zap.Int("limit", m.rateLimit),
|
||||
)
|
||||
metrics.RateLimitRejections.WithLabelValues("registration", remoteIP).Inc()
|
||||
metrics.TunnelRegistrationFailures.WithLabelValues("rate_limit").Inc()
|
||||
return "", ErrRateLimitExceeded
|
||||
}
|
||||
|
||||
m.ipMu.Lock()
|
||||
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
|
||||
currentPerIP := m.tunnelsByIP[remoteIP]
|
||||
m.ipMu.Unlock()
|
||||
@@ -427,14 +392,7 @@ func (m *Manager) CleanupStale(timeout time.Duration) int {
|
||||
}
|
||||
|
||||
// Cleanup expired rate limit entries
|
||||
m.ipMu.Lock()
|
||||
now := time.Now()
|
||||
for ip, entry := range m.rateLimits {
|
||||
if now.After(entry.windowEnd) {
|
||||
delete(m.rateLimits, ip)
|
||||
}
|
||||
}
|
||||
m.ipMu.Unlock()
|
||||
m.rateLimiter.Cleanup()
|
||||
|
||||
if totalCleaned > 0 {
|
||||
m.logger.Info("Cleaned up stale tunnels",
|
||||
|
||||
86
internal/server/tunnel/rate_limiter.go
Normal file
86
internal/server/tunnel/rate_limiter.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/server/metrics"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// rateLimitEntry tracks registration attempts per IP
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
windowEnd time.Time
|
||||
}
|
||||
|
||||
// RateLimiter manages rate limiting for tunnel registrations.
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
rateLimits map[string]*rateLimitEntry
|
||||
rateLimit int
|
||||
rateLimitWindow time.Duration
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter.
|
||||
func NewRateLimiter(rateLimit int, rateLimitWindow time.Duration, logger *zap.Logger) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
rateLimits: make(map[string]*rateLimitEntry),
|
||||
rateLimit: rateLimit,
|
||||
rateLimitWindow: rateLimitWindow,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAndIncrement checks if the IP has exceeded rate limit and increments the counter.
|
||||
func (rl *RateLimiter) CheckAndIncrement(ip string) bool {
|
||||
if ip == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, exists := rl.rateLimits[ip]
|
||||
|
||||
if !exists || now.After(entry.windowEnd) {
|
||||
// New window
|
||||
rl.rateLimits[ip] = &rateLimitEntry{
|
||||
count: 1,
|
||||
windowEnd: now.Add(rl.rateLimitWindow),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if entry.count >= rl.rateLimit {
|
||||
rl.logger.Warn("Rate limit exceeded",
|
||||
zap.String("ip", ip),
|
||||
zap.Int("limit", rl.rateLimit),
|
||||
)
|
||||
metrics.RateLimitRejections.WithLabelValues("registration", ip).Inc()
|
||||
return false
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes expired rate limit entries.
|
||||
func (rl *RateLimiter) Cleanup() int {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
|
||||
for ip, entry := range rl.rateLimits {
|
||||
if now.After(entry.windowEnd) {
|
||||
delete(rl.rateLimits, ip)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
return removed
|
||||
}
|
||||
39
internal/shared/httputil/detection.go
Normal file
39
internal/shared/httputil/detection.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package httputil
|
||||
|
||||
import "strings"
|
||||
|
||||
// HTTPMethods contains common HTTP method prefixes for protocol detection.
|
||||
var HTTPMethods = []string{
|
||||
"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC",
|
||||
}
|
||||
|
||||
// IsHTTPRequest checks if the given bytes represent the start of an HTTP request.
|
||||
// It checks for common HTTP method prefixes.
|
||||
func IsHTTPRequest(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dataStr := string(data[:4])
|
||||
for _, method := range HTTPMethods {
|
||||
if strings.HasPrefix(dataStr, method) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DetectHTTPMethod returns the HTTP method if the data starts with one, or empty string.
|
||||
func DetectHTTPMethod(data []byte) string {
|
||||
if len(data) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
dataStr := string(data)
|
||||
for _, method := range HTTPMethods {
|
||||
if strings.HasPrefix(dataStr, method) {
|
||||
return strings.TrimSpace(method)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
58
internal/shared/httputil/error_response.go
Normal file
58
internal/shared/httputil/error_response.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// HTTPErrorResponse represents a standard HTTP error response.
|
||||
type HTTPErrorResponse struct {
|
||||
StatusCode int
|
||||
StatusText string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Common HTTP error responses
|
||||
var (
|
||||
ServiceUnavailable = &HTTPErrorResponse{
|
||||
StatusCode: 503,
|
||||
StatusText: "Service Unavailable",
|
||||
Message: "Server busy, please retry later",
|
||||
}
|
||||
|
||||
HandlerNotConfigured = &HTTPErrorResponse{
|
||||
StatusCode: 503,
|
||||
StatusText: "Service Unavailable",
|
||||
Message: "HTTP handler not configured for this TCP port",
|
||||
}
|
||||
)
|
||||
|
||||
// WriteErrorResponse writes an HTTP error response to the connection.
|
||||
func WriteErrorResponse(conn net.Conn, resp *HTTPErrorResponse) error {
|
||||
response := fmt.Sprintf(
|
||||
"HTTP/1.1 %d %s\r\n"+
|
||||
"Content-Type: text/plain\r\n"+
|
||||
"Content-Length: %d\r\n"+
|
||||
"Connection: close\r\n"+
|
||||
"\r\n"+
|
||||
"%s\r\n",
|
||||
resp.StatusCode,
|
||||
resp.StatusText,
|
||||
len(resp.Message)+2,
|
||||
resp.Message,
|
||||
)
|
||||
_, err := conn.Write([]byte(response))
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteServiceUnavailable writes a 503 Service Unavailable response.
|
||||
func WriteServiceUnavailable(conn net.Conn, message string) error {
|
||||
if message == "" {
|
||||
message = ServiceUnavailable.Message
|
||||
}
|
||||
return WriteErrorResponse(conn, &HTTPErrorResponse{
|
||||
StatusCode: 503,
|
||||
StatusText: "Service Unavailable",
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
38
internal/shared/httputil/response.go
Normal file
38
internal/shared/httputil/response.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// WriteJSON writes a JSON response with the appropriate headers.
|
||||
func WriteJSON(w http.ResponseWriter, data []byte) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// WriteHTML writes an HTML response with the appropriate headers.
|
||||
func WriteHTML(w http.ResponseWriter, data []byte) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// WriteHTMLWithStatus writes an HTML response with a custom status code.
|
||||
func WriteHTMLWithStatus(w http.ResponseWriter, data []byte, statusCode int) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.WriteHeader(statusCode)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// SetContentLength sets the Content-Length header.
|
||||
func SetContentLength(w http.ResponseWriter, length int64) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", length))
|
||||
}
|
||||
|
||||
// SetCloseConnection sets the Connection: close header.
|
||||
func SetCloseConnection(w http.ResponseWriter) {
|
||||
w.Header().Set("Connection", "close")
|
||||
}
|
||||
116
internal/shared/mux/session_builder.go
Normal file
116
internal/shared/mux/session_builder.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
// BufferedConn wraps a connection with a buffered reader.
|
||||
type BufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
// NewBufferedConn creates a new buffered connection.
|
||||
func NewBufferedConn(conn net.Conn, reader *bufio.Reader) *BufferedConn {
|
||||
return &BufferedConn{
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from the buffered reader if available, otherwise from the connection.
|
||||
func (bc *BufferedConn) Read(p []byte) (int, error) {
|
||||
if bc.reader != nil {
|
||||
return bc.reader.Read(p)
|
||||
}
|
||||
return bc.Conn.Read(p)
|
||||
}
|
||||
|
||||
// SessionBuilder helps build yamux sessions with consistent configuration.
|
||||
type SessionBuilder struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
config *yamux.Config
|
||||
isServer bool
|
||||
}
|
||||
|
||||
// NewSessionBuilder creates a new session builder.
|
||||
func NewSessionBuilder(conn net.Conn) *SessionBuilder {
|
||||
return &SessionBuilder{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// WithReader sets the buffered reader for the session.
|
||||
func (sb *SessionBuilder) WithReader(reader *bufio.Reader) *SessionBuilder {
|
||||
sb.reader = reader
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithConfig sets the yamux configuration.
|
||||
func (sb *SessionBuilder) WithConfig(config *yamux.Config) *SessionBuilder {
|
||||
sb.config = config
|
||||
return sb
|
||||
}
|
||||
|
||||
// AsServer configures the session as a server.
|
||||
func (sb *SessionBuilder) AsServer() *SessionBuilder {
|
||||
sb.isServer = true
|
||||
if sb.config == nil {
|
||||
sb.config = NewServerConfig()
|
||||
}
|
||||
return sb
|
||||
}
|
||||
|
||||
// AsClient configures the session as a client.
|
||||
func (sb *SessionBuilder) AsClient() *SessionBuilder {
|
||||
sb.isServer = false
|
||||
if sb.config == nil {
|
||||
sb.config = NewClientConfig()
|
||||
}
|
||||
return sb
|
||||
}
|
||||
|
||||
// Build creates the yamux session.
|
||||
func (sb *SessionBuilder) Build() (*yamux.Session, error) {
|
||||
conn := sb.conn
|
||||
|
||||
// Wrap with buffered reader if provided
|
||||
if sb.reader != nil {
|
||||
conn = NewBufferedConn(sb.conn, sb.reader)
|
||||
}
|
||||
|
||||
// Use default config if not set
|
||||
if sb.config == nil {
|
||||
if sb.isServer {
|
||||
sb.config = NewServerConfig()
|
||||
} else {
|
||||
sb.config = NewClientConfig()
|
||||
}
|
||||
}
|
||||
|
||||
// Create session based on role
|
||||
if sb.isServer {
|
||||
return yamux.Server(conn, sb.config)
|
||||
}
|
||||
return yamux.Client(conn, sb.config)
|
||||
}
|
||||
|
||||
// BuildClientSession creates a yamux client session with the given connection and reader.
|
||||
func BuildClientSession(conn net.Conn, reader *bufio.Reader) (*yamux.Session, error) {
|
||||
return NewSessionBuilder(conn).
|
||||
WithReader(reader).
|
||||
AsClient().
|
||||
Build()
|
||||
}
|
||||
|
||||
// BuildServerSession creates a yamux server session with the given connection and reader.
|
||||
func BuildServerSession(conn net.Conn, reader *bufio.Reader) (*yamux.Session, error) {
|
||||
return NewSessionBuilder(conn).
|
||||
WithReader(reader).
|
||||
AsServer().
|
||||
Build()
|
||||
}
|
||||
78
internal/shared/netutil/ip.go
Normal file
78
internal/shared/netutil/ip.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var privateNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
privateCIDRs := []string{
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
}
|
||||
for _, cidr := range privateCIDRs {
|
||||
_, ipNet, _ := net.ParseCIDR(cidr)
|
||||
privateNetworks = append(privateNetworks, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractRemoteIP extracts the IP address from a remote address string (host:port format).
|
||||
func ExtractRemoteIP(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// IsPrivateIP checks if the given IP is a private/loopback address.
|
||||
func IsPrivateIP(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, network := range privateNetworks {
|
||||
if network.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractClientIP extracts the client IP from the request.
|
||||
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
|
||||
// comes from a private/loopback network (typical reverse proxy setup).
|
||||
func ExtractClientIP(r *http.Request) string {
|
||||
// First, get the direct remote address
|
||||
remoteIP := ExtractRemoteIP(r.RemoteAddr)
|
||||
|
||||
// Only trust proxy headers if the request comes from a private network
|
||||
if IsPrivateIP(remoteIP) {
|
||||
// Check X-Forwarded-For header (may contain multiple IPs)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to remote address
|
||||
return remoteIP
|
||||
}
|
||||
@@ -3,10 +3,10 @@ package pool
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
|
||||
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
|
||||
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
|
||||
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
|
||||
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
|
||||
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
|
||||
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
|
||||
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
|
||||
)
|
||||
|
||||
type BufferPool struct {
|
||||
|
||||
48
internal/shared/pool/bufio_pool.go
Normal file
48
internal/shared/pool/bufio_pool.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BufioReaderPool provides a pool of bufio.Reader instances.
|
||||
var BufioReaderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bufio.NewReaderSize(nil, 32*1024)
|
||||
},
|
||||
}
|
||||
|
||||
// BufioWriterPool provides a pool of bufio.Writer instances.
|
||||
var BufioWriterPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bufio.NewWriterSize(nil, 4096)
|
||||
},
|
||||
}
|
||||
|
||||
// GetReader gets a bufio.Reader from the pool and resets it to read from r.
|
||||
func GetReader(r interface{}) *bufio.Reader {
|
||||
reader := BufioReaderPool.Get().(*bufio.Reader)
|
||||
if resetter, ok := r.(interface{ Reset(interface{}) }); ok {
|
||||
resetter.Reset(r)
|
||||
}
|
||||
return reader
|
||||
}
|
||||
|
||||
// PutReader returns a bufio.Reader to the pool.
|
||||
func PutReader(reader *bufio.Reader) {
|
||||
BufioReaderPool.Put(reader)
|
||||
}
|
||||
|
||||
// GetWriter gets a bufio.Writer from the pool and resets it to write to w.
|
||||
func GetWriter(w interface{}) *bufio.Writer {
|
||||
writer := BufioWriterPool.Get().(*bufio.Writer)
|
||||
if resetter, ok := w.(interface{ Reset(interface{}) }); ok {
|
||||
resetter.Reset(w)
|
||||
}
|
||||
return writer
|
||||
}
|
||||
|
||||
// PutWriter returns a bufio.Writer to the pool.
|
||||
func PutWriter(writer *bufio.Writer) {
|
||||
BufioWriterPool.Put(writer)
|
||||
}
|
||||
68
internal/shared/protocol/error_sender.go
Normal file
68
internal/shared/protocol/error_sender.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ErrorSender handles sending error frames over connections.
|
||||
type ErrorSender struct {
|
||||
conn net.Conn
|
||||
frameWriter *FrameWriter
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewErrorSender creates a new error sender.
|
||||
func NewErrorSender(conn net.Conn, frameWriter *FrameWriter, logger *zap.Logger) *ErrorSender {
|
||||
return &ErrorSender{
|
||||
conn: conn,
|
||||
frameWriter: frameWriter,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendError sends an error frame with the given code and message.
|
||||
func (e *ErrorSender) SendError(code, message string) error {
|
||||
errMsg := ErrorMessage{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(errMsg)
|
||||
if err != nil {
|
||||
e.logger.Error("Failed to marshal error message", zap.Error(err))
|
||||
return fmt.Errorf("failed to marshal error: %w", err)
|
||||
}
|
||||
|
||||
errFrame := NewFrame(FrameTypeError, data)
|
||||
|
||||
if e.frameWriter == nil {
|
||||
return WriteFrame(e.conn, errFrame)
|
||||
}
|
||||
|
||||
return e.frameWriter.WriteFrame(errFrame)
|
||||
}
|
||||
|
||||
// SendAuthenticationError sends an authentication failed error.
|
||||
func (e *ErrorSender) SendAuthenticationError() error {
|
||||
return e.SendError("authentication_failed", "Invalid authentication token")
|
||||
}
|
||||
|
||||
// SendRegistrationError sends a registration failed error.
|
||||
func (e *ErrorSender) SendRegistrationError(message string) error {
|
||||
return e.SendError("registration_failed", message)
|
||||
}
|
||||
|
||||
// SendPortAllocationError sends a port allocation failed error.
|
||||
func (e *ErrorSender) SendPortAllocationError(message string) error {
|
||||
return e.SendError("port_allocation_failed", message)
|
||||
}
|
||||
|
||||
// SendTunnelTypeNotAllowedError sends a tunnel type not allowed error.
|
||||
func (e *ErrorSender) SendTunnelTypeNotAllowedError(tunnelType string) error {
|
||||
return e.SendError("tunnel_type_not_allowed",
|
||||
fmt.Sprintf("Tunnel type '%s' is not allowed on this server", tunnelType))
|
||||
}
|
||||
74
internal/shared/protocol/heartbeat.go
Normal file
74
internal/shared/protocol/heartbeat.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HeartbeatManager manages heartbeat tracking and checking for connections.
|
||||
type HeartbeatManager struct {
|
||||
mu sync.RWMutex
|
||||
lastHeartbeat time.Time
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
logger *zap.Logger
|
||||
frameWriter *FrameWriter
|
||||
}
|
||||
|
||||
// NewHeartbeatManager creates a new heartbeat manager.
|
||||
func NewHeartbeatManager(interval, timeout time.Duration, frameWriter *FrameWriter, logger *zap.Logger) *HeartbeatManager {
|
||||
return &HeartbeatManager{
|
||||
lastHeartbeat: time.Now(),
|
||||
interval: interval,
|
||||
timeout: timeout,
|
||||
frameWriter: frameWriter,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastHeartbeat updates the last heartbeat timestamp.
|
||||
func (h *HeartbeatManager) UpdateLastHeartbeat() {
|
||||
h.mu.Lock()
|
||||
h.lastHeartbeat = time.Now()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetLastHeartbeat returns the last heartbeat timestamp.
|
||||
func (h *HeartbeatManager) GetLastHeartbeat() time.Time {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.lastHeartbeat
|
||||
}
|
||||
|
||||
// IsAlive checks if the connection is still alive based on heartbeat timeout.
|
||||
func (h *HeartbeatManager) IsAlive() bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return time.Since(h.lastHeartbeat) < h.timeout
|
||||
}
|
||||
|
||||
// SendHeartbeatAck sends a heartbeat acknowledgment frame.
|
||||
func (h *HeartbeatManager) SendHeartbeatAck() error {
|
||||
ackFrame := NewFrame(FrameTypeHeartbeatAck, nil)
|
||||
err := h.frameWriter.WriteControl(ackFrame)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleHeartbeat handles a received heartbeat frame.
|
||||
func (h *HeartbeatManager) HandleHeartbeat() error {
|
||||
h.UpdateLastHeartbeat()
|
||||
return h.SendHeartbeatAck()
|
||||
}
|
||||
|
||||
// TimeSinceLastHeartbeat returns the duration since the last heartbeat.
|
||||
func (h *HeartbeatManager) TimeSinceLastHeartbeat() time.Duration {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return time.Since(h.lastHeartbeat)
|
||||
}
|
||||
@@ -14,7 +14,9 @@ type IPAccessControl struct {
|
||||
|
||||
type ProxyAuth struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
type RegisterRequest struct {
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
|
||||
)
|
||||
|
||||
type memoryStatusEx struct {
|
||||
|
||||
49
internal/shared/utils/allowed_list.go
Normal file
49
internal/shared/utils/allowed_list.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package utils
|
||||
|
||||
import "strings"
|
||||
|
||||
// AllowedList manages a list of allowed values with case-insensitive matching.
|
||||
type AllowedList struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
// NewAllowedList creates a new AllowedList from the given items.
|
||||
func NewAllowedList(items []string) *AllowedList {
|
||||
return &AllowedList{items: items}
|
||||
}
|
||||
|
||||
// IsAllowed checks if a value is in the allowed list.
|
||||
// If the list is empty, all values are allowed.
|
||||
// Matching is case-insensitive.
|
||||
func (a *AllowedList) IsAllowed(value string) bool {
|
||||
if len(a.items) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, item := range a.items {
|
||||
if strings.EqualFold(item, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferred returns the first item in the list, or the default value if empty.
|
||||
func (a *AllowedList) GetPreferred(defaultValue string) string {
|
||||
if len(a.items) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
if len(a.items) == 1 {
|
||||
return a.items[0]
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Items returns the underlying slice of allowed items.
|
||||
func (a *AllowedList) Items() []string {
|
||||
return a.items
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the allowed list is empty.
|
||||
func (a *AllowedList) IsEmpty() bool {
|
||||
return len(a.items) == 0
|
||||
}
|
||||
35
internal/shared/utils/errors.go
Normal file
35
internal/shared/utils/errors.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package utils
|
||||
|
||||
import "strings"
|
||||
|
||||
// IsNetworkError checks if an error message indicates a common network error
|
||||
// that should be handled gracefully (not logged as severe errors).
|
||||
func IsNetworkError(errStr string) bool {
|
||||
return strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "use of closed network connection") ||
|
||||
strings.Contains(errStr, "websocket: close")
|
||||
}
|
||||
|
||||
// IsProtocolError checks if an error message indicates a protocol-level error
|
||||
// (invalid requests, malformed data, etc.).
|
||||
func IsProtocolError(errStr string) bool {
|
||||
return strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "failed to parse HTTP request") ||
|
||||
strings.Contains(errStr, "tunnel type not allowed")
|
||||
}
|
||||
|
||||
// ContainsAny checks if a string contains any of the given substrings.
|
||||
func ContainsAny(s string, substrings ...string) bool {
|
||||
for _, substr := range substrings {
|
||||
if strings.Contains(s, substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
42
internal/shared/utils/url_builder.go
Normal file
42
internal/shared/utils/url_builder.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
// TunnelURLBuilder helps construct tunnel URLs consistently.
|
||||
type TunnelURLBuilder struct {
|
||||
tunnelDomain string
|
||||
publicPort int
|
||||
}
|
||||
|
||||
// NewTunnelURLBuilder creates a new URL builder.
|
||||
func NewTunnelURLBuilder(tunnelDomain string, publicPort int) *TunnelURLBuilder {
|
||||
return &TunnelURLBuilder{
|
||||
tunnelDomain: tunnelDomain,
|
||||
publicPort: publicPort,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHTTPURL builds an HTTP/HTTPS tunnel URL.
|
||||
func (b *TunnelURLBuilder) BuildHTTPURL(subdomain string) string {
|
||||
if b.publicPort == 443 {
|
||||
return fmt.Sprintf("https://%s.%s", subdomain, b.tunnelDomain)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, b.tunnelDomain, b.publicPort)
|
||||
}
|
||||
|
||||
// BuildTCPURL builds a TCP tunnel URL.
|
||||
func (b *TunnelURLBuilder) BuildTCPURL(port int) string {
|
||||
return fmt.Sprintf("tcp://%s:%d", b.tunnelDomain, port)
|
||||
}
|
||||
|
||||
// BuildURL builds a tunnel URL based on the tunnel type.
|
||||
func (b *TunnelURLBuilder) BuildURL(subdomain string, tunnelType protocol.TunnelType, port int) string {
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
return b.BuildHTTPURL(subdomain)
|
||||
}
|
||||
return b.BuildTCPURL(port)
|
||||
}
|
||||
@@ -12,15 +12,16 @@ import (
|
||||
|
||||
// TunnelConfig holds configuration for a predefined tunnel
|
||||
type TunnelConfig struct {
|
||||
Name string `yaml:"name"` // Tunnel name (required, unique identifier)
|
||||
Type string `yaml:"type"` // Tunnel type: http, https, tcp (required)
|
||||
Port int `yaml:"port"` // Local port to forward (required)
|
||||
Address string `yaml:"address,omitempty"` // Local address (default: 127.0.0.1)
|
||||
Subdomain string `yaml:"subdomain,omitempty"` // Custom subdomain
|
||||
Transport string `yaml:"transport,omitempty"` // Transport: auto, tcp, wss
|
||||
AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs
|
||||
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
|
||||
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
|
||||
Name string `yaml:"name"` // Tunnel name (required, unique identifier)
|
||||
Type string `yaml:"type"` // Tunnel type: http, https, tcp (required)
|
||||
Port int `yaml:"port"` // Local port to forward (required)
|
||||
Address string `yaml:"address,omitempty"` // Local address (default: 127.0.0.1)
|
||||
Subdomain string `yaml:"subdomain,omitempty"` // Custom subdomain
|
||||
Transport string `yaml:"transport,omitempty"` // Transport: auto, tcp, wss
|
||||
AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs
|
||||
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
|
||||
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
|
||||
AuthBearer string `yaml:"auth_bearer,omitempty"` // Proxy authentication bearer token (http/https only)
|
||||
}
|
||||
|
||||
// Validate checks if the tunnel configuration is valid
|
||||
@@ -44,6 +45,9 @@ func (t *TunnelConfig) Validate() error {
|
||||
return fmt.Errorf("invalid transport '%s' for '%s': must be auto, tcp, or wss", t.Transport, t.Name)
|
||||
}
|
||||
}
|
||||
if t.Auth != "" && t.AuthBearer != "" {
|
||||
return fmt.Errorf("only one of auth or auth_bearer can be set for '%s'", t.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user