From 307cf8e6ccecc1af5c945a7b576707a98bcddc31 Mon Sep 17 00:00:00 2001 From: zhiqing Date: Thu, 29 Jan 2026 14:40:53 +0800 Subject: [PATCH 1/2] feat: Add Bearer Token authentication support and optimize code structure - Add Bearer Token authentication, supporting tunnel access control via the --auth-bearer parameter - Refactor large modules into smaller, more focused components to improve code maintainability - Update dependency versions, including golang.org/x/crypto, golang.org/x/net, etc. - Add SilenceUsage and SilenceErrors configuration for all CLI commands - Modify connector configuration structure to support the new authentication method - Update recent change log in README with new feature descriptions BREAKING CHANGE: Authentication via Bearer Token is now supported, requiring the new --auth-bearer parameter --- README.md | 7 + README_CN.md | 7 + go.mod | 18 +- go.sum | 32 +- internal/client/cli/attach.go | 8 +- internal/client/cli/config.go | 58 +- internal/client/cli/http.go | 14 +- internal/client/cli/https.go | 13 +- internal/client/cli/list.go | 6 +- internal/client/cli/root.go | 2 + internal/client/cli/server.go | 62 +- internal/client/cli/start.go | 5 +- internal/client/cli/stop.go | 8 +- internal/client/cli/tcp.go | 6 +- internal/client/cli/tunnel_helpers.go | 6 + internal/client/tcp/connection_dialer.go | 197 +++++ internal/client/tcp/connector.go | 3 +- internal/client/tcp/pool_client.go | 197 +---- internal/client/tcp/pool_session.go | 2 +- internal/client/tcp/session_scaler.go | 150 ++++ internal/server/proxy/auth_handler.go | 412 ++++++++++ internal/server/proxy/handler.go | 765 ++---------------- internal/server/proxy/pages.go | 299 +++++++ internal/server/proxy/websocket_handler.go | 113 +++ internal/server/tcp/connection.go | 751 +++++------------ internal/server/tcp/connection_group.go | 8 +- .../server/tcp/data_connection_handler.go | 161 ++++ internal/server/tcp/frame_handler.go | 122 +++ internal/server/tcp/http_request_handler.go | 251 ++++++ internal/server/tcp/lifecycle_manager.go | 137 ++++ internal/server/tcp/listener.go | 99 ++- internal/server/tcp/registration_handler.go | 186 +++++ internal/server/tcp/tunnel.go | 15 + internal/server/tunnel/manager.go | 62 +- internal/server/tunnel/rate_limiter.go | 86 ++ internal/shared/httputil/detection.go | 39 + internal/shared/httputil/error_response.go | 58 ++ internal/shared/httputil/response.go | 38 + internal/shared/mux/session_builder.go | 116 +++ internal/shared/netutil/ip.go | 78 ++ internal/shared/pool/buffer_pool.go | 8 +- internal/shared/pool/bufio_pool.go | 48 ++ internal/shared/protocol/error_sender.go | 68 ++ internal/shared/protocol/heartbeat.go | 74 ++ internal/shared/protocol/messages.go | 2 + internal/shared/tuning/mem_windows.go | 4 +- internal/shared/utils/allowed_list.go | 49 ++ internal/shared/utils/errors.go | 35 + internal/shared/utils/url_builder.go | 42 + pkg/config/client_config.go | 22 +- 50 files changed, 3338 insertions(+), 1611 deletions(-) create mode 100644 internal/client/tcp/connection_dialer.go create mode 100644 internal/client/tcp/session_scaler.go create mode 100644 internal/server/proxy/auth_handler.go create mode 100644 internal/server/proxy/pages.go create mode 100644 internal/server/proxy/websocket_handler.go create mode 100644 internal/server/tcp/data_connection_handler.go create mode 100644 internal/server/tcp/frame_handler.go create mode 100644 internal/server/tcp/http_request_handler.go create mode 100644 internal/server/tcp/lifecycle_manager.go create mode 100644 internal/server/tcp/registration_handler.go create mode 100644 internal/server/tunnel/rate_limiter.go create mode 100644 internal/shared/httputil/detection.go create mode 100644 internal/shared/httputil/error_response.go create mode 100644 internal/shared/httputil/response.go create mode 100644 internal/shared/mux/session_builder.go create mode 100644 internal/shared/netutil/ip.go create mode 100644 internal/shared/pool/bufio_pool.go create mode 100644 internal/shared/protocol/error_sender.go create mode 100644 internal/shared/protocol/heartbeat.go create mode 100644 internal/shared/utils/allowed_list.go create mode 100644 internal/shared/utils/errors.go create mode 100644 internal/shared/utils/url_builder.go diff --git a/README.md b/README.md index d742e6f..3f36f99 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,13 @@ - **Actually free** - Use your own domain, no paid tiers or feature restrictions - **Open source** - BSD 3-Clause License +## Recent Changes + +### 2025-01-29 + +- **Bearer Token Authentication** - Added bearer token authentication support for tunnel access control +- **Code Optimization** - Refactored large modules into smaller, focused components for better maintainability + ## Quick Start ### Install diff --git a/README_CN.md b/README_CN.md index 7011fe4..24d1d91 100644 --- a/README_CN.md +++ b/README_CN.md @@ -33,6 +33,13 @@ - **真的免费** - 用你自己的域名,没有付费档位或功能阉割 - **开源** - BSD 3-Clause 协议 +## 最近更新 + +### 2025-01-29 + +- **Bearer Token 认证** - 新增 Bearer Token 认证支持,用于隧道访问控制 +- **代码优化** - 将大型模块重构为更小、更专注的组件,提升可维护性 + ## 快速开始 ### 安装 diff --git a/go.mod b/go.mod index 3341300..9eabd79 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module drip -go 1.25.4 +go 1.25.5 require ( github.com/charmbracelet/lipgloss v1.1.0 @@ -10,9 +10,9 @@ require ( github.com/prometheus/client_golang v1.23.2 github.com/spf13/cobra v1.10.2 go.uber.org/zap v1.27.1 - golang.org/x/crypto v0.46.0 - golang.org/x/net v0.48.0 - golang.org/x/sys v0.39.0 + golang.org/x/crypto v0.47.0 + golang.org/x/net v0.49.0 + golang.org/x/sys v0.40.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -21,12 +21,12 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect - github.com/charmbracelet/x/ansi v0.11.3 // indirect + github.com/charmbracelet/x/ansi v0.11.4 // indirect github.com/charmbracelet/x/cellbuf v0.0.14 // indirect github.com/charmbracelet/x/term v0.2.2 // indirect - github.com/clipperhouse/displaywidth v0.6.1 // indirect + github.com/clipperhouse/displaywidth v0.7.0 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect - github.com/clipperhouse/uax29/v2 v2.3.0 // indirect + github.com/clipperhouse/uax29/v2 v2.3.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -34,13 +34,13 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.67.4 // indirect + github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/text v0.33.0 // indirect google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index 0ec1251..a2e8541 100644 --- a/go.sum +++ b/go.sum @@ -8,18 +8,18 @@ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= -github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI= -github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI= +github.com/charmbracelet/x/ansi v0.11.4 h1:6G65PLu6HjmE858CnTUQY1LXT3ZUWwfvqEROLF8vqHI= +github.com/charmbracelet/x/ansi v0.11.4/go.mod h1:/5AZ+UfWExW3int5H5ugnsG/PWjNcSQcwYsHBlPFQN4= github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4= github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= -github.com/clipperhouse/displaywidth v0.6.1 h1:/zMlAezfDzT2xy6acHBzwIfyu2ic0hgkT83UX5EY2gY= -github.com/clipperhouse/displaywidth v0.6.1/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= +github.com/clipperhouse/displaywidth v0.7.0 h1:QNv1GYsnLX9QBrcWUtMlogpTXuM5FVnBwKWp1O5NwmE= +github.com/clipperhouse/displaywidth v0.7.0/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= -github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= -github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/clipperhouse/uax29/v2 v2.3.1 h1:RjM8gnVbFbgI67SBekIC7ihFpyXwRPYWXn9BZActHbw= +github.com/clipperhouse/uax29/v2 v2.3.1/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -57,8 +57,8 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc= -github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= @@ -84,17 +84,17 @@ go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/client/cli/attach.go b/internal/client/cli/attach.go index a5e572d..d5967a8 100644 --- a/internal/client/cli/attach.go +++ b/internal/client/cli/attach.go @@ -28,9 +28,11 @@ Examples: drip attach tcp 5432 Attach to TCP tunnel on port 5432 Press Ctrl+C to detach (tunnel will continue running).`, - Aliases: []string{"logs", "tail"}, - Args: cobra.MaximumNArgs(2), - RunE: runAttach, + Aliases: []string{"logs", "tail"}, + Args: cobra.MaximumNArgs(2), + RunE: runAttach, + SilenceUsage: true, + SilenceErrors: true, } func init() { diff --git a/internal/client/cli/config.go b/internal/client/cli/config.go index 668ed42..0d9c369 100644 --- a/internal/client/cli/config.go +++ b/internal/client/cli/config.go @@ -13,44 +13,56 @@ import ( ) var configCmd = &cobra.Command{ - Use: "config", - Short: "Manage configuration", - Long: "Manage Drip client configuration (server, token, tunnels)", + Use: "config", + Short: "Manage configuration", + Long: "Manage Drip client configuration (server, token, tunnels)", + SilenceUsage: true, + SilenceErrors: true, } var configInitCmd = &cobra.Command{ - Use: "init", - Short: "Initialize configuration interactively", - Long: "Initialize Drip configuration with interactive prompts", - RunE: runConfigInit, + Use: "init", + Short: "Initialize configuration interactively", + Long: "Initialize Drip configuration with interactive prompts", + RunE: runConfigInit, + SilenceUsage: true, + SilenceErrors: true, } var configShowCmd = &cobra.Command{ - Use: "show", - Short: "Show current configuration", - Long: "Display the current Drip configuration", - RunE: runConfigShow, + Use: "show", + Short: "Show current configuration", + Long: "Display the current Drip configuration", + RunE: runConfigShow, + SilenceUsage: true, + SilenceErrors: true, } var configSetCmd = &cobra.Command{ - Use: "set", - Short: "Set configuration values", - Long: "Set specific configuration values (server, token)", - RunE: runConfigSet, + Use: "set", + Short: "Set configuration values", + Long: "Set specific configuration values (server, token)", + RunE: runConfigSet, + SilenceUsage: true, + SilenceErrors: true, } var configResetCmd = &cobra.Command{ - Use: "reset", - Short: "Reset configuration", - Long: "Delete the configuration file", - RunE: runConfigReset, + Use: "reset", + Short: "Reset configuration", + Long: "Delete the configuration file", + RunE: runConfigReset, + SilenceUsage: true, + SilenceErrors: true, } var configValidateCmd = &cobra.Command{ - Use: "validate", - Short: "Validate configuration", - Long: "Validate the configuration file", - RunE: runConfigValidate, + Use: "validate", + Short: "Validate configuration", + Long: "Validate the configuration file", + RunE: runConfigValidate, + SilenceUsage: true, + SilenceErrors: true, } var ( diff --git a/internal/client/cli/http.go b/internal/client/cli/http.go index c3c2156..cac5c66 100644 --- a/internal/client/cli/http.go +++ b/internal/client/cli/http.go @@ -19,6 +19,7 @@ var ( allowIPs []string denyIPs []string authPass string + authBearer string transport string ) @@ -34,6 +35,7 @@ Example: drip http 3000 --allow-ip 10.0.0.1 Allow single IP drip http 3000 --deny-ip 1.2.3.4 Block specific IP drip http 3000 --auth secret Enable proxy authentication with password + drip http 3000 --auth-bearer sk-xxx Enable proxy authentication with bearer token drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly) Configuration: @@ -44,8 +46,10 @@ Transport options: auto - Automatically select based on server address (default) tcp - Direct TLS 1.3 connection wss - WebSocket over TLS (works through CDN like Cloudflare)`, - Args: cobra.ExactArgs(1), - RunE: runHTTP, + Args: cobra.ExactArgs(1), + RunE: runHTTP, + SilenceUsage: true, + SilenceErrors: true, } func init() { @@ -55,6 +59,7 @@ func init() { httpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication") + httpCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication") httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)") httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpCmd.Flags().MarkHidden("daemon-child") @@ -71,6 +76,10 @@ func runHTTP(_ *cobra.Command, args []string) error { return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress)) } + if authPass != "" && authBearer != "" { + return fmt.Errorf("cannot use --auth and --auth-bearer together") + } + serverAddr, token, err := resolveServerAddrAndToken("http", port) if err != nil { return err @@ -87,6 +96,7 @@ func runHTTP(_ *cobra.Command, args []string) error { AllowIPs: allowIPs, DenyIPs: denyIPs, AuthPass: authPass, + AuthBearer: authBearer, Transport: parseTransport(transport), } diff --git a/internal/client/cli/https.go b/internal/client/cli/https.go index 085f74a..1760db5 100644 --- a/internal/client/cli/https.go +++ b/internal/client/cli/https.go @@ -22,6 +22,7 @@ Example: drip https 443 --allow-ip 10.0.0.1 Allow single IP drip https 443 --deny-ip 1.2.3.4 Block specific IP drip https 443 --auth secret Enable proxy authentication with password + drip https 443 --auth-bearer sk-xxx Enable proxy authentication with bearer token drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly) Configuration: @@ -32,8 +33,10 @@ Transport options: auto - Automatically select based on server address (default) tcp - Direct TLS 1.3 connection wss - WebSocket over TLS (works through CDN like Cloudflare)`, - Args: cobra.ExactArgs(1), - RunE: runHTTPS, + Args: cobra.ExactArgs(1), + RunE: runHTTPS, + SilenceUsage: true, + SilenceErrors: true, } func init() { @@ -43,6 +46,7 @@ func init() { httpsCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication") + httpsCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication") httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)") httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpsCmd.Flags().MarkHidden("daemon-child") @@ -59,6 +63,10 @@ func runHTTPS(_ *cobra.Command, args []string) error { return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress)) } + if authPass != "" && authBearer != "" { + return fmt.Errorf("cannot use --auth and --auth-bearer together") + } + serverAddr, token, err := resolveServerAddrAndToken("https", port) if err != nil { return err @@ -75,6 +83,7 @@ func runHTTPS(_ *cobra.Command, args []string) error { AllowIPs: allowIPs, DenyIPs: denyIPs, AuthPass: authPass, + AuthBearer: authBearer, Transport: parseTransport(transport), } diff --git a/internal/client/cli/list.go b/internal/client/cli/list.go index 93f59d5..c320d9e 100644 --- a/internal/client/cli/list.go +++ b/internal/client/cli/list.go @@ -33,8 +33,10 @@ This command shows: In interactive mode, you can select a tunnel to: - Attach: View real-time logs - Stop: Terminate the tunnel`, - Aliases: []string{"ls", "ps", "status"}, - RunE: runList, + Aliases: []string{"ls", "ps", "status"}, + RunE: runList, + SilenceUsage: true, + SilenceErrors: true, } func init() { diff --git a/internal/client/cli/root.go b/internal/client/cli/root.go index 777e22d..31dfe23 100644 --- a/internal/client/cli/root.go +++ b/internal/client/cli/root.go @@ -45,6 +45,8 @@ Features: ✓ Auto-save configuration ✓ Custom subdomains ✓ Authentication via token`, + SilenceUsage: true, + SilenceErrors: true, } func init() { diff --git a/internal/client/cli/server.go b/internal/client/cli/server.go index 38b98d4..9d16621 100644 --- a/internal/client/cli/server.go +++ b/internal/client/cli/server.go @@ -22,28 +22,30 @@ import ( ) var ( - serverPort int - serverPublicPort int - serverDomain string - serverTunnelDomain string - serverAuthToken string - serverMetricsToken string - serverDebug bool - serverTCPPortMin int - serverTCPPortMax int - serverTLSCert string - serverTLSKey string - serverPprofPort int - serverTransports string - serverTunnelTypes string - serverConfigFile string + serverPort int + serverPublicPort int + serverDomain string + serverTunnelDomain string + serverAuthToken string + serverMetricsToken string + serverDebug bool + serverTCPPortMin int + serverTCPPortMax int + serverTLSCert string + serverTLSKey string + serverPprofPort int + serverTransports string + serverTunnelTypes string + serverConfigFile string ) var serverCmd = &cobra.Command{ - Use: "server", - Short: "Start Drip server", - Long: `Start the Drip tunnel server to accept client connections`, - RunE: runServer, + Use: "server", + Short: "Start Drip server", + Long: `Start the Drip tunnel server to accept client connections`, + RunE: runServer, + SilenceUsage: true, + SilenceErrors: true, } func init() { @@ -285,11 +287,29 @@ func runServer(cmd *cobra.Command, _ []string) error { listenAddr := fmt.Sprintf("0.0.0.0:%d", cfg.Port) - httpHandler := proxy.NewHandler(tunnelManager, logger, cfg.Domain, cfg.TunnelDomain, cfg.AuthToken, cfg.MetricsToken) + httpHandler := proxy.NewHandler(proxy.HandlerConfig{ + Manager: tunnelManager, + Logger: logger, + ServerDomain: cfg.Domain, + TunnelDomain: cfg.TunnelDomain, + AuthToken: cfg.AuthToken, + MetricsToken: cfg.MetricsToken, + }) httpHandler.SetAllowedTransports(cfg.AllowedTransports) httpHandler.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes) - listener := tcp.NewListener(listenAddr, tlsConfig, cfg.AuthToken, tunnelManager, logger, portAllocator, cfg.Domain, cfg.TunnelDomain, cfg.PublicPort, httpHandler) + listener := tcp.NewListener(tcp.ListenerConfig{ + Address: listenAddr, + TLSConfig: tlsConfig, + AuthToken: cfg.AuthToken, + Manager: tunnelManager, + Logger: logger, + PortAlloc: portAllocator, + Domain: cfg.Domain, + TunnelDomain: cfg.TunnelDomain, + PublicPort: cfg.PublicPort, + HTTPHandler: httpHandler, + }) listener.SetAllowedTransports(cfg.AllowedTransports) listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes) diff --git a/internal/client/cli/start.go b/internal/client/cli/start.go index 4de5df2..3ade7fa 100644 --- a/internal/client/cli/start.go +++ b/internal/client/cli/start.go @@ -54,7 +54,9 @@ Configuration file example (~/.drip/config.yaml): allow_ips: - 192.168.0.0/16 - 10.0.0.0/8`, - RunE: runStart, + RunE: runStart, + SilenceUsage: true, + SilenceErrors: true, } func init() { @@ -238,6 +240,7 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp AllowIPs: t.AllowIPs, DenyIPs: t.DenyIPs, AuthPass: t.Auth, + AuthBearer: t.AuthBearer, Transport: transport, } } diff --git a/internal/client/cli/stop.go b/internal/client/cli/stop.go index bf9fae1..30f96d1 100644 --- a/internal/client/cli/stop.go +++ b/internal/client/cli/stop.go @@ -19,9 +19,11 @@ Examples: drip stop all Stop all running tunnels Use 'drip list' to see running tunnels.`, - Aliases: []string{"kill"}, - Args: cobra.MinimumNArgs(1), - RunE: runStop, + Aliases: []string{"kill"}, + Args: cobra.MinimumNArgs(1), + RunE: runStop, + SilenceUsage: true, + SilenceErrors: true, } func init() { diff --git a/internal/client/cli/tcp.go b/internal/client/cli/tcp.go index 3f42ca3..2ab643e 100644 --- a/internal/client/cli/tcp.go +++ b/internal/client/cli/tcp.go @@ -41,8 +41,10 @@ Transport options: Note: TCP tunnels require dynamic port allocation on the server. When using CDN (--transport wss), the server must still expose the allocated port directly.`, - Args: cobra.ExactArgs(1), - RunE: runTCP, + Args: cobra.ExactArgs(1), + RunE: runTCP, + SilenceUsage: true, + SilenceErrors: true, } func init() { diff --git a/internal/client/cli/tunnel_helpers.go b/internal/client/cli/tunnel_helpers.go index 3ece9b3..4863ec2 100644 --- a/internal/client/cli/tunnel_helpers.go +++ b/internal/client/cli/tunnel_helpers.go @@ -24,6 +24,12 @@ func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAd if authToken != "" { daemonArgs = append(daemonArgs, "--token", authToken) } + if authPass != "" { + daemonArgs = append(daemonArgs, "--auth", authPass) + } + if authBearer != "" { + daemonArgs = append(daemonArgs, "--auth-bearer", authBearer) + } if insecure { daemonArgs = append(daemonArgs, "--insecure") } diff --git a/internal/client/tcp/connection_dialer.go b/internal/client/tcp/connection_dialer.go new file mode 100644 index 0000000..b0ea294 --- /dev/null +++ b/internal/client/tcp/connection_dialer.go @@ -0,0 +1,197 @@ +package tcp + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "strings" + "time" + + json "github.com/goccy/go-json" + "github.com/gorilla/websocket" + "go.uber.org/zap" + + "drip/internal/shared/wsutil" +) + +// serverCapabilities holds the discovered server capabilities +type serverCapabilities struct { + Transports []string `json:"transports"` + Preferred string `json:"preferred"` +} + +// ConnectionDialer handles establishing connections to the server. +type ConnectionDialer struct { + serverAddr string + tlsConfig *tls.Config + token string + transport TransportType + logger *zap.Logger +} + +// NewConnectionDialer creates a new connection dialer. +func NewConnectionDialer( + serverAddr string, + tlsConfig *tls.Config, + token string, + transport TransportType, + logger *zap.Logger, +) *ConnectionDialer { + return &ConnectionDialer{ + serverAddr: serverAddr, + tlsConfig: tlsConfig, + token: token, + transport: transport, + logger: logger, + } +} + +// Dial establishes a connection using the appropriate transport. +func (d *ConnectionDialer) Dial() (net.Conn, error) { + switch d.transport { + case TransportWebSocket: + return d.dialWebSocket() + case TransportTCP: + // User explicitly requested TCP, verify server supports it + caps := d.discoverServerCapabilities() + if caps != nil && len(caps.Transports) > 0 { + tcpSupported := false + for _, t := range caps.Transports { + if t == "tcp" { + tcpSupported = true + break + } + } + if !tcpSupported { + return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports) + } + } + return d.dialTLS() + default: // TransportAuto + // Check if server address indicates WebSocket + if strings.HasPrefix(d.serverAddr, "wss://") { + return d.dialWebSocket() + } + // Query server for preferred transport + caps := d.discoverServerCapabilities() + if caps != nil && caps.Preferred == "wss" { + return d.dialWebSocket() + } + // Default to TCP + return d.dialTLS() + } +} + +// dialTLS establishes a TLS connection to the server. +func (d *ConnectionDialer) dialTLS() (net.Conn, error) { + dialer := &net.Dialer{Timeout: 10 * time.Second} + conn, err := tls.DialWithDialer(dialer, "tcp", d.serverAddr, d.tlsConfig) + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + + state := conn.ConnectionState() + if state.Version != tls.VersionTLS13 { + _ = conn.Close() + return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version) + } + + if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok { + _ = tcpConn.SetNoDelay(true) + _ = tcpConn.SetKeepAlive(true) + _ = tcpConn.SetKeepAlivePeriod(30 * time.Second) + _ = tcpConn.SetReadBuffer(256 * 1024) + _ = tcpConn.SetWriteBuffer(256 * 1024) + } + + return conn, nil +} + +// dialWebSocket establishes a WebSocket connection to the server over TLS. +func (d *ConnectionDialer) dialWebSocket() (net.Conn, error) { + // Build WebSocket URL + host, port, err := net.SplitHostPort(d.serverAddr) + if err != nil { + // No port specified, use default + host = d.serverAddr + port = "443" + } + + wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port) + + d.logger.Debug("Connecting via WebSocket over TLS", + zap.String("url", wsURL), + ) + + dialer := websocket.Dialer{ + TLSClientConfig: d.tlsConfig, + HandshakeTimeout: 10 * time.Second, + ReadBufferSize: 256 * 1024, + WriteBufferSize: 256 * 1024, + } + + // Add authorization header if token is set + header := http.Header{} + if d.token != "" { + header.Set("Authorization", "Bearer "+d.token) + } + + ws, resp, err := dialer.Dial(wsURL, header) + if err != nil { + if resp != nil { + return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err) + } + return nil, fmt.Errorf("WebSocket dial failed: %w", err) + } + + d.logger.Info("Connected via WebSocket over TLS", + zap.String("url", wsURL), + ) + + // Wrap WebSocket connection to implement net.Conn with ping loop for CDN keep-alive + return wsutil.NewConnWithPing(ws, 30*time.Second), nil +} + +// discoverServerCapabilities queries the server for its capabilities. +func (d *ConnectionDialer) discoverServerCapabilities() *serverCapabilities { + host, port, err := net.SplitHostPort(d.serverAddr) + if err != nil { + host = d.serverAddr + port = "443" + } + + discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port) + + client := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: d.tlsConfig, + }, + } + + resp, err := client.Get(discoverURL) + if err != nil { + d.logger.Debug("Failed to discover server capabilities", + zap.Error(err), + ) + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + var caps serverCapabilities + if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil { + return nil + } + + d.logger.Debug("Discovered server capabilities", + zap.Strings("transports", caps.Transports), + zap.String("preferred", caps.Preferred), + ) + + return &caps +} diff --git a/internal/client/tcp/connector.go b/internal/client/tcp/connector.go index ce2c5c7..961aa2c 100644 --- a/internal/client/tcp/connector.go +++ b/internal/client/tcp/connector.go @@ -41,7 +41,8 @@ type ConnectorConfig struct { DenyIPs []string // Proxy authentication - AuthPass string + AuthPass string + AuthBearer string // Transport protocol selection Transport TransportType diff --git a/internal/client/tcp/pool_client.go b/internal/client/tcp/pool_client.go index e711a84..b913604 100644 --- a/internal/client/tcp/pool_client.go +++ b/internal/client/tcp/pool_client.go @@ -13,7 +13,6 @@ import ( "sync/atomic" "time" - "github.com/gorilla/websocket" json "github.com/goccy/go-json" "github.com/hashicorp/yamux" "go.uber.org/zap" @@ -22,7 +21,6 @@ import ( "drip/internal/shared/mux" "drip/internal/shared/protocol" "drip/internal/shared/stats" - "drip/internal/shared/wsutil" "drip/pkg/config" ) @@ -71,11 +69,18 @@ type PoolClient struct { allowIPs []string denyIPs []string - authPass string + authPass string + authBearer string // Transport protocol selection transport TransportType insecure bool + + // Connection dialer + dialer *ConnectionDialer + + // Session scaler + scaler *SessionScaler } // NewPoolClient creates a new pool client. @@ -169,8 +174,10 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient { allowIPs: cfg.AllowIPs, denyIPs: cfg.DenyIPs, authPass: cfg.AuthPass, + authBearer: cfg.AuthBearer, transport: transport, insecure: cfg.Insecure, + dialer: NewConnectionDialer(serverAddr, tlsConfig, cfg.Token, transport, logger), } if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS { @@ -183,7 +190,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient { // Connect establishes the primary connection and starts background workers. func (c *PoolClient) Connect() error { - primaryConn, err := c.dial() + primaryConn, err := c.dialer.Dial() if err != nil { return err } @@ -208,9 +215,16 @@ func (c *PoolClient) Connect() error { } } - if c.authPass != "" { + if c.authBearer != "" { + req.ProxyAuth = &protocol.ProxyAuth{ + Enabled: true, + Type: "bearer", + Token: c.authBearer, + } + } else if c.authPass != "" { req.ProxyAuth = &protocol.ProxyAuth{ Enabled: true, + Type: "password", Password: c.authPass, } } @@ -299,8 +313,9 @@ func (c *PoolClient) Connect() error { c.warmupSessions() - c.wg.Add(1) - go c.scalerLoop() + // Initialize and start session scaler + c.scaler = NewSessionScaler(c, c.logger, c.stopCh, &c.wg) + c.scaler.Start() } go func() { @@ -311,162 +326,6 @@ func (c *PoolClient) Connect() error { return nil } -func (c *PoolClient) dialTLS() (net.Conn, error) { - dialer := &net.Dialer{Timeout: 10 * time.Second} - conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig) - if err != nil { - return nil, fmt.Errorf("failed to connect: %w", err) - } - - state := conn.ConnectionState() - if state.Version != tls.VersionTLS13 { - _ = conn.Close() - return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version) - } - - if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok { - _ = tcpConn.SetNoDelay(true) - _ = tcpConn.SetKeepAlive(true) - _ = tcpConn.SetKeepAlivePeriod(30 * time.Second) - _ = tcpConn.SetReadBuffer(256 * 1024) - _ = tcpConn.SetWriteBuffer(256 * 1024) - } - - return conn, nil -} - -// serverCapabilities holds the discovered server capabilities -type serverCapabilities struct { - Transports []string `json:"transports"` - Preferred string `json:"preferred"` -} - -// dial selects the appropriate transport and establishes a connection -func (c *PoolClient) dial() (net.Conn, error) { - switch c.transport { - case TransportWebSocket: - return c.dialWebSocket() - case TransportTCP: - // User explicitly requested TCP, verify server supports it - caps := c.discoverServerCapabilities() - if caps != nil && len(caps.Transports) > 0 { - tcpSupported := false - for _, t := range caps.Transports { - if t == "tcp" { - tcpSupported = true - break - } - } - if !tcpSupported { - return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports) - } - } - return c.dialTLS() - default: // TransportAuto - // Check if server address indicates WebSocket - if strings.HasPrefix(c.serverAddr, "wss://") { - return c.dialWebSocket() - } - // Query server for preferred transport - caps := c.discoverServerCapabilities() - if caps != nil && caps.Preferred == "wss" { - return c.dialWebSocket() - } - // Default to TCP - return c.dialTLS() - } -} - -// discoverServerCapabilities queries the server for its capabilities -func (c *PoolClient) discoverServerCapabilities() *serverCapabilities { - host, port, err := net.SplitHostPort(c.serverAddr) - if err != nil { - host = c.serverAddr - port = "443" - } - - discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port) - - client := &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: c.tlsConfig, - }, - } - - resp, err := client.Get(discoverURL) - if err != nil { - c.logger.Debug("Failed to discover server capabilities", - zap.Error(err), - ) - return nil - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil - } - - var caps serverCapabilities - if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil { - return nil - } - - c.logger.Debug("Discovered server capabilities", - zap.Strings("transports", caps.Transports), - zap.String("preferred", caps.Preferred), - ) - - return &caps -} - -// dialWebSocket establishes a WebSocket connection to the server over TLS -func (c *PoolClient) dialWebSocket() (net.Conn, error) { - // Build WebSocket URL - host, port, err := net.SplitHostPort(c.serverAddr) - if err != nil { - // No port specified, use default - host = c.serverAddr - port = "443" - } - - wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port) - - c.logger.Debug("Connecting via WebSocket over TLS", - zap.String("url", wsURL), - ) - - dialer := websocket.Dialer{ - TLSClientConfig: c.tlsConfig, - HandshakeTimeout: 10 * time.Second, - ReadBufferSize: 256 * 1024, - WriteBufferSize: 256 * 1024, - } - - // Add authorization header if token is set - header := http.Header{} - if c.token != "" { - header.Set("Authorization", "Bearer "+c.token) - } - - ws, resp, err := dialer.Dial(wsURL, header) - if err != nil { - if resp != nil { - return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err) - } - return nil, fmt.Errorf("WebSocket dial failed: %w", err) - } - - // Wrap WebSocket as net.Conn with ping loop for CDN keep-alive - conn := wsutil.NewConnWithPing(ws, 30*time.Second) - - c.logger.Debug("WebSocket connection established", - zap.String("remote_addr", ws.RemoteAddr().String()), - ) - - return conn, nil -} - func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) { defer c.wg.Done() @@ -623,12 +482,12 @@ func (c *PoolClient) Close() error { return closeErr } -func (c *PoolClient) Wait() { <-c.doneCh } -func (c *PoolClient) GetURL() string { return c.assignedURL } -func (c *PoolClient) GetSubdomain() string { return c.subdomain } -func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) } -func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats } -func (c *PoolClient) IsClosed() bool { return c.closed.Load() } +func (c *PoolClient) Wait() { <-c.doneCh } +func (c *PoolClient) GetURL() string { return c.assignedURL } +func (c *PoolClient) GetSubdomain() string { return c.subdomain } +func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) } +func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats } +func (c *PoolClient) IsClosed() bool { return c.closed.Load() } func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) { if cb == nil { diff --git a/internal/client/tcp/pool_session.go b/internal/client/tcp/pool_session.go index 0f4cafb..f7d5e80 100644 --- a/internal/client/tcp/pool_session.go +++ b/internal/client/tcp/pool_session.go @@ -188,7 +188,7 @@ func (c *PoolClient) addDataSession() error { return fmt.Errorf("server does not support data connections") } - conn, err := c.dial() + conn, err := c.dialer.Dial() if err != nil { return err } diff --git a/internal/client/tcp/session_scaler.go b/internal/client/tcp/session_scaler.go new file mode 100644 index 0000000..8b5a89f --- /dev/null +++ b/internal/client/tcp/session_scaler.go @@ -0,0 +1,150 @@ +package tcp + +import ( + "sync" + "time" + + "go.uber.org/zap" +) + +// SessionScaler manages automatic scaling of yamux sessions based on load. +type SessionScaler struct { + client *PoolClient + logger *zap.Logger + stopCh <-chan struct{} + wg *sync.WaitGroup + + // Scaling configuration + checkInterval time.Duration + scaleUpCooldown time.Duration + scaleDownCooldown time.Duration + capacityPerSession int64 + scaleUpLoad float64 + scaleDownLoad float64 + burstThreshold float64 + maxBurstAdd int +} + +// NewSessionScaler creates a new session scaler. +func NewSessionScaler( + client *PoolClient, + logger *zap.Logger, + stopCh <-chan struct{}, + wg *sync.WaitGroup, +) *SessionScaler { + return &SessionScaler{ + client: client, + logger: logger, + stopCh: stopCh, + wg: wg, + checkInterval: 1 * time.Second, + scaleUpCooldown: 1 * time.Second, + scaleDownCooldown: 60 * time.Second, + capacityPerSession: 256, + scaleUpLoad: 0.6, + scaleDownLoad: 0.2, + burstThreshold: 0.9, + maxBurstAdd: 4, + } +} + +// Start starts the scaler loop. +func (s *SessionScaler) Start() { + s.wg.Add(1) + go s.scalerLoop() +} + +// scalerLoop monitors load and adjusts session count. +func (s *SessionScaler) scalerLoop() { + defer s.wg.Done() + + ticker := time.NewTicker(s.checkInterval) + defer ticker.Stop() + + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + } + + s.client.mu.Lock() + desired := s.client.desiredTotal + if desired == 0 { + desired = s.client.initialSessions + s.client.desiredTotal = desired + } + lastScale := s.client.lastScale + s.client.mu.Unlock() + + current := s.client.sessionCount() + if current == 0 { + continue + } + + activeConns := s.client.stats.GetActiveConnections() + capacity := int64(current) * s.capacityPerSession + load := float64(activeConns) / float64(capacity) + + now := time.Now() + + // Burst scaling: rapid scale-up under extreme load + if load > s.burstThreshold && current < s.client.maxSessions { + toAdd := min(s.maxBurstAdd, s.client.maxSessions-current) + s.logger.Info("Burst scaling up sessions", + zap.Int("current", current), + zap.Int("adding", toAdd), + zap.Float64("load", load), + ) + for i := 0; i < toAdd; i++ { + _ = s.client.addDataSession() + } + s.client.mu.Lock() + s.client.desiredTotal = current + toAdd + s.client.lastScale = now + s.client.mu.Unlock() + continue + } + + // Scale up: add sessions when load is high + if load > s.scaleUpLoad && current < s.client.maxSessions { + if now.Sub(lastScale) < s.scaleUpCooldown { + continue + } + newDesired := min(desired+1, s.client.maxSessions) + if newDesired > desired { + s.logger.Debug("Scaling up sessions", + zap.Int("current", current), + zap.Int("desired", newDesired), + zap.Float64("load", load), + ) + _ = s.client.addDataSession() + s.client.mu.Lock() + s.client.desiredTotal = newDesired + s.client.lastScale = now + s.client.mu.Unlock() + } + continue + } + + // Scale down: remove idle sessions when load is low + if load < s.scaleDownLoad && current > s.client.minSessions { + if now.Sub(lastScale) < s.scaleDownCooldown { + continue + } + newDesired := max(desired-1, s.client.minSessions) + if newDesired < desired && current > newDesired { + s.logger.Debug("Scaling down sessions", + zap.Int("current", current), + zap.Int("desired", newDesired), + zap.Float64("load", load), + ) + s.client.removeIdleSessions(1) + s.client.mu.Lock() + s.client.desiredTotal = newDesired + s.client.lastScale = now + s.client.mu.Unlock() + } + } + } +} diff --git a/internal/server/proxy/auth_handler.go b/internal/server/proxy/auth_handler.go new file mode 100644 index 0000000..428d5fd --- /dev/null +++ b/internal/server/proxy/auth_handler.go @@ -0,0 +1,412 @@ +package proxy + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "fmt" + "html" + "net/http" + "strings" + "sync" + "time" + + "drip/internal/server/tunnel" + "drip/internal/shared/protocol" +) + +const authCookieName = "drip_auth" +const authSessionDuration = 24 * time.Hour + +const ( + authRateLimitWindow = 1 * time.Minute + authRateLimitMax = 10 + authRateLimitLockout = 5 * time.Minute + authRateLimitLockoutThreshold = 20 +) + +type authSession struct { + subdomain string + expiresAt time.Time +} + +type authSessionStore struct { + mu sync.RWMutex + sessions map[string]*authSession +} + +type authRateLimitEntry struct { + failures int + windowStart time.Time + lockedUntil time.Time +} + +type authRateLimiter struct { + mu sync.RWMutex + entries map[string]*authRateLimitEntry +} + +var sessionStore = &authSessionStore{ + sessions: make(map[string]*authSession), +} + +var authLimiter = &authRateLimiter{ + entries: make(map[string]*authRateLimitEntry), +} + +func init() { + go authLimiter.startCleanupLoop() + go sessionStore.startCleanupLoop() +} + +func (rl *authRateLimiter) startCleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + rl.cleanup() + } +} + +func (s *authSessionStore) startCleanupLoop() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + s.cleanup() + } +} + +func (s *authSessionStore) cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for token, session := range s.sessions { + if now.After(session.expiresAt) { + delete(s.sessions, token) + } + } +} + +func (rl *authRateLimiter) isRateLimited(ip string) bool { + if ip == "" { + return false + } + + rl.mu.RLock() + entry, exists := rl.entries[ip] + rl.mu.RUnlock() + + if !exists { + return false + } + + now := time.Now() + + if !entry.lockedUntil.IsZero() && now.Before(entry.lockedUntil) { + return true + } + + if now.Sub(entry.windowStart) < authRateLimitWindow && entry.failures >= authRateLimitMax { + return true + } + + return false +} + +func (rl *authRateLimiter) recordFailure(ip string) { + if ip == "" { + return + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + entry, exists := rl.entries[ip] + + if !exists { + rl.entries[ip] = &authRateLimitEntry{ + failures: 1, + windowStart: now, + } + return + } + + if now.Sub(entry.windowStart) >= authRateLimitWindow { + entry.failures = 1 + entry.windowStart = now + entry.lockedUntil = time.Time{} + return + } + + entry.failures++ + + if entry.failures >= authRateLimitLockoutThreshold { + entry.lockedUntil = now.Add(authRateLimitLockout) + } +} + +func (rl *authRateLimiter) resetFailures(ip string) { + if ip == "" { + return + } + + rl.mu.Lock() + delete(rl.entries, ip) + rl.mu.Unlock() +} + +func (rl *authRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for ip, entry := range rl.entries { + windowExpired := now.Sub(entry.windowStart) >= authRateLimitWindow + lockoutExpired := entry.lockedUntil.IsZero() || now.After(entry.lockedUntil) + if windowExpired && lockoutExpired { + delete(rl.entries, ip) + } + } +} + +func (s *authSessionStore) create(subdomain string) string { + token := generateSessionToken() + s.mu.Lock() + s.sessions[token] = &authSession{ + subdomain: subdomain, + expiresAt: time.Now().Add(authSessionDuration), + } + s.mu.Unlock() + return token +} + +func (s *authSessionStore) validate(token, subdomain string) bool { + s.mu.RLock() + session, ok := s.sessions[token] + s.mu.RUnlock() + + if !ok { + return false + } + if time.Now().After(session.expiresAt) { + s.mu.Lock() + delete(s.sessions, token) + s.mu.Unlock() + return false + } + return session.subdomain == subdomain +} + +func generateSessionToken() string { + b := make([]byte, 32) + rand.Read(b) + hash := sha256.Sum256(b) + return hex.EncodeToString(hash[:]) +} + +func isBearerProxyAuth(auth *protocol.ProxyAuth) bool { + if auth == nil { + return false + } + if auth.Type != "" { + return strings.EqualFold(auth.Type, "bearer") + } + return auth.Token != "" +} + +func extractBearerToken(header string) string { + if header == "" { + return "" + } + parts := strings.Fields(header) + if len(parts) < 2 { + return "" + } + if !strings.EqualFold(parts[0], "Bearer") { + return "" + } + return parts[1] +} + +func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool { + cookie, err := r.Cookie(authCookieName + "_" + subdomain) + if err != nil { + return false + } + return sessionStore.validate(cookie.Value, subdomain) +} + +func (h *Handler) isBearerAuthenticated(r *http.Request, auth *protocol.ProxyAuth) bool { + token := extractBearerToken(r.Header.Get("Authorization")) + if token == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(token), []byte(auth.Token)) == 1 +} + +func (h *Handler) serveBearerAuthRequired(w http.ResponseWriter, realm string) { + if realm == "" { + realm = "drip" + } + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm)) + w.Header().Set("Cache-Control", "no-store") + http.Error(w, "Unauthorized: provide bearer token via Authorization header", http.StatusUnauthorized) +} + +func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) { + h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, "") +} + +func (h *Handler) handleProxyLoginWithRateLimit(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string, clientIP string) { + if r.Method != http.MethodPost { + h.serveLoginPage(w, r, subdomain, "") + return + } + + if clientIP != "" && authLimiter.isRateLimited(clientIP) { + w.Header().Set("Retry-After", "60") + http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests) + return + } + + if err := r.ParseForm(); err != nil { + h.serveLoginPage(w, r, subdomain, "Invalid form data") + return + } + + password := r.FormValue("password") + + if !tconn.ValidateProxyAuth(password) { + if clientIP != "" { + authLimiter.recordFailure(clientIP) + } + h.serveLoginPage(w, r, subdomain, "Invalid password") + return + } + + if clientIP != "" { + authLimiter.resetFailures(clientIP) + } + + token := sessionStore.create(subdomain) + http.SetCookie(w, &http.Cookie{ + Name: authCookieName + "_" + subdomain, + Value: token, + Path: "/", + MaxAge: int(authSessionDuration.Seconds()), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + redirectURL := r.FormValue("redirect") + if redirectURL == "" || redirectURL == "/_drip/login" { + redirectURL = "/" + } + http.Redirect(w, r, redirectURL, http.StatusSeeOther) +} + +func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) { + redirectURL := r.URL.Path + if r.URL.RawQuery != "" { + redirectURL += "?" + r.URL.RawQuery + } + if redirectURL == "/_drip/login" { + redirectURL = "/" + } + + errorHTML := "" + if errorMsg != "" { + errorHTML = fmt.Sprintf(`

%s

`, html.EscapeString(errorMsg)) + } + + safeRedirectURL := html.EscapeString(redirectURL) + + htmlContent := fmt.Sprintf(` + + + + + %s - Drip + `+faviconLink+` + + + +
+
+

🔒%s

+

This tunnel is password protected

+
+ + %s +
+ +
+ + +
+
+ + +
+ +`, subdomain, subdomain, errorHTML, safeRedirectURL) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate") + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(htmlContent)) +} diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index 2621bfd..386932f 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -2,12 +2,8 @@ package proxy import ( "bufio" - "context" - "crypto/rand" - "crypto/sha256" - "encoding/hex" + "crypto/subtle" "fmt" - "html" "io" "net" "net/http" @@ -16,18 +12,14 @@ import ( "sync" "time" - json "github.com/goccy/go-json" "github.com/gorilla/websocket" + "go.uber.org/zap" "drip/internal/server/tunnel" "drip/internal/shared/httputil" "drip/internal/shared/netutil" "drip/internal/shared/pool" "drip/internal/shared/protocol" - "drip/internal/shared/wsutil" - - "github.com/prometheus/client_golang/prometheus/promhttp" - "go.uber.org/zap" ) // bufio.Reader pool to reduce allocations on hot path @@ -38,56 +30,14 @@ var bufioReaderPool = sync.Pool{ } const openStreamTimeout = 3 * time.Second -const authCookieName = "drip_auth" -const authSessionDuration = 24 * time.Hour -type authSession struct { - subdomain string - expiresAt time.Time -} - -type authSessionStore struct { - mu sync.RWMutex - sessions map[string]*authSession -} - -var sessionStore = &authSessionStore{ - sessions: make(map[string]*authSession), -} - -func (s *authSessionStore) create(subdomain string) string { - token := generateSessionToken() - s.mu.Lock() - s.sessions[token] = &authSession{ - subdomain: subdomain, - expiresAt: time.Now().Add(authSessionDuration), - } - s.mu.Unlock() - return token -} - -func (s *authSessionStore) validate(token, subdomain string) bool { - s.mu.RLock() - session, ok := s.sessions[token] - s.mu.RUnlock() - - if !ok { - return false - } - if time.Now().After(session.expiresAt) { - s.mu.Lock() - delete(s.sessions, token) - s.mu.Unlock() - return false - } - return session.subdomain == subdomain -} - -func generateSessionToken() string { - b := make([]byte, 32) - rand.Read(b) - hash := sha256.Sum256(b) - return hex.EncodeToString(hash[:]) +type HandlerConfig struct { + Manager *tunnel.Manager + Logger *zap.Logger + ServerDomain string + TunnelDomain string + AuthToken string + MetricsToken string } type Handler struct { @@ -113,32 +63,14 @@ type WSConnectionHandler interface { HandleWSConnection(conn net.Conn, remoteAddr string) } -var privateNetworks []*net.IPNet - -func init() { - privateCIDRs := []string{ - "127.0.0.0/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "::1/128", - "fc00::/7", - "fe80::/10", - } - for _, cidr := range privateCIDRs { - _, ipNet, _ := net.ParseCIDR(cidr) - privateNetworks = append(privateNetworks, ipNet) - } -} - -func NewHandler(manager *tunnel.Manager, logger *zap.Logger, serverDomain, tunnelDomain string, authToken string, metricsToken string) *Handler { +func NewHandler(cfg HandlerConfig) *Handler { return &Handler{ - manager: manager, - logger: logger, - serverDomain: serverDomain, - tunnelDomain: tunnelDomain, - authToken: authToken, - metricsToken: metricsToken, + manager: cfg.Manager, + logger: cfg.Logger, + serverDomain: cfg.ServerDomain, + tunnelDomain: cfg.TunnelDomain, + authToken: cfg.AuthToken, + metricsToken: cfg.MetricsToken, wsUpgrader: websocket.Upgrader{ ReadBufferSize: 256 * 1024, WriteBufferSize: 256 * 1024, @@ -253,22 +185,38 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if tconn.HasIPAccessControl() { - clientIP := h.extractClientIP(r) + clientIP := netutil.ExtractClientIP(r) if !tconn.IsIPAllowed(clientIP) { http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden) return } } - // Check proxy authentication - if tconn.HasProxyAuth() { - if r.URL.Path == "/_drip/login" { - h.handleProxyLogin(w, r, tconn, subdomain) + if auth := tconn.GetProxyAuth(); auth != nil && auth.Enabled { + clientIP := netutil.ExtractClientIP(r) + + if authLimiter.isRateLimited(clientIP) { + w.Header().Set("Retry-After", "60") + http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests) return } - if !h.isProxyAuthenticated(r, subdomain) { - h.serveLoginPage(w, r, subdomain, "") - return + + if isBearerProxyAuth(auth) { + if !h.isBearerAuthenticated(r, auth) { + authLimiter.recordFailure(clientIP) + h.serveBearerAuthRequired(w, "drip") + return + } + authLimiter.resetFailures(clientIP) + } else { + if r.URL.Path == "/_drip/login" { + h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, clientIP) + return + } + if !h.isProxyAuthenticated(r, subdomain) { + h.serveLoginPage(w, r, subdomain, "") + return + } } } @@ -283,14 +231,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if httputil.IsWebSocketUpgrade(r) { + if h.isWebSocketUpgrade(r) { h.handleWebSocket(w, r, tconn) return } stream, err := h.openStreamWithTimeout(tconn) if err != nil { - w.Header().Set("Connection", "close") + httputil.SetCloseConnection(w) http.Error(w, "Tunnel unavailable", http.StatusBadGateway) return } @@ -305,7 +253,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ) if err := r.Write(countingStream); err != nil { - w.Header().Set("Connection", "close") + httputil.SetCloseConnection(w) _ = r.Body.Close() http.Error(w, "Forward failed", http.StatusBadGateway) return @@ -316,7 +264,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp, err := http.ReadResponse(reader, r) if err != nil { bufioReaderPool.Put(reader) - w.Header().Set("Connection", "close") + httputil.SetCloseConnection(w) http.Error(w, "Read response failed", http.StatusBadGateway) return } @@ -334,7 +282,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified { if resp.ContentLength >= 0 { - w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength)) + httputil.SetContentLength(w, resp.ContentLength) } else { w.Header().Del("Content-Length") } @@ -343,7 +291,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if resp.ContentLength >= 0 { - w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength)) + httputil.SetContentLength(w, resp.ContentLength) } else { w.Header().Del("Content-Length") } @@ -396,58 +344,6 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err } } -func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) { - stream, err := h.openStreamWithTimeout(tconn) - if err != nil { - http.Error(w, "Tunnel unavailable", http.StatusBadGateway) - return - } - - tconn.IncActiveConnections() - - hj, ok := w.(http.Hijacker) - if !ok { - stream.Close() - tconn.DecActiveConnections() - http.Error(w, "WebSocket not supported", http.StatusInternalServerError) - return - } - - clientConn, clientBuf, err := hj.Hijack() - if err != nil { - stream.Close() - tconn.DecActiveConnections() - http.Error(w, "Failed to hijack connection", http.StatusInternalServerError) - return - } - - if err := r.Write(stream); err != nil { - stream.Close() - clientConn.Close() - tconn.DecActiveConnections() - return - } - - go func() { - defer stream.Close() - defer clientConn.Close() - defer tconn.DecActiveConnections() - - var clientRW io.ReadWriteCloser = clientConn - if clientBuf != nil && clientBuf.Reader.Buffered() > 0 { - clientRW = &bufferedReadWriteCloser{ - Reader: clientBuf.Reader, - Conn: clientConn, - } - } - - _ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW, - func(n int64) { tconn.AddBytesOut(n) }, - func(n int64) { tconn.AddBytesIn(n) }, - ) - }() -} - func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) { for key, values := range src { canonicalKey := http.CanonicalHeaderKey(key) @@ -530,571 +426,18 @@ func (h *Handler) extractSubdomain(host string) (string, subdomainResult) { return "", subdomainNotFound } -// extractClientIP extracts the client IP from the request. -// It only trusts X-Forwarded-For and X-Real-IP headers when the request -// comes from a private/loopback network (typical reverse proxy setup). -func (h *Handler) extractClientIP(r *http.Request) string { - // First, get the direct remote address - remoteIP := h.extractRemoteIP(r.RemoteAddr) - - // Only trust proxy headers if the request comes from a private network - if isPrivateIP(remoteIP) { - // Check X-Forwarded-For header (may contain multiple IPs) - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - // Take the first IP (original client) - if idx := strings.Index(xff, ","); idx != -1 { - return strings.TrimSpace(xff[:idx]) - } - return strings.TrimSpace(xff) - } - - // Check X-Real-IP header - if xri := r.Header.Get("X-Real-IP"); xri != "" { - return strings.TrimSpace(xri) - } +func (h *Handler) validateMetricsAuth(w http.ResponseWriter, r *http.Request, realm string) bool { + if h.metricsToken == "" { + return true } - // Fall back to remote address - return remoteIP -} + token := extractBearerToken(r.Header.Get("Authorization")) -// extractRemoteIP extracts the IP address from a remote address string (host:port format). -func (h *Handler) extractRemoteIP(remoteAddr string) string { - host, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return remoteAddr - } - return host -} - -// isPrivateIP checks if the given IP is a private/loopback address. -func isPrivateIP(ip string) bool { - parsedIP := net.ParseIP(ip) - if parsedIP == nil { + if subtle.ConstantTimeCompare([]byte(token), []byte(h.metricsToken)) != 1 { + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm)) + http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer ' header", http.StatusUnauthorized) return false } - for _, network := range privateNetworks { - if network.Contains(parsedIP) { - return true - } - } - - return false -} - -func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) { - html := ` - - - - - Drip - Your Tunnel, Your Domain, Anywhere - ` + faviconLink + ` - - - -
-
-

💧Drip

-

Your Tunnel, Your Domain, Anywhere

-
- -

A self-hosted tunneling solution to securely expose your services to the internet.

- -

Install

-
-
bash <(curl -fsSL https://driptunnel.app/install.sh)
- -
- -

Usage

-
-
drip http 3000
- -
-
-
drip https 443
- -
-
-
drip tcp 5432
- -
- - - - -
- - -` - - data := []byte(html) - w.Header().Set("Content-Type", "text/html") - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) - w.Write(data) -} - -func (h *Handler) serveTunnelNotFound(w http.ResponseWriter, r *http.Request) { - html := ` - - - - - 404 - Tunnel Not Found - ` + faviconLink + ` - - - -
-
-

🔍Tunnel Not Found

-

The requested tunnel does not exist or has been closed.

-
- -
-

This could happen because:

-
    -
  • The tunnel was never created
  • -
  • The tunnel has been closed by the owner
  • -
  • The tunnel URL is incorrect
  • -
-
- -

If you are the tunnel owner, please restart your tunnel client.

- - -
- -` - - data := []byte(html) - w.Header().Set("Content-Type", "text/html") - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) - w.WriteHeader(http.StatusNotFound) - w.Write(data) -} - -func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) { - health := map[string]interface{}{ - "status": "ok", - "active_tunnels": h.manager.Count(), - "timestamp": time.Now().Unix(), - } - - data, err := json.Marshal(health) - if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) - w.Write(data) -} - -func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) { - if h.metricsToken != "" { - // Only accept token via Authorization header (Bearer token) - // URL query parameters are insecure (logged, cached, visible in browser history) - var token string - authHeader := r.Header.Get("Authorization") - if strings.HasPrefix(authHeader, "Bearer ") { - token = strings.TrimPrefix(authHeader, "Bearer ") - } - - if token != h.metricsToken { - w.Header().Set("WWW-Authenticate", `Bearer realm="stats"`) - http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer ' header", http.StatusUnauthorized) - return - } - } - - connections := h.manager.List() - - // Pre-allocate slice to avoid O(n²) reallocations - tunnelStats := make([]map[string]interface{}, 0, len(connections)) - for _, conn := range connections { - if conn == nil { - continue - } - tunnelStats = append(tunnelStats, map[string]interface{}{ - "subdomain": conn.Subdomain, - "tunnel_type": string(conn.GetTunnelType()), - "last_active": conn.LastActive.Unix(), - "bytes_in": conn.GetBytesIn(), - "bytes_out": conn.GetBytesOut(), - "active_connections": conn.GetActiveConnections(), - "total_bytes": conn.GetBytesIn() + conn.GetBytesOut(), - }) - } - - stats := map[string]interface{}{ - "total_tunnels": len(tunnelStats), - "tunnels": tunnelStats, - } - - data, err := json.Marshal(stats) - if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) - w.Write(data) -} - -func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { - if h.metricsToken != "" { - // Only accept token via Authorization header (Bearer token) - var token string - authHeader := r.Header.Get("Authorization") - if strings.HasPrefix(authHeader, "Bearer ") { - token = strings.TrimPrefix(authHeader, "Bearer ") - } - - if token != h.metricsToken { - w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`) - http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer ' header", http.StatusUnauthorized) - return - } - } - - // Serve Prometheus metrics - promhttp.Handler().ServeHTTP(w, r) -} - -type bufferedReadWriteCloser struct { - *bufio.Reader - net.Conn -} - -func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) { - return b.Reader.Read(p) -} - -func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool { - cookie, err := r.Cookie(authCookieName + "_" + subdomain) - if err != nil { - return false - } - return sessionStore.validate(cookie.Value, subdomain) -} - -func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) { - if r.Method != http.MethodPost { - h.serveLoginPage(w, r, subdomain, "") - return - } - - if err := r.ParseForm(); err != nil { - h.serveLoginPage(w, r, subdomain, "Invalid form data") - return - } - - password := r.FormValue("password") - - if !tconn.ValidateProxyAuth(password) { - h.serveLoginPage(w, r, subdomain, "Invalid password") - return - } - - token := sessionStore.create(subdomain) - http.SetCookie(w, &http.Cookie{ - Name: authCookieName + "_" + subdomain, - Value: token, - Path: "/", - MaxAge: int(authSessionDuration.Seconds()), - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - - redirectURL := r.FormValue("redirect") - if redirectURL == "" || redirectURL == "/_drip/login" { - redirectURL = "/" - } - http.Redirect(w, r, redirectURL, http.StatusSeeOther) -} - -func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) { - redirectURL := r.URL.Path - if r.URL.RawQuery != "" { - redirectURL += "?" + r.URL.RawQuery - } - if redirectURL == "/_drip/login" { - redirectURL = "/" - } - - errorHTML := "" - if errorMsg != "" { - errorHTML = fmt.Sprintf(`

%s

`, html.EscapeString(errorMsg)) - } - - safeRedirectURL := html.EscapeString(redirectURL) - - htmlContent := fmt.Sprintf(` - - - - - %s - Drip - `+faviconLink+` - - - -
-
-

🔒%s

-

This tunnel is password protected

-
- - %s -
- -
- - -
-
- - -
- -`, subdomain, subdomain, errorHTML, safeRedirectURL) - - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate") - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(htmlContent)) -} - -// handleTunnelWebSocket handles WebSocket connections for tunnel clients -func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) { - // Check if WSS transport is allowed - if !h.IsTransportAllowed("wss") { - http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden) - return - } - - if h.wsConnHandler == nil { - http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable) - return - } - - ws, err := h.wsUpgrader.Upgrade(w, r, nil) - if err != nil { - h.logger.Error("WebSocket upgrade failed", zap.Error(err)) - return - } - - // Configure WebSocket for tunnel use - ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024) - - // Extract real client IP (support CDN headers) - remoteAddr := h.extractClientIP(r) - - h.logger.Info("WebSocket tunnel connection established", - zap.String("remote_addr", remoteAddr), - ) - - // Wrap WebSocket as net.Conn with ping loop for CDN keep-alive - conn := wsutil.NewConnWithPing(ws, 30*time.Second) - - // Handle the connection using the registered handler - h.wsConnHandler.HandleWSConnection(conn, remoteAddr) -} - -// serveDiscovery returns server capabilities for client auto-detection -func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) { - transports := h.allowedTransports - if len(transports) == 0 { - transports = []string{"tcp", "wss"} - } - - tunnelTypes := h.allowedTunnelTypes - if len(tunnelTypes) == 0 { - tunnelTypes = []string{"http", "https", "tcp"} - } - - response := map[string]interface{}{ - "transports": transports, - "tunnel_types": tunnelTypes, - "preferred": h.GetPreferredTransport(), - "version": "1", - } - - data, err := json.Marshal(response) - if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Cache-Control", "no-cache") - w.Write(data) + return true } diff --git a/internal/server/proxy/pages.go b/internal/server/proxy/pages.go new file mode 100644 index 0000000..e4ee07f --- /dev/null +++ b/internal/server/proxy/pages.go @@ -0,0 +1,299 @@ +package proxy + +import ( + "net/http" + "time" + + json "github.com/goccy/go-json" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "drip/internal/shared/httputil" +) + +func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) { + html := ` + + + + + Drip - Your Tunnel, Your Domain, Anywhere + ` + faviconLink + ` + + + +
+
+

💧Drip

+

Your Tunnel, Your Domain, Anywhere

+
+ +

A self-hosted tunneling solution to securely expose your services to the internet.

+ +

Install

+
+
bash <(curl -fsSL https://driptunnel.app/install.sh)
+ +
+ +

Usage

+
+
drip http 3000
+ +
+
+
drip https 443
+ +
+
+
drip tcp 5432
+ +
+ + + + +
+ + +` + + httputil.WriteHTML(w, []byte(html)) +} + +func (h *Handler) serveTunnelNotFound(w http.ResponseWriter, r *http.Request) { + html := ` + + + + + 404 - Tunnel Not Found + ` + faviconLink + ` + + + +
+
+

🔍Tunnel Not Found

+

The requested tunnel does not exist or has been closed.

+
+ +
+

This could happen because:

+
    +
  • The tunnel was never created
  • +
  • The tunnel has been closed by the owner
  • +
  • The tunnel URL is incorrect
  • +
+
+ +

If you are the tunnel owner, please restart your tunnel client.

+ + +
+ +` + + httputil.WriteHTMLWithStatus(w, []byte(html), http.StatusNotFound) +} + +func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) { + health := map[string]interface{}{ + "status": "ok", + "active_tunnels": h.manager.Count(), + "timestamp": time.Now().Unix(), + } + + data, err := json.Marshal(health) + if err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + + httputil.WriteJSON(w, data) +} + +func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) { + if !h.validateMetricsAuth(w, r, "stats") { + return + } + + connections := h.manager.List() + + tunnelStats := make([]map[string]interface{}, 0, len(connections)) + for _, conn := range connections { + if conn == nil { + continue + } + tunnelStats = append(tunnelStats, map[string]interface{}{ + "subdomain": conn.Subdomain, + "tunnel_type": string(conn.GetTunnelType()), + "last_active": conn.LastActive.Unix(), + "bytes_in": conn.GetBytesIn(), + "bytes_out": conn.GetBytesOut(), + "active_connections": conn.GetActiveConnections(), + "total_bytes": conn.GetBytesIn() + conn.GetBytesOut(), + }) + } + + stats := map[string]interface{}{ + "total_tunnels": len(tunnelStats), + "tunnels": tunnelStats, + } + + data, err := json.Marshal(stats) + if err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + + httputil.WriteJSON(w, data) +} + +func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { + if !h.validateMetricsAuth(w, r, "metrics") { + return + } + + promhttp.Handler().ServeHTTP(w, r) +} + +func (h *Handler) serveDiscovery(w http.ResponseWriter, r *http.Request) { + transports := h.allowedTransports + if len(transports) == 0 { + transports = []string{"tcp", "wss"} + } + + tunnelTypes := h.allowedTunnelTypes + if len(tunnelTypes) == 0 { + tunnelTypes = []string{"http", "https", "tcp"} + } + + response := map[string]interface{}{ + "transports": transports, + "tunnel_types": tunnelTypes, + "preferred": h.GetPreferredTransport(), + "version": "1", + } + + data, err := json.Marshal(response) + if err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + + w.Header().Set("Cache-Control", "no-cache") + httputil.WriteJSON(w, data) +} diff --git a/internal/server/proxy/websocket_handler.go b/internal/server/proxy/websocket_handler.go new file mode 100644 index 0000000..41ec952 --- /dev/null +++ b/internal/server/proxy/websocket_handler.go @@ -0,0 +1,113 @@ +package proxy + +import ( + "bufio" + "context" + "io" + "net" + "net/http" + "time" + + "go.uber.org/zap" + + "drip/internal/server/tunnel" + "drip/internal/shared/httputil" + "drip/internal/shared/netutil" + "drip/internal/shared/protocol" + "drip/internal/shared/wsutil" +) + +type bufferedReadWriteCloser struct { + *bufio.Reader + net.Conn +} + +func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) { + return b.Reader.Read(p) +} + +func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) { + stream, err := h.openStreamWithTimeout(tconn) + if err != nil { + http.Error(w, "Tunnel unavailable", http.StatusBadGateway) + return + } + + tconn.IncActiveConnections() + + hj, ok := w.(http.Hijacker) + if !ok { + stream.Close() + tconn.DecActiveConnections() + http.Error(w, "WebSocket not supported", http.StatusInternalServerError) + return + } + + clientConn, clientBuf, err := hj.Hijack() + if err != nil { + stream.Close() + tconn.DecActiveConnections() + http.Error(w, "Failed to hijack connection", http.StatusInternalServerError) + return + } + + if err := r.Write(stream); err != nil { + stream.Close() + clientConn.Close() + tconn.DecActiveConnections() + return + } + + go func() { + defer stream.Close() + defer clientConn.Close() + defer tconn.DecActiveConnections() + + var clientRW io.ReadWriteCloser = clientConn + if clientBuf != nil && clientBuf.Reader.Buffered() > 0 { + clientRW = &bufferedReadWriteCloser{ + Reader: clientBuf.Reader, + Conn: clientConn, + } + } + + _ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW, + func(n int64) { tconn.AddBytesOut(n) }, + func(n int64) { tconn.AddBytesIn(n) }, + ) + }() +} + +func (h *Handler) handleTunnelWebSocket(w http.ResponseWriter, r *http.Request) { + if !h.IsTransportAllowed("wss") { + http.Error(w, "WebSocket transport not allowed on this server", http.StatusForbidden) + return + } + + if h.wsConnHandler == nil { + http.Error(w, "WebSocket tunnel not configured", http.StatusServiceUnavailable) + return + } + + ws, err := h.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Error("WebSocket upgrade failed", zap.Error(err)) + return + } + + ws.SetReadLimit(protocol.MaxFrameSize + protocol.FrameHeaderSize + 1024) + + remoteAddr := netutil.ExtractClientIP(r) + + h.logger.Info("WebSocket tunnel connection established", + zap.String("remote_addr", remoteAddr), + ) + + conn := wsutil.NewConnWithPing(ws, 30*time.Second) + + h.wsConnHandler.HandleWSConnection(conn, remoteAddr) +} + +func (h *Handler) isWebSocketUpgrade(r *http.Request) bool { + return httputil.IsWebSocketUpgrade(r) +} diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index baeb215..ebbf59d 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/http" "strconv" @@ -18,46 +17,54 @@ import ( "drip/internal/server/tunnel" "drip/internal/shared/constants" - "drip/internal/shared/mux" + "drip/internal/shared/httputil" "drip/internal/shared/protocol" "go.uber.org/zap" ) -// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure -var bufioWriterPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriterSize(nil, 4096) - }, +type ConnectionConfig struct { + Conn net.Conn + AuthToken string + Manager *tunnel.Manager + Logger *zap.Logger + PortAlloc *PortAllocator + Domain string + TunnelDomain string + PublicPort int + HTTPHandler http.Handler + GroupManager *ConnectionGroupManager + HTTPListener *connQueueListener } type Connection struct { - conn net.Conn - authToken string - manager *tunnel.Manager - logger *zap.Logger - subdomain string - port int - domain string - tunnelDomain string - publicPort int - portAlloc *PortAllocator - tunnelConn *tunnel.Connection - stopCh chan struct{} - once sync.Once - lastHeartbeat time.Time - mu sync.RWMutex - frameWriter *protocol.FrameWriter - httpHandler http.Handler - tunnelType protocol.TunnelType - ctx context.Context - cancel context.CancelFunc - session *yamux.Session - proxy *Proxy - tunnelID string - groupManager *ConnectionGroupManager - httpListener *connQueueListener - handedOff bool + conn net.Conn + authToken string + manager *tunnel.Manager + logger *zap.Logger + subdomain string + port int + domain string + tunnelDomain string + publicPort int + portAlloc *PortAllocator + tunnelConn *tunnel.Connection + stopCh chan struct{} + once sync.Once + lastHeartbeat time.Time + mu sync.RWMutex + frameWriter *protocol.FrameWriter + httpHandler http.Handler + tunnelType protocol.TunnelType + ctx context.Context + cancel context.CancelFunc + session *yamux.Session + proxy *Proxy + tunnelID string + groupManager *ConnectionGroupManager + httpListener *connQueueListener + handedOff bool + lifecycleManager *ConnectionLifecycleManager // Server capabilities allowedTunnelTypes []string @@ -65,25 +72,32 @@ type Connection struct { } // NewConnection creates a new connection handler -func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection { +func NewConnection(cfg ConnectionConfig) *Connection { ctx, cancel := context.WithCancel(context.Background()) + stopCh := make(chan struct{}) + c := &Connection{ - conn: conn, - authToken: authToken, - manager: manager, - logger: logger, - portAlloc: portAlloc, - domain: domain, - tunnelDomain: tunnelDomain, - publicPort: publicPort, - httpHandler: httpHandler, - stopCh: make(chan struct{}), - lastHeartbeat: time.Now(), - ctx: ctx, - cancel: cancel, - groupManager: groupManager, - httpListener: httpListener, + conn: cfg.Conn, + authToken: cfg.AuthToken, + manager: cfg.Manager, + logger: cfg.Logger, + portAlloc: cfg.PortAlloc, + domain: cfg.Domain, + tunnelDomain: cfg.TunnelDomain, + publicPort: cfg.PublicPort, + httpHandler: cfg.HTTPHandler, + stopCh: stopCh, + lastHeartbeat: time.Now(), + ctx: ctx, + cancel: cancel, + groupManager: cfg.GroupManager, + httpListener: cfg.HTTPListener, + lifecycleManager: NewConnectionLifecycleManager(stopCh, cancel, cfg.Logger), } + + // Set connection in lifecycle manager + c.lifecycleManager.SetConnection(cfg.Conn) + return c } @@ -99,17 +113,7 @@ func (c *Connection) Handle() error { return fmt.Errorf("failed to peek connection: %w", err) } - peekStr := string(peek) - httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"} - isHTTP := false - for _, method := range httpMethods { - if strings.HasPrefix(peekStr, method) { - isHTTP = true - break - } - } - - if isHTTP { + if httputil.IsHTTPRequest(peek) { c.logger.Info("Detected HTTP request on TCP port, handling as HTTP") return c.handleHTTPRequest(reader) } @@ -128,7 +132,24 @@ func (c *Connection) Handle() error { defer sf.Close() if sf.Frame.Type == protocol.FrameTypeDataConnect { - return c.handleDataConnect(sf.Frame, reader) + handler := NewDataConnectionHandler( + c.conn, + reader, + c.authToken, + c.groupManager, + c.stopCh, + c.logger, + ) + handler.SetSessionCreatedHandler(func(session *yamux.Session) { + c.session = session + if c.lifecycleManager != nil { + c.lifecycleManager.SetSession(session) + } + }) + handler.SetTunnelIDHandler(func(tunnelID string) { + c.tunnelID = tunnelID + }) + return handler.Handle(sf.Frame) } if sf.Frame.Type != protocol.FrameTypeRegister { @@ -153,121 +174,70 @@ func (c *Connection) Handle() error { return fmt.Errorf("authentication failed") } - if req.TunnelType == protocol.TunnelTypeTCP { - if c.portAlloc == nil { - return fmt.Errorf("port allocator not configured") - } - - if requestedPort, ok := parseTCPSubdomainPort(req.CustomSubdomain); ok { - port, err := c.portAlloc.AllocateSpecific(requestedPort) - if err != nil { - c.sendError("port_allocation_failed", err.Error()) - return fmt.Errorf("failed to allocate requested port %d: %w", requestedPort, err) - } - c.port = port - } else { - port, err := c.portAlloc.Allocate() - if err != nil { - c.sendError("port_allocation_failed", err.Error()) - return fmt.Errorf("failed to allocate port: %w", err) - } - c.port = port - - if req.CustomSubdomain == "" { - req.CustomSubdomain = fmt.Sprintf("tcp-%d", port) - } - } - } - - subdomain, err := c.manager.Register(nil, req.CustomSubdomain) - if err != nil { - c.sendError("registration_failed", err.Error()) - c.portAlloc.Release(c.port) - c.port = 0 - return fmt.Errorf("tunnel registration failed: %w", err) - } - - c.subdomain = subdomain - - tunnelConn, ok := c.manager.Get(subdomain) - if !ok { - return fmt.Errorf("failed to get registered tunnel") - } - c.tunnelConn = tunnelConn - - c.tunnelConn.Conn = nil - c.tunnelConn.SetTunnelType(req.TunnelType) - c.tunnelType = req.TunnelType - - if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) { - c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs) - c.logger.Info("IP access control configured", - zap.String("subdomain", subdomain), - zap.Strings("allow_ips", req.IPAccess.AllowIPs), - zap.Strings("deny_ips", req.IPAccess.DenyIPs), - ) - } - - if req.ProxyAuth != nil && req.ProxyAuth.Enabled { - c.tunnelConn.SetProxyAuth(req.ProxyAuth) - c.logger.Info("Proxy authentication configured", - zap.String("subdomain", subdomain), - ) - } - - c.logger.Info("Tunnel registered", - zap.String("subdomain", subdomain), - zap.String("tunnel_type", string(req.TunnelType)), - zap.Int("local_port", req.LocalPort), - zap.Int("remote_port", c.port), + // Use RegistrationHandler for registration logic + regHandler := NewRegistrationHandler( + c.manager, + c.portAlloc, + c.groupManager, + c.domain, + c.tunnelDomain, + c.publicPort, + c.logger, ) - var tunnelURL string - if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS { - if c.publicPort == 443 { - tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.tunnelDomain) - } else { - tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.tunnelDomain, c.publicPort) - } - } else { - tunnelURL = fmt.Sprintf("tcp://%s:%d", c.tunnelDomain, c.port) + regReq := &RegistrationRequest{ + TunnelType: req.TunnelType, + CustomSubdomain: req.CustomSubdomain, + Token: req.Token, + ConnectionType: req.ConnectionType, + PoolCapabilities: req.PoolCapabilities, + IPAccess: req.IPAccess, + ProxyAuth: req.ProxyAuth, + LocalPort: req.LocalPort, } - var tunnelID string - var supportsDataConn bool - recommendedConns := 0 + result, err := regHandler.Register(regReq) + if err != nil { + c.sendError("registration_failed", err.Error()) + return fmt.Errorf("registration failed: %w", err) + } - if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil { - group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType) - tunnelID = group.TunnelID - c.tunnelID = tunnelID - supportsDataConn = true - recommendedConns = 4 + // Store registration results + c.subdomain = result.Subdomain + c.port = result.Port + c.tunnelConn = result.TunnelConn + c.tunnelConn.Conn = nil + + // Update lifecycle manager with registration info + if c.lifecycleManager != nil { + c.lifecycleManager.SetPortAllocation(c.portAlloc, c.port) + c.lifecycleManager.SetTunnelRegistration(c.manager, c.subdomain, "", c.groupManager) + } + + // Handle connection groups + if result.SupportsDataConn && c.groupManager != nil { + group := c.groupManager.CreateGroup(result.Subdomain, req.Token, c, req.TunnelType) + result.TunnelID = group.TunnelID + c.tunnelID = result.TunnelID + + // Update lifecycle manager with tunnel ID + if c.lifecycleManager != nil { + c.lifecycleManager.SetTunnelRegistration(c.manager, c.subdomain, c.tunnelID, c.groupManager) + } c.logger.Info("Created connection group for multi-connection support", - zap.String("tunnel_id", tunnelID), + zap.String("tunnel_id", result.TunnelID), zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns), ) } - resp := protocol.RegisterResponse{ - Subdomain: subdomain, - Port: c.port, - URL: tunnelURL, - Message: "Tunnel registered successfully", - TunnelID: tunnelID, - SupportsDataConn: supportsDataConn, - RecommendedConns: recommendedConns, + // Build and send registration response + resp, err := regHandler.BuildRegistrationResponse(result) + if err != nil { + return fmt.Errorf("failed to build registration response: %w", err) } - respData, err := json.Marshal(resp) - if err != nil { - return fmt.Errorf("failed to marshal registration response: %w", err) - } - ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData) - - err = protocol.WriteFrame(c.conn, ackFrame) - if err != nil { + if err := regHandler.SendRegistrationResponse(c.conn, resp); err != nil { return fmt.Errorf("failed to send registration ack: %w", err) } @@ -282,6 +252,11 @@ func (c *Connection) Handle() error { c.frameWriter = protocol.NewFrameWriter(c.conn) + // Update lifecycle manager with frame writer + if c.lifecycleManager != nil { + c.lifecycleManager.SetFrameWriter(c.frameWriter) + } + c.frameWriter.SetWriteErrorHandler(func(err error) { c.logger.Error("Write error detected, closing connection", zap.Error(err)) c.Close() @@ -289,39 +264,30 @@ func (c *Connection) Handle() error { go c.heartbeatChecker() - return c.handleFrames(reader) + // Use FrameHandler for frame processing + frameHandler := NewFrameHandler(c.conn, reader, c.stopCh, c.frameWriter, c.logger) + frameHandler.SetHeartbeatHandler(func() { + c.handleHeartbeat() + }) + frameHandler.SetCloseHandler(func() { + c.logger.Info("Client requested close") + }) + + return frameHandler.HandleFrames() } func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error { - if c.httpListener == nil { - return c.handleHTTPRequestLegacy(reader) - } - - c.conn.SetReadDeadline(time.Time{}) - - wrappedConn := &bufferedConn{ - Conn: c.conn, - reader: reader, - } - - if !c.httpListener.Enqueue(wrappedConn) { - c.logger.Warn("HTTP listener queue full, rejecting connection") - response := "HTTP/1.1 503 Service Unavailable\r\n" + - "Content-Type: text/plain\r\n" + - "Content-Length: 32\r\n" + - "Connection: close\r\n" + - "\r\n" + - "Server busy, please retry later\r\n" - c.conn.Write([]byte(response)) - return fmt.Errorf("http listener queue full") - } - - c.mu.Lock() - c.conn = nil - c.handedOff = true - c.mu.Unlock() - - return nil + handler := NewHTTPRequestHandler( + c.conn, + reader, + c.httpHandler, + c.httpListener, + c.ctx, + c.logger, + &c.mu, + &c.handedOff, + ) + return handler.Handle() } func (c *Connection) IsHandedOff() bool { @@ -330,113 +296,6 @@ func (c *Connection) IsHandedOff() bool { return c.handedOff } -func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error { - if c.httpHandler == nil { - c.logger.Warn("HTTP request received but no HTTP handler configured") - response := "HTTP/1.1 503 Service Unavailable\r\n" + - "Content-Type: text/plain\r\n" + - "Content-Length: 47\r\n" + - "\r\n" + - "HTTP handler not configured for this TCP port\r\n" - c.conn.Write([]byte(response)) - return fmt.Errorf("HTTP handler not configured") - } - - c.conn.SetReadDeadline(time.Time{}) - - for { - c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - - req, err := http.ReadRequest(reader) - if err != nil { - if err == io.EOF || err == io.ErrUnexpectedEOF { - c.logger.Debug("Client closed HTTP connection") - return nil - } - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - c.logger.Debug("HTTP keep-alive timeout") - return nil - } - errStr := err.Error() - if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") { - c.logger.Debug("HTTP connection closed during read", zap.Error(err)) - return nil - } - if strings.Contains(errStr, "connection reset by peer") || - strings.Contains(errStr, "broken pipe") || - strings.Contains(errStr, "connection refused") { - c.logger.Debug("Client disconnected abruptly", zap.Error(err)) - return nil - } - if strings.Contains(errStr, "malformed HTTP") { - c.logger.Warn("Received malformed HTTP request", - zap.Error(err), - zap.String("error_snippet", errStr[:min(len(errStr), 100)]), - ) - return nil - } - c.logger.Error("Failed to parse HTTP request", zap.Error(err)) - return fmt.Errorf("failed to parse HTTP request: %w", err) - } - - if c.ctx != nil { - req = req.WithContext(c.ctx) - } - - c.logger.Info("Processing HTTP request on TCP port", - zap.String("method", req.Method), - zap.String("url", req.URL.String()), - zap.String("host", req.Host), - ) - - // Get writer from pool to reduce GC pressure - pooledWriter := bufioWriterPool.Get().(*bufio.Writer) - pooledWriter.Reset(c.conn) - - respWriter := &httpResponseWriter{ - conn: c.conn, - writer: pooledWriter, - header: make(http.Header), - } - - c.httpHandler.ServeHTTP(respWriter, req) - - if err := respWriter.writer.Flush(); err != nil { - c.logger.Debug("Failed to flush HTTP response", zap.Error(err)) - } - - // Return writer to pool - pooledWriter.Reset(nil) // Clear reference to connection - bufioWriterPool.Put(pooledWriter) - - // Keep TCP_NODELAY enabled for low latency HTTP responses - // (removed the toggle that was disabling it) - - c.logger.Debug("HTTP request processing completed", - zap.String("method", req.Method), - zap.String("url", req.URL.String()), - ) - - shouldClose := false - if req.Close { - shouldClose = true - } else if req.ProtoMajor == 1 && req.ProtoMinor == 0 { - if req.Header.Get("Connection") != "keep-alive" { - shouldClose = true - } - } - - if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" { - shouldClose = true - } - - if shouldClose { - c.logger.Debug("Closing connection as requested by client or server") - return nil - } - } -} - func parseTCPSubdomainPort(subdomain string) (int, bool) { if !strings.HasPrefix(subdomain, "tcp-") { return 0, false @@ -455,55 +314,6 @@ func parseTCPSubdomainPort(subdomain string) (int, bool) { return port, true } -func (c *Connection) handleFrames(reader *bufio.Reader) error { - for { - select { - case <-c.stopCh: - return nil - default: - } - - c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout)) - frame, err := protocol.ReadFrame(reader) - if err != nil { - if isTimeoutError(err) { - c.logger.Warn("Read timeout, connection may be dead") - return fmt.Errorf("read timeout") - } - if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" { - c.logger.Info("Client disconnected") - return nil - } - select { - case <-c.stopCh: - c.logger.Debug("Connection closed during shutdown") - return nil - default: - return fmt.Errorf("failed to read frame: %w", err) - } - } - - sf := protocol.WithFrame(frame) - - switch sf.Frame.Type { - case protocol.FrameTypeHeartbeat: - c.handleHeartbeat() - sf.Close() - - case protocol.FrameTypeClose: - sf.Close() - c.logger.Info("Client requested close") - return nil - - default: - sf.Close() - c.logger.Warn("Unexpected frame type", - zap.String("type", sf.Frame.Type.String()), - ) - } - } -} - func (c *Connection) handleHeartbeat() { c.mu.Lock() c.lastHeartbeat = time.Now() @@ -562,188 +372,71 @@ func (c *Connection) sendError(code, message string) { func (c *Connection) Close() { c.once.Do(func() { - protocol.UnregisterConnection() - close(c.stopCh) + // Check if connection was handed off to HTTP handler + c.mu.RLock() + handedOff := c.handedOff + c.mu.RUnlock() - if c.cancel != nil { - c.cancel() + // If handed off, don't close the connection - HTTP handler owns it now + if handedOff { + c.logger.Debug("Connection handed off to HTTP handler, skipping close") + return } - // Prevent race with handleHTTPRequest setting c.conn = nil - c.mu.Lock() - conn := c.conn - c.mu.Unlock() + // Use lifecycle manager for cleanup + if c.lifecycleManager != nil { + c.lifecycleManager.Close() + } else { + // Fallback if lifecycle manager not initialized + protocol.UnregisterConnection() + close(c.stopCh) - if conn != nil { - _ = conn.SetDeadline(time.Now()) - } - - if c.frameWriter != nil { - c.frameWriter.Close() - } - - if c.proxy != nil { - c.proxy.Stop() - } - - if c.session != nil { - _ = c.session.Close() - } - - if conn != nil { - conn.Close() - } - - if c.port > 0 && c.portAlloc != nil { - c.portAlloc.Release(c.port) - } - - if c.subdomain != "" { - c.manager.Unregister(c.subdomain) - if c.tunnelID != "" && c.groupManager != nil { - c.groupManager.RemoveGroup(c.tunnelID) + if c.cancel != nil { + c.cancel() } - } - c.logger.Info("Connection closed", - zap.String("subdomain", c.subdomain), - ) + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn != nil { + _ = conn.SetDeadline(time.Now()) + } + + if c.frameWriter != nil { + c.frameWriter.Close() + } + + if c.proxy != nil { + c.proxy.Stop() + } + + if c.session != nil { + _ = c.session.Close() + } + + if conn != nil { + conn.Close() + } + + if c.port > 0 && c.portAlloc != nil { + c.portAlloc.Release(c.port) + } + + if c.subdomain != "" { + c.manager.Unregister(c.subdomain) + if c.tunnelID != "" && c.groupManager != nil { + c.groupManager.RemoveGroup(c.tunnelID) + } + } + + c.logger.Info("Connection closed", + zap.String("subdomain", c.subdomain), + ) + } }) } -type httpResponseWriter struct { - conn net.Conn - writer *bufio.Writer - header http.Header - statusCode int - headerWritten bool -} - -func (w *httpResponseWriter) Header() http.Header { - return w.header -} - -func (w *httpResponseWriter) WriteHeader(statusCode int) { - if w.headerWritten { - return - } - w.statusCode = statusCode - w.headerWritten = true - - statusText := http.StatusText(statusCode) - if statusText == "" { - statusText = "Unknown" - } - - w.writer.WriteString("HTTP/1.1 ") - w.writer.WriteString(fmt.Sprintf("%d", statusCode)) - w.writer.WriteByte(' ') - w.writer.WriteString(statusText) - w.writer.WriteString("\r\n") - - for key, values := range w.header { - for _, value := range values { - w.writer.WriteString(key) - w.writer.WriteString(": ") - w.writer.WriteString(value) - w.writer.WriteString("\r\n") - } - } - - w.writer.WriteString("\r\n") -} - -func (w *httpResponseWriter) Write(data []byte) (int, error) { - if !w.headerWritten { - w.WriteHeader(http.StatusOK) - } - return w.writer.Write(data) -} - -func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error { - var req protocol.DataConnectRequest - if err := json.Unmarshal(frame.Payload, &req); err != nil { - c.sendError("invalid_request", "Failed to parse data connect request") - return fmt.Errorf("failed to parse data connect request: %w", err) - } - - c.logger.Info("Data connection request received", - zap.String("tunnel_id", req.TunnelID), - zap.String("connection_id", req.ConnectionID), - ) - - if c.groupManager == nil { - c.sendDataConnectError("not_supported", "Multi-connection not supported") - return fmt.Errorf("group manager not available") - } - - if c.authToken != "" && req.Token != c.authToken { - c.sendDataConnectError("authentication_failed", "Invalid authentication token") - return fmt.Errorf("authentication failed for data connection") - } - - group, ok := c.groupManager.GetGroup(req.TunnelID) - if !ok || group == nil { - c.sendDataConnectError("join_failed", "Tunnel not found") - return fmt.Errorf("tunnel not found: %s", req.TunnelID) - } - - if group.Token != "" && req.Token != group.Token { - c.sendDataConnectError("authentication_failed", "Invalid authentication token") - return fmt.Errorf("authentication failed for data connection") - } - - c.tunnelID = req.TunnelID - - resp := protocol.DataConnectResponse{ - Accepted: true, - ConnectionID: req.ConnectionID, - Message: "Data connection accepted", - } - - respData, err := json.Marshal(resp) - if err != nil { - return fmt.Errorf("failed to marshal data connect response: %w", err) - } - ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) - - if err := protocol.WriteFrame(c.conn, ackFrame); err != nil { - return fmt.Errorf("failed to send data connect ack: %w", err) - } - - c.logger.Info("Data connection established", - zap.String("tunnel_id", req.TunnelID), - zap.String("connection_id", req.ConnectionID), - ) - - _ = c.conn.SetReadDeadline(time.Time{}) - - // Server acts as yamux Client, client connector acts as yamux Server - bc := &bufferedConn{ - Conn: c.conn, - reader: reader, - } - - // Use optimized mux config for server - cfg := mux.NewServerConfig() - - session, err := yamux.Client(bc, cfg) - if err != nil { - return fmt.Errorf("failed to init yamux session: %w", err) - } - c.session = session - - group.AddSession(req.ConnectionID, session) - defer group.RemoveSession(req.ConnectionID) - - select { - case <-c.stopCh: - return nil - case <-session.CloseChan(): - return nil - } -} - func isTimeoutError(err error) bool { if err == nil { return false @@ -755,20 +448,6 @@ func isTimeoutError(err error) bool { return strings.Contains(err.Error(), "i/o timeout") } -func (c *Connection) sendDataConnectError(code, message string) { - resp := protocol.DataConnectResponse{ - Accepted: false, - Message: fmt.Sprintf("%s: %s", code, message), - } - respData, err := json.Marshal(resp) - if err != nil { - c.logger.Error("Failed to marshal data connect error", zap.Error(err)) - return - } - frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) - _ = protocol.WriteFrame(c.conn, frame) -} - // SetAllowedTunnelTypes sets the allowed tunnel types for this connection func (c *Connection) SetAllowedTunnelTypes(types []string) { c.allowedTunnelTypes = types diff --git a/internal/server/tcp/connection_group.go b/internal/server/tcp/connection_group.go index 3b0a207..95d1ac4 100644 --- a/internal/server/tcp/connection_group.go +++ b/internal/server/tcp/connection_group.go @@ -17,10 +17,10 @@ import ( // sessionEntry represents a session with its current stream count for heap operations type sessionEntry struct { - id string - session *yamux.Session - streams int - heapIdx int // index in the heap, managed by heap.Interface + id string + session *yamux.Session + streams int + heapIdx int // index in the heap, managed by heap.Interface } // sessionHeap implements heap.Interface for O(log n) session selection diff --git a/internal/server/tcp/data_connection_handler.go b/internal/server/tcp/data_connection_handler.go new file mode 100644 index 0000000..c23bbe0 --- /dev/null +++ b/internal/server/tcp/data_connection_handler.go @@ -0,0 +1,161 @@ +package tcp + +import ( + "bufio" + "fmt" + "net" + "time" + + json "github.com/goccy/go-json" + "github.com/hashicorp/yamux" + "go.uber.org/zap" + + "drip/internal/shared/mux" + "drip/internal/shared/protocol" +) + +// DataConnectionHandler handles data connection requests for multi-connection support. +type DataConnectionHandler struct { + conn net.Conn + reader *bufio.Reader + authToken string + groupManager *ConnectionGroupManager + stopCh <-chan struct{} + logger *zap.Logger + onSessionCreated func(*yamux.Session) + onTunnelIDSet func(string) +} + +// NewDataConnectionHandler creates a new data connection handler. +func NewDataConnectionHandler( + conn net.Conn, + reader *bufio.Reader, + authToken string, + groupManager *ConnectionGroupManager, + stopCh <-chan struct{}, + logger *zap.Logger, +) *DataConnectionHandler { + return &DataConnectionHandler{ + conn: conn, + reader: reader, + authToken: authToken, + groupManager: groupManager, + stopCh: stopCh, + logger: logger, + } +} + +// SetSessionCreatedHandler sets the callback for when a session is created. +func (h *DataConnectionHandler) SetSessionCreatedHandler(handler func(*yamux.Session)) { + h.onSessionCreated = handler +} + +// SetTunnelIDHandler sets the callback for when tunnel ID is set. +func (h *DataConnectionHandler) SetTunnelIDHandler(handler func(string)) { + h.onTunnelIDSet = handler +} + +// Handle processes the data connection request. +func (h *DataConnectionHandler) Handle(frame *protocol.Frame) error { + var req protocol.DataConnectRequest + if err := json.Unmarshal(frame.Payload, &req); err != nil { + h.sendError("invalid_request", "Failed to parse data connect request") + return fmt.Errorf("failed to parse data connect request: %w", err) + } + + h.logger.Info("Data connection request received", + zap.String("tunnel_id", req.TunnelID), + zap.String("connection_id", req.ConnectionID), + ) + + if h.groupManager == nil { + h.sendError("not_supported", "Multi-connection not supported") + return fmt.Errorf("group manager not available") + } + + if h.authToken != "" && req.Token != h.authToken { + h.sendError("authentication_failed", "Invalid authentication token") + return fmt.Errorf("authentication failed for data connection") + } + + group, ok := h.groupManager.GetGroup(req.TunnelID) + if !ok || group == nil { + h.sendError("join_failed", "Tunnel not found") + return fmt.Errorf("tunnel not found: %s", req.TunnelID) + } + + if group.Token != "" && req.Token != group.Token { + h.sendError("authentication_failed", "Invalid authentication token") + return fmt.Errorf("authentication failed for data connection") + } + + if h.onTunnelIDSet != nil { + h.onTunnelIDSet(req.TunnelID) + } + + resp := protocol.DataConnectResponse{ + Accepted: true, + ConnectionID: req.ConnectionID, + Message: "Data connection accepted", + } + + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal data connect response: %w", err) + } + ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) + + if err := protocol.WriteFrame(h.conn, ackFrame); err != nil { + return fmt.Errorf("failed to send data connect ack: %w", err) + } + + h.logger.Info("Data connection established", + zap.String("tunnel_id", req.TunnelID), + zap.String("connection_id", req.ConnectionID), + ) + + _ = h.conn.SetReadDeadline(time.Time{}) + + // Server acts as yamux Client, client connector acts as yamux Server + bc := &bufferedConn{ + Conn: h.conn, + reader: h.reader, + } + + // Use optimized mux config for server + cfg := mux.NewServerConfig() + + session, err := yamux.Client(bc, cfg) + if err != nil { + return fmt.Errorf("failed to init yamux session: %w", err) + } + + if h.onSessionCreated != nil { + h.onSessionCreated(session) + } + + group.AddSession(req.ConnectionID, session) + defer group.RemoveSession(req.ConnectionID) + + select { + case <-h.stopCh: + return nil + case <-session.CloseChan(): + return nil + } +} + +// sendError sends an error response to the client. +func (h *DataConnectionHandler) sendError(code, message string) { + resp := protocol.DataConnectResponse{ + Accepted: false, + Message: fmt.Sprintf("%s: %s", code, message), + } + respData, err := json.Marshal(resp) + if err != nil { + h.logger.Error("Failed to marshal data connect error", zap.Error(err)) + return + } + frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) + _ = protocol.WriteFrame(h.conn, frame) +} diff --git a/internal/server/tcp/frame_handler.go b/internal/server/tcp/frame_handler.go new file mode 100644 index 0000000..39cf73e --- /dev/null +++ b/internal/server/tcp/frame_handler.go @@ -0,0 +1,122 @@ +package tcp + +import ( + "bufio" + "fmt" + "net" + "time" + + "drip/internal/shared/constants" + "drip/internal/shared/protocol" + "go.uber.org/zap" +) + +// FrameHandler handles protocol frame reading and processing. +type FrameHandler struct { + conn net.Conn + reader *bufio.Reader + stopCh <-chan struct{} + logger *zap.Logger + frameWriter *protocol.FrameWriter + + // Heartbeat tracking + onHeartbeat func() + onClose func() +} + +// NewFrameHandler creates a new frame handler. +func NewFrameHandler( + conn net.Conn, + reader *bufio.Reader, + stopCh <-chan struct{}, + frameWriter *protocol.FrameWriter, + logger *zap.Logger, +) *FrameHandler { + return &FrameHandler{ + conn: conn, + reader: reader, + stopCh: stopCh, + frameWriter: frameWriter, + logger: logger, + } +} + +// SetHeartbeatHandler sets the callback for heartbeat frames. +func (fh *FrameHandler) SetHeartbeatHandler(handler func()) { + fh.onHeartbeat = handler +} + +// SetCloseHandler sets the callback for close frames. +func (fh *FrameHandler) SetCloseHandler(handler func()) { + fh.onClose = handler +} + +// HandleFrames processes incoming frames in a loop. +func (fh *FrameHandler) HandleFrames() error { + for { + select { + case <-fh.stopCh: + return nil + default: + } + + fh.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout)) + frame, err := protocol.ReadFrame(fh.reader) + if err != nil { + return fh.handleReadError(err) + } + + sf := protocol.WithFrame(frame) + err = fh.processFrame(sf) + sf.Close() + + if err != nil { + return err + } + } +} + +// handleReadError handles errors that occur while reading frames. +func (fh *FrameHandler) handleReadError(err error) error { + if isTimeoutError(err) { + fh.logger.Warn("Read timeout, connection may be dead") + return fmt.Errorf("read timeout") + } + + if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" { + fh.logger.Info("Client disconnected") + return nil + } + + select { + case <-fh.stopCh: + fh.logger.Debug("Connection closed during shutdown") + return nil + default: + return fmt.Errorf("failed to read frame: %w", err) + } +} + +// processFrame processes a single frame based on its type. +func (fh *FrameHandler) processFrame(sf *protocol.SafeFrame) error { + switch sf.Frame.Type { + case protocol.FrameTypeHeartbeat: + if fh.onHeartbeat != nil { + fh.onHeartbeat() + } + return nil + + case protocol.FrameTypeClose: + fh.logger.Info("Client requested close") + if fh.onClose != nil { + fh.onClose() + } + return fmt.Errorf("client requested close") + + default: + fh.logger.Warn("Unexpected frame type", + zap.String("type", sf.Frame.Type.String()), + ) + return nil + } +} diff --git a/internal/server/tcp/http_request_handler.go b/internal/server/tcp/http_request_handler.go new file mode 100644 index 0000000..23a7706 --- /dev/null +++ b/internal/server/tcp/http_request_handler.go @@ -0,0 +1,251 @@ +package tcp + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "go.uber.org/zap" +) + +// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure +var bufioWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, 4096) + }, +} + +// HTTPRequestHandler handles HTTP requests on TCP connections. +type HTTPRequestHandler struct { + conn net.Conn + reader *bufio.Reader + httpHandler http.Handler + httpListener *connQueueListener + ctx interface{ Done() <-chan struct{} } + logger *zap.Logger + mu *sync.RWMutex + handedOff *bool +} + +// NewHTTPRequestHandler creates a new HTTP request handler. +func NewHTTPRequestHandler( + conn net.Conn, + reader *bufio.Reader, + httpHandler http.Handler, + httpListener *connQueueListener, + ctx interface{ Done() <-chan struct{} }, + logger *zap.Logger, + mu *sync.RWMutex, + handedOff *bool, +) *HTTPRequestHandler { + return &HTTPRequestHandler{ + conn: conn, + reader: reader, + httpHandler: httpHandler, + httpListener: httpListener, + ctx: ctx, + logger: logger, + mu: mu, + handedOff: handedOff, + } +} + +// Handle processes the HTTP request. +func (h *HTTPRequestHandler) Handle() error { + if h.httpListener == nil { + return h.handleLegacy() + } + + h.conn.SetReadDeadline(time.Time{}) + + wrappedConn := &bufferedConn{ + Conn: h.conn, + reader: h.reader, + } + + if !h.httpListener.Enqueue(wrappedConn) { + h.logger.Warn("HTTP listener queue full, rejecting connection") + response := "HTTP/1.1 503 Service Unavailable\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 32\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Server busy, please retry later\r\n" + h.conn.Write([]byte(response)) + return fmt.Errorf("http listener queue full") + } + + h.mu.Lock() + *h.handedOff = true + h.mu.Unlock() + + return nil +} + +// handleLegacy processes HTTP requests using the legacy handler. +func (h *HTTPRequestHandler) handleLegacy() error { + if h.httpHandler == nil { + h.logger.Warn("HTTP request received but no HTTP handler configured") + response := "HTTP/1.1 503 Service Unavailable\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 47\r\n" + + "\r\n" + + "HTTP handler not configured for this TCP port\r\n" + h.conn.Write([]byte(response)) + return fmt.Errorf("HTTP handler not configured") + } + + h.conn.SetReadDeadline(time.Time{}) + + for { + h.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + req, err := http.ReadRequest(h.reader) + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + h.logger.Debug("Client closed HTTP connection") + return nil + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + h.logger.Debug("HTTP keep-alive timeout") + return nil + } + errStr := err.Error() + if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") { + h.logger.Debug("HTTP connection closed during read", zap.Error(err)) + return nil + } + if strings.Contains(errStr, "connection reset by peer") || + strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "connection refused") { + h.logger.Debug("Client disconnected abruptly", zap.Error(err)) + return nil + } + if strings.Contains(errStr, "malformed HTTP") { + h.logger.Warn("Received malformed HTTP request", + zap.Error(err), + zap.String("error_snippet", errStr[:min(len(errStr), 100)]), + ) + return nil + } + h.logger.Error("Failed to parse HTTP request", zap.Error(err)) + return fmt.Errorf("failed to parse HTTP request: %w", err) + } + + if h.ctx != nil { + if ctxWithContext, ok := h.ctx.(interface{ Done() <-chan struct{} }); ok { + req = req.WithContext(ctxWithContext.(interface { + Done() <-chan struct{} + Deadline() (deadline time.Time, ok bool) + Err() error + Value(key interface{}) interface{} + })) + } + } + + h.logger.Info("Processing HTTP request on TCP port", + zap.String("method", req.Method), + zap.String("url", req.URL.String()), + zap.String("host", req.Host), + ) + + // Get writer from pool to reduce GC pressure + pooledWriter := bufioWriterPool.Get().(*bufio.Writer) + pooledWriter.Reset(h.conn) + + respWriter := &httpResponseWriter{ + conn: h.conn, + writer: pooledWriter, + header: make(http.Header), + } + + h.httpHandler.ServeHTTP(respWriter, req) + + if err := respWriter.writer.Flush(); err != nil { + h.logger.Debug("Failed to flush HTTP response", zap.Error(err)) + } + + // Return writer to pool + pooledWriter.Reset(nil) // Clear reference to connection + bufioWriterPool.Put(pooledWriter) + + h.logger.Debug("HTTP request processing completed", + zap.String("method", req.Method), + zap.String("url", req.URL.String()), + ) + + shouldClose := false + if req.Close { + shouldClose = true + } else if req.ProtoMajor == 1 && req.ProtoMinor == 0 { + if req.Header.Get("Connection") != "keep-alive" { + shouldClose = true + } + } + + if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" { + shouldClose = true + } + + if shouldClose { + h.logger.Debug("Closing connection as requested by client or server") + return nil + } + } +} + +// httpResponseWriter implements http.ResponseWriter for raw TCP connections. +type httpResponseWriter struct { + conn net.Conn + writer *bufio.Writer + header http.Header + statusCode int + headerWritten bool +} + +func (w *httpResponseWriter) Header() http.Header { + return w.header +} + +func (w *httpResponseWriter) WriteHeader(statusCode int) { + if w.headerWritten { + return + } + w.statusCode = statusCode + w.headerWritten = true + + statusText := http.StatusText(statusCode) + if statusText == "" { + statusText = "Unknown" + } + + w.writer.WriteString("HTTP/1.1 ") + w.writer.WriteString(fmt.Sprintf("%d", statusCode)) + w.writer.WriteByte(' ') + w.writer.WriteString(statusText) + w.writer.WriteString("\r\n") + + for key, values := range w.header { + for _, value := range values { + w.writer.WriteString(key) + w.writer.WriteString(": ") + w.writer.WriteString(value) + w.writer.WriteString("\r\n") + } + } + + w.writer.WriteString("\r\n") +} + +func (w *httpResponseWriter) Write(data []byte) (int, error) { + if !w.headerWritten { + w.WriteHeader(http.StatusOK) + } + return w.writer.Write(data) +} diff --git a/internal/server/tcp/lifecycle_manager.go b/internal/server/tcp/lifecycle_manager.go new file mode 100644 index 0000000..c849fec --- /dev/null +++ b/internal/server/tcp/lifecycle_manager.go @@ -0,0 +1,137 @@ +package tcp + +import ( + "sync" + "time" + + "github.com/hashicorp/yamux" + "go.uber.org/zap" + + "drip/internal/server/tunnel" + "drip/internal/shared/protocol" +) + +// ConnectionLifecycleManager manages the lifecycle of a connection. +type ConnectionLifecycleManager struct { + once sync.Once + stopCh chan struct{} + cancel func() + logger *zap.Logger + + // Resources to clean up + conn interface { + Close() error + SetDeadline(time.Time) error + } + frameWriter *protocol.FrameWriter + proxy interface{ Stop() } + session *yamux.Session + portAlloc *PortAllocator + port int + manager *tunnel.Manager + subdomain string + tunnelID string + groupManager *ConnectionGroupManager +} + +// NewConnectionLifecycleManager creates a new lifecycle manager. +func NewConnectionLifecycleManager( + stopCh chan struct{}, + cancel func(), + logger *zap.Logger, +) *ConnectionLifecycleManager { + return &ConnectionLifecycleManager{ + stopCh: stopCh, + cancel: cancel, + logger: logger, + } +} + +// SetConnection sets the connection to manage. +func (clm *ConnectionLifecycleManager) SetConnection(conn interface { + Close() error + SetDeadline(time.Time) error +}) { + clm.conn = conn +} + +// SetFrameWriter sets the frame writer to close. +func (clm *ConnectionLifecycleManager) SetFrameWriter(fw *protocol.FrameWriter) { + clm.frameWriter = fw +} + +// SetProxy sets the proxy to stop. +func (clm *ConnectionLifecycleManager) SetProxy(proxy interface{ Stop() }) { + clm.proxy = proxy +} + +// SetSession sets the yamux session to close. +func (clm *ConnectionLifecycleManager) SetSession(session *yamux.Session) { + clm.session = session +} + +// SetPortAllocation sets the port allocation to release. +func (clm *ConnectionLifecycleManager) SetPortAllocation(portAlloc *PortAllocator, port int) { + clm.portAlloc = portAlloc + clm.port = port +} + +// SetTunnelRegistration sets the tunnel registration to clean up. +func (clm *ConnectionLifecycleManager) SetTunnelRegistration( + manager *tunnel.Manager, + subdomain string, + tunnelID string, + groupManager *ConnectionGroupManager, +) { + clm.manager = manager + clm.subdomain = subdomain + clm.tunnelID = tunnelID + clm.groupManager = groupManager +} + +// Close closes the connection and cleans up all resources. +func (clm *ConnectionLifecycleManager) Close() { + clm.once.Do(func() { + protocol.UnregisterConnection() + close(clm.stopCh) + + if clm.cancel != nil { + clm.cancel() + } + + if clm.conn != nil { + _ = clm.conn.SetDeadline(time.Now()) + } + + if clm.frameWriter != nil { + clm.frameWriter.Close() + } + + if clm.proxy != nil { + clm.proxy.Stop() + } + + if clm.session != nil { + _ = clm.session.Close() + } + + if clm.conn != nil { + clm.conn.Close() + } + + if clm.port > 0 && clm.portAlloc != nil { + clm.portAlloc.Release(clm.port) + } + + if clm.subdomain != "" && clm.manager != nil { + clm.manager.Unregister(clm.subdomain) + if clm.tunnelID != "" && clm.groupManager != nil { + clm.groupManager.RemoveGroup(clm.tunnelID) + } + } + + clm.logger.Info("Connection closed", + zap.String("subdomain", clm.subdomain), + ) + }) +} diff --git a/internal/server/tcp/listener.go b/internal/server/tcp/listener.go index 9ea4f1f..93b50c8 100644 --- a/internal/server/tcp/listener.go +++ b/internal/server/tcp/listener.go @@ -15,11 +15,25 @@ import ( "drip/internal/server/tunnel" "drip/internal/shared/pool" "drip/internal/shared/recovery" + "drip/internal/shared/utils" "go.uber.org/zap" "golang.org/x/net/http2" ) +type ListenerConfig struct { + Address string + TLSConfig *tls.Config + AuthToken string + Manager *tunnel.Manager + Logger *zap.Logger + PortAlloc *PortAllocator + Domain string + TunnelDomain string + PublicPort int + HTTPHandler http.Handler +} + type Listener struct { address string tlsConfig *tls.Config @@ -48,47 +62,47 @@ type Listener struct { allowedTunnelTypes []string } -func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener { +func NewListener(cfg ListenerConfig) *Listener { numCPU := pool.NumCPU() workers := numCPU * 5 queueSize := workers * 20 workerPool := pool.NewWorkerPool(workers, queueSize) - logger.Info("Worker pool configured", + cfg.Logger.Info("Worker pool configured", zap.Int("cpu_cores", numCPU), zap.Int("workers", workers), zap.Int("queue_size", queueSize), ) - panicMetrics := recovery.NewPanicMetrics(logger, nil) - recoverer := recovery.NewRecoverer(logger, panicMetrics) + panicMetrics := recovery.NewPanicMetrics(cfg.Logger, nil) + recoverer := recovery.NewRecoverer(cfg.Logger, panicMetrics) // Initialize worker pool metrics metrics.WorkerPoolSize.Set(float64(workers)) l := &Listener{ - address: address, - tlsConfig: tlsConfig, - authToken: authToken, - manager: manager, - portAlloc: portAlloc, - logger: logger, - domain: domain, - tunnelDomain: tunnelDomain, - publicPort: publicPort, - httpHandler: httpHandler, + address: cfg.Address, + tlsConfig: cfg.TLSConfig, + authToken: cfg.AuthToken, + manager: cfg.Manager, + portAlloc: cfg.PortAlloc, + logger: cfg.Logger, + domain: cfg.Domain, + tunnelDomain: cfg.TunnelDomain, + publicPort: cfg.PublicPort, + httpHandler: cfg.HTTPHandler, stopCh: make(chan struct{}), connections: make(map[string]*Connection), workerPool: workerPool, recoverer: recoverer, panicMetrics: panicMetrics, - groupManager: NewConnectionGroupManager(logger), + groupManager: NewConnectionGroupManager(cfg.Logger), } // Set up WebSocket connection handler if httpHandler supports it - if h, ok := httpHandler.(*proxy.Handler); ok { + if h, ok := cfg.HTTPHandler.(*proxy.Handler); ok { h.SetWSConnectionHandler(l) - h.SetPublicPort(publicPort) + h.SetPublicPort(cfg.PublicPort) } return l @@ -269,7 +283,19 @@ func (l *Listener) handleConnection(netConn net.Conn) { ) } - conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener) + conn := NewConnection(ConnectionConfig{ + Conn: netConn, + AuthToken: l.authToken, + Manager: l.manager, + Logger: l.logger, + PortAlloc: l.portAlloc, + Domain: l.domain, + TunnelDomain: l.tunnelDomain, + PublicPort: l.publicPort, + HTTPHandler: l.httpHandler, + GroupManager: l.groupManager, + HTTPListener: l.httpListener, + }) conn.SetAllowedTunnelTypes(l.allowedTunnelTypes) conn.SetAllowedTransports(l.allowedTransports) @@ -297,18 +323,11 @@ func (l *Listener) handleConnection(netConn net.Conn) { if err := conn.Handle(); err != nil { errStr := err.Error() - if strings.Contains(errStr, "EOF") || - strings.Contains(errStr, "connection reset by peer") || - strings.Contains(errStr, "broken pipe") || - strings.Contains(errStr, "connection refused") { + if utils.IsNetworkError(errStr) { return } - if strings.Contains(errStr, "payload too large") || - strings.Contains(errStr, "failed to read registration frame") || - strings.Contains(errStr, "expected register frame") || - strings.Contains(errStr, "failed to parse registration request") || - strings.Contains(errStr, "failed to parse HTTP request") { + if utils.IsProtocolError(errStr) { l.logger.Warn("Protocol validation failed", zap.String("remote_addr", connID), zap.Error(err), @@ -387,7 +406,19 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) { ) // Create connection handler (no TLS verification needed - already done by HTTP server) - tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener) + tcpConn := NewConnection(ConnectionConfig{ + Conn: conn, + AuthToken: l.authToken, + Manager: l.manager, + Logger: l.logger, + PortAlloc: l.portAlloc, + Domain: l.domain, + TunnelDomain: l.tunnelDomain, + PublicPort: l.publicPort, + HTTPHandler: l.httpHandler, + GroupManager: l.groupManager, + HTTPListener: l.httpListener, + }) tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes) l.connMu.Lock() @@ -412,19 +443,11 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) { if err := tcpConn.Handle(); err != nil { errStr := err.Error() - if strings.Contains(errStr, "EOF") || - strings.Contains(errStr, "connection reset by peer") || - strings.Contains(errStr, "broken pipe") || - strings.Contains(errStr, "connection refused") || - strings.Contains(errStr, "websocket: close") { + if utils.IsNetworkError(errStr) { return } - if strings.Contains(errStr, "payload too large") || - strings.Contains(errStr, "failed to read registration frame") || - strings.Contains(errStr, "expected register frame") || - strings.Contains(errStr, "failed to parse registration request") || - strings.Contains(errStr, "tunnel type not allowed") { + if utils.IsProtocolError(errStr) { l.logger.Warn("WebSocket tunnel protocol validation failed", zap.String("remote_addr", connID), zap.Error(err), diff --git a/internal/server/tcp/registration_handler.go b/internal/server/tcp/registration_handler.go new file mode 100644 index 0000000..400496b --- /dev/null +++ b/internal/server/tcp/registration_handler.go @@ -0,0 +1,186 @@ +package tcp + +import ( + "fmt" + + json "github.com/goccy/go-json" + "go.uber.org/zap" + + "drip/internal/server/tunnel" + "drip/internal/shared/protocol" + "drip/internal/shared/utils" +) + +// RegistrationHandler handles tunnel registration logic. +type RegistrationHandler struct { + manager *tunnel.Manager + portAlloc *PortAllocator + groupManager *ConnectionGroupManager + domain string + tunnelDomain string + publicPort int + logger *zap.Logger +} + +// NewRegistrationHandler creates a new registration handler. +func NewRegistrationHandler( + manager *tunnel.Manager, + portAlloc *PortAllocator, + groupManager *ConnectionGroupManager, + domain, tunnelDomain string, + publicPort int, + logger *zap.Logger, +) *RegistrationHandler { + return &RegistrationHandler{ + manager: manager, + portAlloc: portAlloc, + groupManager: groupManager, + domain: domain, + tunnelDomain: tunnelDomain, + publicPort: publicPort, + logger: logger, + } +} + +// RegistrationRequest contains all information needed for registration. +type RegistrationRequest struct { + TunnelType protocol.TunnelType + CustomSubdomain string + Token string + ConnectionType string + PoolCapabilities *protocol.PoolCapabilities + IPAccess *protocol.IPAccessControl + ProxyAuth *protocol.ProxyAuth + LocalPort int +} + +// RegistrationResult contains the result of a registration attempt. +type RegistrationResult struct { + Subdomain string + Port int + TunnelURL string + TunnelID string + SupportsDataConn bool + RecommendedConns int + TunnelConn *tunnel.Connection +} + +// Register handles the tunnel registration process. +func (rh *RegistrationHandler) Register(req *RegistrationRequest) (*RegistrationResult, error) { + // Allocate port for TCP tunnels + port := 0 + if req.TunnelType == protocol.TunnelTypeTCP { + if rh.portAlloc == nil { + return nil, fmt.Errorf("port allocator not configured") + } + + if requestedPort, ok := parseTCPSubdomainPort(req.CustomSubdomain); ok { + allocatedPort, err := rh.portAlloc.AllocateSpecific(requestedPort) + if err != nil { + return nil, fmt.Errorf("failed to allocate requested port %d: %w", requestedPort, err) + } + port = allocatedPort + } else { + allocatedPort, err := rh.portAlloc.Allocate() + if err != nil { + return nil, fmt.Errorf("failed to allocate port: %w", err) + } + port = allocatedPort + + if req.CustomSubdomain == "" { + req.CustomSubdomain = fmt.Sprintf("tcp-%d", port) + } + } + } + + // Register with tunnel manager + subdomain, err := rh.manager.Register(nil, req.CustomSubdomain) + if err != nil { + if port > 0 && rh.portAlloc != nil { + rh.portAlloc.Release(port) + } + return nil, fmt.Errorf("tunnel registration failed: %w", err) + } + + // Get tunnel connection + tunnelConn, ok := rh.manager.Get(subdomain) + if !ok { + return nil, fmt.Errorf("failed to get registered tunnel") + } + + // Configure tunnel + tunnelConn.SetTunnelType(req.TunnelType) + + if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) { + tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs) + rh.logger.Info("IP access control configured", + zap.String("subdomain", subdomain), + zap.Strings("allow_ips", req.IPAccess.AllowIPs), + zap.Strings("deny_ips", req.IPAccess.DenyIPs), + ) + } + + if req.ProxyAuth != nil && req.ProxyAuth.Enabled { + tunnelConn.SetProxyAuth(req.ProxyAuth) + rh.logger.Info("Proxy authentication configured", + zap.String("subdomain", subdomain), + ) + } + + // Build tunnel URL + urlBuilder := utils.NewTunnelURLBuilder(rh.tunnelDomain, rh.publicPort) + tunnelURL := urlBuilder.BuildURL(subdomain, req.TunnelType, port) + + // Handle connection groups for multi-connection support + var tunnelID string + var supportsDataConn bool + recommendedConns := 0 + + if req.PoolCapabilities != nil && req.ConnectionType == "primary" && rh.groupManager != nil { + // This will be handled by the caller since it needs the connection instance + supportsDataConn = true + recommendedConns = 4 + } + + rh.logger.Info("Tunnel registered", + zap.String("subdomain", subdomain), + zap.String("tunnel_type", string(req.TunnelType)), + zap.Int("local_port", req.LocalPort), + zap.Int("remote_port", port), + ) + + return &RegistrationResult{ + Subdomain: subdomain, + Port: port, + TunnelURL: tunnelURL, + TunnelID: tunnelID, + SupportsDataConn: supportsDataConn, + RecommendedConns: recommendedConns, + TunnelConn: tunnelConn, + }, nil +} + +// BuildRegistrationResponse creates a protocol registration response. +func (rh *RegistrationHandler) BuildRegistrationResponse(result *RegistrationResult) (*protocol.RegisterResponse, error) { + resp := &protocol.RegisterResponse{ + Subdomain: result.Subdomain, + Port: result.Port, + URL: result.TunnelURL, + Message: "Tunnel registered successfully", + TunnelID: result.TunnelID, + SupportsDataConn: result.SupportsDataConn, + RecommendedConns: result.RecommendedConns, + } + return resp, nil +} + +// SendRegistrationResponse sends the registration response frame. +func (rh *RegistrationHandler) SendRegistrationResponse(conn interface{ Write([]byte) (int, error) }, resp *protocol.RegisterResponse) error { + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal registration response: %w", err) + } + + ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData) + return protocol.WriteFrame(conn, ackFrame) +} diff --git a/internal/server/tcp/tunnel.go b/internal/server/tcp/tunnel.go index 26c21c9..0d9f721 100644 --- a/internal/server/tcp/tunnel.go +++ b/internal/server/tcp/tunnel.go @@ -35,6 +35,11 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error { } c.session = session + // Update lifecycle manager with session + if c.lifecycleManager != nil { + c.lifecycleManager.SetSession(session) + } + openStream := session.Open if c.groupManager != nil { if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil { @@ -48,6 +53,11 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error { c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed) } + // Update lifecycle manager with proxy + if c.lifecycleManager != nil { + c.lifecycleManager.SetProxy(c.proxy) + } + if err := c.proxy.Start(); err != nil { return fmt.Errorf("failed to start tcp proxy: %w", err) } @@ -76,6 +86,11 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error { } c.session = session + // Update lifecycle manager with session + if c.lifecycleManager != nil { + c.lifecycleManager.SetSession(session) + } + openStream := session.Open if c.groupManager != nil { if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil { diff --git a/internal/server/tunnel/manager.go b/internal/server/tunnel/manager.go index 7a964fb..58158fd 100644 --- a/internal/server/tunnel/manager.go +++ b/internal/server/tunnel/manager.go @@ -31,12 +31,6 @@ var ( ErrRateLimitExceeded = errors.New("rate limit exceeded, try again later") ) -// rateLimitEntry tracks registration attempts per IP -type rateLimitEntry struct { - count int - windowEnd time.Time -} - // shard holds a subset of tunnels with its own lock type shard struct { tunnels map[string]*Connection @@ -52,16 +46,16 @@ type Manager struct { // Limits maxTunnels int maxTunnelsPerIP int - rateLimit int - rateLimitWindow time.Duration // Global counters (atomic for lock-free reads) tunnelCount atomic.Int64 // Per-IP tracking (requires separate lock as it spans shards) ipMu sync.RWMutex - tunnelsByIP map[string]int // IP -> tunnel count - rateLimits map[string]*rateLimitEntry // IP -> rate limit entry + tunnelsByIP map[string]int // IP -> tunnel count + + // Rate limiting + rateLimiter *RateLimiter // Lifecycle stopCh chan struct{} @@ -71,7 +65,7 @@ type Manager struct { type ManagerConfig struct { MaxTunnels int MaxTunnelsPerIP int - RateLimit int // Registrations per IP per window + RateLimit int // Registrations per IP per window RateLimitWindow time.Duration } @@ -117,10 +111,8 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager { logger: logger, maxTunnels: cfg.MaxTunnels, maxTunnelsPerIP: cfg.MaxTunnelsPerIP, - rateLimit: cfg.RateLimit, - rateLimitWindow: cfg.RateLimitWindow, tunnelsByIP: make(map[string]int), - rateLimits: make(map[string]*rateLimitEntry), + rateLimiter: NewRateLimiter(cfg.RateLimit, cfg.RateLimitWindow, logger), stopCh: make(chan struct{}), } @@ -140,28 +132,6 @@ func (m *Manager) getShard(subdomain string) *shard { return &m.shards[h.Sum32()%numShards] } -// checkRateLimit checks if the IP has exceeded rate limit (caller must hold ipMu) -func (m *Manager) checkRateLimitLocked(ip string) bool { - now := time.Now() - entry, exists := m.rateLimits[ip] - - if !exists || now.After(entry.windowEnd) { - // New window - m.rateLimits[ip] = &rateLimitEntry{ - count: 1, - windowEnd: now.Add(m.rateLimitWindow), - } - return true - } - - if entry.count >= m.rateLimit { - return false - } - - entry.count++ - return true -} - // Register registers a new tunnel connection with IP-based limits func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string, error) { return m.RegisterWithIP(conn, customSubdomain, "") @@ -193,19 +163,14 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r // Check per-IP limits and reserve slot atomically if remoteIP != "" { - m.ipMu.Lock() - if !m.checkRateLimitLocked(remoteIP) { - m.ipMu.Unlock() + // Check rate limit first (has its own lock) + if !m.rateLimiter.CheckAndIncrement(remoteIP) { rollbackGlobal() - m.logger.Warn("Rate limit exceeded", - zap.String("ip", remoteIP), - zap.Int("limit", m.rateLimit), - ) - metrics.RateLimitRejections.WithLabelValues("registration", remoteIP).Inc() metrics.TunnelRegistrationFailures.WithLabelValues("rate_limit").Inc() return "", ErrRateLimitExceeded } + m.ipMu.Lock() if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP { currentPerIP := m.tunnelsByIP[remoteIP] m.ipMu.Unlock() @@ -427,14 +392,7 @@ func (m *Manager) CleanupStale(timeout time.Duration) int { } // Cleanup expired rate limit entries - m.ipMu.Lock() - now := time.Now() - for ip, entry := range m.rateLimits { - if now.After(entry.windowEnd) { - delete(m.rateLimits, ip) - } - } - m.ipMu.Unlock() + m.rateLimiter.Cleanup() if totalCleaned > 0 { m.logger.Info("Cleaned up stale tunnels", diff --git a/internal/server/tunnel/rate_limiter.go b/internal/server/tunnel/rate_limiter.go new file mode 100644 index 0000000..056d6ed --- /dev/null +++ b/internal/server/tunnel/rate_limiter.go @@ -0,0 +1,86 @@ +package tunnel + +import ( + "sync" + "time" + + "drip/internal/server/metrics" + "go.uber.org/zap" +) + +// rateLimitEntry tracks registration attempts per IP +type rateLimitEntry struct { + count int + windowEnd time.Time +} + +// RateLimiter manages rate limiting for tunnel registrations. +type RateLimiter struct { + mu sync.RWMutex + rateLimits map[string]*rateLimitEntry + rateLimit int + rateLimitWindow time.Duration + logger *zap.Logger +} + +// NewRateLimiter creates a new rate limiter. +func NewRateLimiter(rateLimit int, rateLimitWindow time.Duration, logger *zap.Logger) *RateLimiter { + return &RateLimiter{ + rateLimits: make(map[string]*rateLimitEntry), + rateLimit: rateLimit, + rateLimitWindow: rateLimitWindow, + logger: logger, + } +} + +// CheckAndIncrement checks if the IP has exceeded rate limit and increments the counter. +func (rl *RateLimiter) CheckAndIncrement(ip string) bool { + if ip == "" { + return true + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + entry, exists := rl.rateLimits[ip] + + if !exists || now.After(entry.windowEnd) { + // New window + rl.rateLimits[ip] = &rateLimitEntry{ + count: 1, + windowEnd: now.Add(rl.rateLimitWindow), + } + return true + } + + if entry.count >= rl.rateLimit { + rl.logger.Warn("Rate limit exceeded", + zap.String("ip", ip), + zap.Int("limit", rl.rateLimit), + ) + metrics.RateLimitRejections.WithLabelValues("registration", ip).Inc() + return false + } + + entry.count++ + return true +} + +// Cleanup removes expired rate limit entries. +func (rl *RateLimiter) Cleanup() int { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + removed := 0 + + for ip, entry := range rl.rateLimits { + if now.After(entry.windowEnd) { + delete(rl.rateLimits, ip) + removed++ + } + } + + return removed +} diff --git a/internal/shared/httputil/detection.go b/internal/shared/httputil/detection.go new file mode 100644 index 0000000..8843049 --- /dev/null +++ b/internal/shared/httputil/detection.go @@ -0,0 +1,39 @@ +package httputil + +import "strings" + +// HTTPMethods contains common HTTP method prefixes for protocol detection. +var HTTPMethods = []string{ + "GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC", +} + +// IsHTTPRequest checks if the given bytes represent the start of an HTTP request. +// It checks for common HTTP method prefixes. +func IsHTTPRequest(data []byte) bool { + if len(data) < 4 { + return false + } + + dataStr := string(data[:4]) + for _, method := range HTTPMethods { + if strings.HasPrefix(dataStr, method) { + return true + } + } + return false +} + +// DetectHTTPMethod returns the HTTP method if the data starts with one, or empty string. +func DetectHTTPMethod(data []byte) string { + if len(data) < 4 { + return "" + } + + dataStr := string(data) + for _, method := range HTTPMethods { + if strings.HasPrefix(dataStr, method) { + return strings.TrimSpace(method) + } + } + return "" +} diff --git a/internal/shared/httputil/error_response.go b/internal/shared/httputil/error_response.go new file mode 100644 index 0000000..95e25f0 --- /dev/null +++ b/internal/shared/httputil/error_response.go @@ -0,0 +1,58 @@ +package httputil + +import ( + "fmt" + "net" +) + +// HTTPErrorResponse represents a standard HTTP error response. +type HTTPErrorResponse struct { + StatusCode int + StatusText string + Message string +} + +// Common HTTP error responses +var ( + ServiceUnavailable = &HTTPErrorResponse{ + StatusCode: 503, + StatusText: "Service Unavailable", + Message: "Server busy, please retry later", + } + + HandlerNotConfigured = &HTTPErrorResponse{ + StatusCode: 503, + StatusText: "Service Unavailable", + Message: "HTTP handler not configured for this TCP port", + } +) + +// WriteErrorResponse writes an HTTP error response to the connection. +func WriteErrorResponse(conn net.Conn, resp *HTTPErrorResponse) error { + response := fmt.Sprintf( + "HTTP/1.1 %d %s\r\n"+ + "Content-Type: text/plain\r\n"+ + "Content-Length: %d\r\n"+ + "Connection: close\r\n"+ + "\r\n"+ + "%s\r\n", + resp.StatusCode, + resp.StatusText, + len(resp.Message)+2, + resp.Message, + ) + _, err := conn.Write([]byte(response)) + return err +} + +// WriteServiceUnavailable writes a 503 Service Unavailable response. +func WriteServiceUnavailable(conn net.Conn, message string) error { + if message == "" { + message = ServiceUnavailable.Message + } + return WriteErrorResponse(conn, &HTTPErrorResponse{ + StatusCode: 503, + StatusText: "Service Unavailable", + Message: message, + }) +} diff --git a/internal/shared/httputil/response.go b/internal/shared/httputil/response.go new file mode 100644 index 0000000..b8bb543 --- /dev/null +++ b/internal/shared/httputil/response.go @@ -0,0 +1,38 @@ +package httputil + +import ( + "fmt" + "net/http" +) + +// WriteJSON writes a JSON response with the appropriate headers. +func WriteJSON(w http.ResponseWriter, data []byte) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + w.Write(data) +} + +// WriteHTML writes an HTML response with the appropriate headers. +func WriteHTML(w http.ResponseWriter, data []byte) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + w.Write(data) +} + +// WriteHTMLWithStatus writes an HTML response with a custom status code. +func WriteHTMLWithStatus(w http.ResponseWriter, data []byte, statusCode int) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + w.WriteHeader(statusCode) + w.Write(data) +} + +// SetContentLength sets the Content-Length header. +func SetContentLength(w http.ResponseWriter, length int64) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", length)) +} + +// SetCloseConnection sets the Connection: close header. +func SetCloseConnection(w http.ResponseWriter) { + w.Header().Set("Connection", "close") +} diff --git a/internal/shared/mux/session_builder.go b/internal/shared/mux/session_builder.go new file mode 100644 index 0000000..6a96666 --- /dev/null +++ b/internal/shared/mux/session_builder.go @@ -0,0 +1,116 @@ +package mux + +import ( + "bufio" + "net" + + "github.com/hashicorp/yamux" +) + +// BufferedConn wraps a connection with a buffered reader. +type BufferedConn struct { + net.Conn + reader *bufio.Reader +} + +// NewBufferedConn creates a new buffered connection. +func NewBufferedConn(conn net.Conn, reader *bufio.Reader) *BufferedConn { + return &BufferedConn{ + Conn: conn, + reader: reader, + } +} + +// Read reads from the buffered reader if available, otherwise from the connection. +func (bc *BufferedConn) Read(p []byte) (int, error) { + if bc.reader != nil { + return bc.reader.Read(p) + } + return bc.Conn.Read(p) +} + +// SessionBuilder helps build yamux sessions with consistent configuration. +type SessionBuilder struct { + conn net.Conn + reader *bufio.Reader + config *yamux.Config + isServer bool +} + +// NewSessionBuilder creates a new session builder. +func NewSessionBuilder(conn net.Conn) *SessionBuilder { + return &SessionBuilder{ + conn: conn, + } +} + +// WithReader sets the buffered reader for the session. +func (sb *SessionBuilder) WithReader(reader *bufio.Reader) *SessionBuilder { + sb.reader = reader + return sb +} + +// WithConfig sets the yamux configuration. +func (sb *SessionBuilder) WithConfig(config *yamux.Config) *SessionBuilder { + sb.config = config + return sb +} + +// AsServer configures the session as a server. +func (sb *SessionBuilder) AsServer() *SessionBuilder { + sb.isServer = true + if sb.config == nil { + sb.config = NewServerConfig() + } + return sb +} + +// AsClient configures the session as a client. +func (sb *SessionBuilder) AsClient() *SessionBuilder { + sb.isServer = false + if sb.config == nil { + sb.config = NewClientConfig() + } + return sb +} + +// Build creates the yamux session. +func (sb *SessionBuilder) Build() (*yamux.Session, error) { + conn := sb.conn + + // Wrap with buffered reader if provided + if sb.reader != nil { + conn = NewBufferedConn(sb.conn, sb.reader) + } + + // Use default config if not set + if sb.config == nil { + if sb.isServer { + sb.config = NewServerConfig() + } else { + sb.config = NewClientConfig() + } + } + + // Create session based on role + if sb.isServer { + return yamux.Server(conn, sb.config) + } + return yamux.Client(conn, sb.config) +} + +// BuildClientSession creates a yamux client session with the given connection and reader. +func BuildClientSession(conn net.Conn, reader *bufio.Reader) (*yamux.Session, error) { + return NewSessionBuilder(conn). + WithReader(reader). + AsClient(). + Build() +} + +// BuildServerSession creates a yamux server session with the given connection and reader. +func BuildServerSession(conn net.Conn, reader *bufio.Reader) (*yamux.Session, error) { + return NewSessionBuilder(conn). + WithReader(reader). + AsServer(). + Build() +} diff --git a/internal/shared/netutil/ip.go b/internal/shared/netutil/ip.go new file mode 100644 index 0000000..a89e586 --- /dev/null +++ b/internal/shared/netutil/ip.go @@ -0,0 +1,78 @@ +package netutil + +import ( + "net" + "net/http" + "strings" +) + +var privateNetworks []*net.IPNet + +func init() { + privateCIDRs := []string{ + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } + for _, cidr := range privateCIDRs { + _, ipNet, _ := net.ParseCIDR(cidr) + privateNetworks = append(privateNetworks, ipNet) + } +} + +// ExtractRemoteIP extracts the IP address from a remote address string (host:port format). +func ExtractRemoteIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + return host +} + +// IsPrivateIP checks if the given IP is a private/loopback address. +func IsPrivateIP(ip string) bool { + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + for _, network := range privateNetworks { + if network.Contains(parsedIP) { + return true + } + } + + return false +} + +// ExtractClientIP extracts the client IP from the request. +// It only trusts X-Forwarded-For and X-Real-IP headers when the request +// comes from a private/loopback network (typical reverse proxy setup). +func ExtractClientIP(r *http.Request) string { + // First, get the direct remote address + remoteIP := ExtractRemoteIP(r.RemoteAddr) + + // Only trust proxy headers if the request comes from a private network + if IsPrivateIP(remoteIP) { + // Check X-Forwarded-For header (may contain multiple IPs) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP (original client) + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + } + + // Fall back to remote address + return remoteIP +} diff --git a/internal/shared/pool/buffer_pool.go b/internal/shared/pool/buffer_pool.go index 65c4930..8207dec 100644 --- a/internal/shared/pool/buffer_pool.go +++ b/internal/shared/pool/buffer_pool.go @@ -3,10 +3,10 @@ package pool import "sync" const ( - SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages - SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies - SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers - SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data + SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages + SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies + SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers + SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data ) type BufferPool struct { diff --git a/internal/shared/pool/bufio_pool.go b/internal/shared/pool/bufio_pool.go new file mode 100644 index 0000000..1df3522 --- /dev/null +++ b/internal/shared/pool/bufio_pool.go @@ -0,0 +1,48 @@ +package pool + +import ( + "bufio" + "sync" +) + +// BufioReaderPool provides a pool of bufio.Reader instances. +var BufioReaderPool = sync.Pool{ + New: func() interface{} { + return bufio.NewReaderSize(nil, 32*1024) + }, +} + +// BufioWriterPool provides a pool of bufio.Writer instances. +var BufioWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, 4096) + }, +} + +// GetReader gets a bufio.Reader from the pool and resets it to read from r. +func GetReader(r interface{}) *bufio.Reader { + reader := BufioReaderPool.Get().(*bufio.Reader) + if resetter, ok := r.(interface{ Reset(interface{}) }); ok { + resetter.Reset(r) + } + return reader +} + +// PutReader returns a bufio.Reader to the pool. +func PutReader(reader *bufio.Reader) { + BufioReaderPool.Put(reader) +} + +// GetWriter gets a bufio.Writer from the pool and resets it to write to w. +func GetWriter(w interface{}) *bufio.Writer { + writer := BufioWriterPool.Get().(*bufio.Writer) + if resetter, ok := w.(interface{ Reset(interface{}) }); ok { + resetter.Reset(w) + } + return writer +} + +// PutWriter returns a bufio.Writer to the pool. +func PutWriter(writer *bufio.Writer) { + BufioWriterPool.Put(writer) +} diff --git a/internal/shared/protocol/error_sender.go b/internal/shared/protocol/error_sender.go new file mode 100644 index 0000000..4b015f5 --- /dev/null +++ b/internal/shared/protocol/error_sender.go @@ -0,0 +1,68 @@ +package protocol + +import ( + "fmt" + "net" + + json "github.com/goccy/go-json" + "go.uber.org/zap" +) + +// ErrorSender handles sending error frames over connections. +type ErrorSender struct { + conn net.Conn + frameWriter *FrameWriter + logger *zap.Logger +} + +// NewErrorSender creates a new error sender. +func NewErrorSender(conn net.Conn, frameWriter *FrameWriter, logger *zap.Logger) *ErrorSender { + return &ErrorSender{ + conn: conn, + frameWriter: frameWriter, + logger: logger, + } +} + +// SendError sends an error frame with the given code and message. +func (e *ErrorSender) SendError(code, message string) error { + errMsg := ErrorMessage{ + Code: code, + Message: message, + } + + data, err := json.Marshal(errMsg) + if err != nil { + e.logger.Error("Failed to marshal error message", zap.Error(err)) + return fmt.Errorf("failed to marshal error: %w", err) + } + + errFrame := NewFrame(FrameTypeError, data) + + if e.frameWriter == nil { + return WriteFrame(e.conn, errFrame) + } + + return e.frameWriter.WriteFrame(errFrame) +} + +// SendAuthenticationError sends an authentication failed error. +func (e *ErrorSender) SendAuthenticationError() error { + return e.SendError("authentication_failed", "Invalid authentication token") +} + +// SendRegistrationError sends a registration failed error. +func (e *ErrorSender) SendRegistrationError(message string) error { + return e.SendError("registration_failed", message) +} + +// SendPortAllocationError sends a port allocation failed error. +func (e *ErrorSender) SendPortAllocationError(message string) error { + return e.SendError("port_allocation_failed", message) +} + +// SendTunnelTypeNotAllowedError sends a tunnel type not allowed error. +func (e *ErrorSender) SendTunnelTypeNotAllowedError(tunnelType string) error { + return e.SendError("tunnel_type_not_allowed", + fmt.Sprintf("Tunnel type '%s' is not allowed on this server", tunnelType)) +} diff --git a/internal/shared/protocol/heartbeat.go b/internal/shared/protocol/heartbeat.go new file mode 100644 index 0000000..0b24bec --- /dev/null +++ b/internal/shared/protocol/heartbeat.go @@ -0,0 +1,74 @@ +package protocol + +import ( + "sync" + "time" + + "go.uber.org/zap" +) + +// HeartbeatManager manages heartbeat tracking and checking for connections. +type HeartbeatManager struct { + mu sync.RWMutex + lastHeartbeat time.Time + interval time.Duration + timeout time.Duration + logger *zap.Logger + frameWriter *FrameWriter +} + +// NewHeartbeatManager creates a new heartbeat manager. +func NewHeartbeatManager(interval, timeout time.Duration, frameWriter *FrameWriter, logger *zap.Logger) *HeartbeatManager { + return &HeartbeatManager{ + lastHeartbeat: time.Now(), + interval: interval, + timeout: timeout, + frameWriter: frameWriter, + logger: logger, + } +} + +// UpdateLastHeartbeat updates the last heartbeat timestamp. +func (h *HeartbeatManager) UpdateLastHeartbeat() { + h.mu.Lock() + h.lastHeartbeat = time.Now() + h.mu.Unlock() +} + +// GetLastHeartbeat returns the last heartbeat timestamp. +func (h *HeartbeatManager) GetLastHeartbeat() time.Time { + h.mu.RLock() + defer h.mu.RUnlock() + return h.lastHeartbeat +} + +// IsAlive checks if the connection is still alive based on heartbeat timeout. +func (h *HeartbeatManager) IsAlive() bool { + h.mu.RLock() + defer h.mu.RUnlock() + return time.Since(h.lastHeartbeat) < h.timeout +} + +// SendHeartbeatAck sends a heartbeat acknowledgment frame. +func (h *HeartbeatManager) SendHeartbeatAck() error { + ackFrame := NewFrame(FrameTypeHeartbeatAck, nil) + err := h.frameWriter.WriteControl(ackFrame) + if err != nil { + h.logger.Error("Failed to send heartbeat ack", zap.Error(err)) + return err + } + return nil +} + +// HandleHeartbeat handles a received heartbeat frame. +func (h *HeartbeatManager) HandleHeartbeat() error { + h.UpdateLastHeartbeat() + return h.SendHeartbeatAck() +} + +// TimeSinceLastHeartbeat returns the duration since the last heartbeat. +func (h *HeartbeatManager) TimeSinceLastHeartbeat() time.Duration { + h.mu.RLock() + defer h.mu.RUnlock() + return time.Since(h.lastHeartbeat) +} diff --git a/internal/shared/protocol/messages.go b/internal/shared/protocol/messages.go index 0238f3b..d82947a 100644 --- a/internal/shared/protocol/messages.go +++ b/internal/shared/protocol/messages.go @@ -14,7 +14,9 @@ type IPAccessControl struct { type ProxyAuth struct { Enabled bool `json:"enabled"` + Type string `json:"type,omitempty"` Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` } type RegisterRequest struct { diff --git a/internal/shared/tuning/mem_windows.go b/internal/shared/tuning/mem_windows.go index aa64143..594765a 100644 --- a/internal/shared/tuning/mem_windows.go +++ b/internal/shared/tuning/mem_windows.go @@ -8,8 +8,8 @@ import ( ) var ( - kernel32 = syscall.NewLazyDLL("kernel32.dll") - globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx") + kernel32 = syscall.NewLazyDLL("kernel32.dll") + globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx") ) type memoryStatusEx struct { diff --git a/internal/shared/utils/allowed_list.go b/internal/shared/utils/allowed_list.go new file mode 100644 index 0000000..108b9b5 --- /dev/null +++ b/internal/shared/utils/allowed_list.go @@ -0,0 +1,49 @@ +package utils + +import "strings" + +// AllowedList manages a list of allowed values with case-insensitive matching. +type AllowedList struct { + items []string +} + +// NewAllowedList creates a new AllowedList from the given items. +func NewAllowedList(items []string) *AllowedList { + return &AllowedList{items: items} +} + +// IsAllowed checks if a value is in the allowed list. +// If the list is empty, all values are allowed. +// Matching is case-insensitive. +func (a *AllowedList) IsAllowed(value string) bool { + if len(a.items) == 0 { + return true + } + for _, item := range a.items { + if strings.EqualFold(item, value) { + return true + } + } + return false +} + +// GetPreferred returns the first item in the list, or the default value if empty. +func (a *AllowedList) GetPreferred(defaultValue string) string { + if len(a.items) == 0 { + return defaultValue + } + if len(a.items) == 1 { + return a.items[0] + } + return defaultValue +} + +// Items returns the underlying slice of allowed items. +func (a *AllowedList) Items() []string { + return a.items +} + +// IsEmpty returns true if the allowed list is empty. +func (a *AllowedList) IsEmpty() bool { + return len(a.items) == 0 +} diff --git a/internal/shared/utils/errors.go b/internal/shared/utils/errors.go new file mode 100644 index 0000000..88cd9b3 --- /dev/null +++ b/internal/shared/utils/errors.go @@ -0,0 +1,35 @@ +package utils + +import "strings" + +// IsNetworkError checks if an error message indicates a common network error +// that should be handled gracefully (not logged as severe errors). +func IsNetworkError(errStr string) bool { + return strings.Contains(errStr, "EOF") || + strings.Contains(errStr, "connection reset by peer") || + strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "use of closed network connection") || + strings.Contains(errStr, "websocket: close") +} + +// IsProtocolError checks if an error message indicates a protocol-level error +// (invalid requests, malformed data, etc.). +func IsProtocolError(errStr string) bool { + return strings.Contains(errStr, "payload too large") || + strings.Contains(errStr, "failed to read registration frame") || + strings.Contains(errStr, "expected register frame") || + strings.Contains(errStr, "failed to parse registration request") || + strings.Contains(errStr, "failed to parse HTTP request") || + strings.Contains(errStr, "tunnel type not allowed") +} + +// ContainsAny checks if a string contains any of the given substrings. +func ContainsAny(s string, substrings ...string) bool { + for _, substr := range substrings { + if strings.Contains(s, substr) { + return true + } + } + return false +} diff --git a/internal/shared/utils/url_builder.go b/internal/shared/utils/url_builder.go new file mode 100644 index 0000000..7263d6b --- /dev/null +++ b/internal/shared/utils/url_builder.go @@ -0,0 +1,42 @@ +package utils + +import ( + "fmt" + + "drip/internal/shared/protocol" +) + +// TunnelURLBuilder helps construct tunnel URLs consistently. +type TunnelURLBuilder struct { + tunnelDomain string + publicPort int +} + +// NewTunnelURLBuilder creates a new URL builder. +func NewTunnelURLBuilder(tunnelDomain string, publicPort int) *TunnelURLBuilder { + return &TunnelURLBuilder{ + tunnelDomain: tunnelDomain, + publicPort: publicPort, + } +} + +// BuildHTTPURL builds an HTTP/HTTPS tunnel URL. +func (b *TunnelURLBuilder) BuildHTTPURL(subdomain string) string { + if b.publicPort == 443 { + return fmt.Sprintf("https://%s.%s", subdomain, b.tunnelDomain) + } + return fmt.Sprintf("https://%s.%s:%d", subdomain, b.tunnelDomain, b.publicPort) +} + +// BuildTCPURL builds a TCP tunnel URL. +func (b *TunnelURLBuilder) BuildTCPURL(port int) string { + return fmt.Sprintf("tcp://%s:%d", b.tunnelDomain, port) +} + +// BuildURL builds a tunnel URL based on the tunnel type. +func (b *TunnelURLBuilder) BuildURL(subdomain string, tunnelType protocol.TunnelType, port int) string { + if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS { + return b.BuildHTTPURL(subdomain) + } + return b.BuildTCPURL(port) +} diff --git a/pkg/config/client_config.go b/pkg/config/client_config.go index 7cde8e5..944dba5 100644 --- a/pkg/config/client_config.go +++ b/pkg/config/client_config.go @@ -12,15 +12,16 @@ import ( // TunnelConfig holds configuration for a predefined tunnel type TunnelConfig struct { - Name string `yaml:"name"` // Tunnel name (required, unique identifier) - Type string `yaml:"type"` // Tunnel type: http, https, tcp (required) - Port int `yaml:"port"` // Local port to forward (required) - Address string `yaml:"address,omitempty"` // Local address (default: 127.0.0.1) - Subdomain string `yaml:"subdomain,omitempty"` // Custom subdomain - Transport string `yaml:"transport,omitempty"` // Transport: auto, tcp, wss - AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs - DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs - Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only) + Name string `yaml:"name"` // Tunnel name (required, unique identifier) + Type string `yaml:"type"` // Tunnel type: http, https, tcp (required) + Port int `yaml:"port"` // Local port to forward (required) + Address string `yaml:"address,omitempty"` // Local address (default: 127.0.0.1) + Subdomain string `yaml:"subdomain,omitempty"` // Custom subdomain + Transport string `yaml:"transport,omitempty"` // Transport: auto, tcp, wss + AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs + DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs + Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only) + AuthBearer string `yaml:"auth_bearer,omitempty"` // Proxy authentication bearer token (http/https only) } // Validate checks if the tunnel configuration is valid @@ -44,6 +45,9 @@ func (t *TunnelConfig) Validate() error { return fmt.Errorf("invalid transport '%s' for '%s': must be auto, tcp, or wss", t.Transport, t.Name) } } + if t.Auth != "" && t.AuthBearer != "" { + return fmt.Errorf("only one of auth or auth_bearer can be set for '%s'", t.Name) + } return nil } From dfb19930c7758654473ba89ae26f77b82baf8513 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Thu, 29 Jan 2026 14:53:23 +0800 Subject: [PATCH 2/2] refactor(pool): optimize bufio pool's type definition and reset logic --- internal/shared/pool/bufio_pool.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/internal/shared/pool/bufio_pool.go b/internal/shared/pool/bufio_pool.go index 1df3522..90258e3 100644 --- a/internal/shared/pool/bufio_pool.go +++ b/internal/shared/pool/bufio_pool.go @@ -2,6 +2,7 @@ package pool import ( "bufio" + "io" "sync" ) @@ -20,11 +21,9 @@ var BufioWriterPool = sync.Pool{ } // GetReader gets a bufio.Reader from the pool and resets it to read from r. -func GetReader(r interface{}) *bufio.Reader { +func GetReader(r io.Reader) *bufio.Reader { reader := BufioReaderPool.Get().(*bufio.Reader) - if resetter, ok := r.(interface{ Reset(interface{}) }); ok { - resetter.Reset(r) - } + reader.Reset(r) return reader } @@ -34,11 +33,9 @@ func PutReader(reader *bufio.Reader) { } // GetWriter gets a bufio.Writer from the pool and resets it to write to w. -func GetWriter(w interface{}) *bufio.Writer { +func GetWriter(w io.Writer) *bufio.Writer { writer := BufioWriterPool.Get().(*bufio.Writer) - if resetter, ok := w.(interface{ Reset(interface{}) }); ok { - resetter.Reset(w) - } + writer.Reset(w) return writer }