mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-24 05:10:43 +00:00
Merge pull request #8 from Gouryella/refactor/protocol-v2
refactor(protocol): switch to yamux multiplexing with connection pooling
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -53,3 +53,4 @@ certs/
|
||||
.drip-server.env
|
||||
benchmark-results/
|
||||
drip
|
||||
drip-linux-amd64
|
||||
|
||||
3
go.mod
3
go.mod
@@ -6,8 +6,8 @@ require (
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/hashicorp/yamux v0.1.2
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/sys v0.38.0
|
||||
@@ -28,7 +28,6 @@ require (
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -17,6 +17,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
|
||||
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
@@ -40,10 +42,6 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ var attachCmd = &cobra.Command{
|
||||
Examples:
|
||||
drip attach List running tunnels and select one
|
||||
drip attach http 3000 Attach to HTTP tunnel on port 3000
|
||||
drip attach https 8443 Attach to HTTPS tunnel on port 8443
|
||||
drip attach tcp 5432 Attach to TCP tunnel on port 5432
|
||||
|
||||
Press Ctrl+C to detach (tunnel will continue running).`,
|
||||
@@ -36,7 +37,7 @@ func init() {
|
||||
rootCmd.AddCommand(attachCmd)
|
||||
}
|
||||
|
||||
func runAttach(cmd *cobra.Command, args []string) error {
|
||||
func runAttach(_ *cobra.Command, args []string) error {
|
||||
CleanupStaleDaemons()
|
||||
|
||||
daemons, err := ListAllDaemons()
|
||||
@@ -59,8 +60,8 @@ func runAttach(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if len(args) == 2 {
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
|
||||
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(args[1])
|
||||
@@ -119,10 +120,13 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
uptime := time.Since(d.StartTime)
|
||||
|
||||
var typeStr string
|
||||
if d.Type == "http" {
|
||||
typeStr = ui.Success("HTTP")
|
||||
} else {
|
||||
typeStr = ui.Highlight("TCP")
|
||||
switch d.Type {
|
||||
case "http":
|
||||
typeStr = ui.Highlight("HTTP")
|
||||
case "https":
|
||||
typeStr = ui.Highlight("HTTPS")
|
||||
default:
|
||||
typeStr = ui.Cyan("TCP")
|
||||
}
|
||||
|
||||
table.AddRow([]string{
|
||||
@@ -196,7 +200,7 @@ func attachToDaemon(daemon *DaemonInfo) error {
|
||||
select {
|
||||
case <-sigCh:
|
||||
if tailCmd.Process != nil {
|
||||
tailCmd.Process.Kill()
|
||||
_ = tailCmd.Process.Kill()
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println(ui.Warning("Detached from tunnel (tunnel is still running)"))
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -77,7 +77,7 @@ func init() {
|
||||
rootCmd.AddCommand(configCmd)
|
||||
}
|
||||
|
||||
func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
func runConfigInit(_ *cobra.Command, _ []string) error {
|
||||
fmt.Print(ui.RenderConfigInit())
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
@@ -109,7 +109,7 @@ func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
func runConfigShow(_ *cobra.Command, _ []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -137,7 +137,7 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
func runConfigSet(_ *cobra.Command, _ []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
cfg = &config.ClientConfig{
|
||||
@@ -173,7 +173,7 @@ func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
func runConfigReset(_ *cobra.Command, _ []string) error {
|
||||
configPath := config.DefaultClientConfigPath()
|
||||
|
||||
if !config.ConfigExists("") {
|
||||
@@ -202,7 +202,7 @@ func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigValidate(cmd *cobra.Command, args []string) error {
|
||||
func runConfigValidate(_ *cobra.Command, _ []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
fmt.Println(ui.Error("Failed to load configuration"))
|
||||
|
||||
@@ -5,10 +5,10 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/pkg/config"
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
@@ -194,14 +194,67 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
return fmt.Errorf("failed to start daemon: %w", err)
|
||||
}
|
||||
|
||||
// Don't wait for the process - let it run in background
|
||||
// The child process will save its own daemon info after connecting
|
||||
_ = logFile.Close()
|
||||
_ = devNull.Close()
|
||||
|
||||
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath))
|
||||
localHost := parseFlagValue(cleanArgs, "--address", "-a", "127.0.0.1")
|
||||
displayHost := localHost
|
||||
if displayHost == "127.0.0.1" {
|
||||
displayHost = "localhost"
|
||||
}
|
||||
forwardAddr := fmt.Sprintf("%s:%d", displayHost, port)
|
||||
|
||||
serverAddr := parseFlagValue(cleanArgs, "--server", "-s", "")
|
||||
if serverAddr == "" {
|
||||
if cfg, err := config.LoadClientConfig(""); err == nil {
|
||||
serverAddr = cfg.Server
|
||||
}
|
||||
}
|
||||
|
||||
var url string
|
||||
|
||||
info, err := waitForDaemonInfo(tunnelType, port, cmd.Process.Pid, 30*time.Second)
|
||||
if err == nil && info != nil && info.PID == cmd.Process.Pid && info.URL != "" {
|
||||
url = info.URL
|
||||
if info.Server != "" {
|
||||
serverAddr = info.Server
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath, url, forwardAddr, serverAddr))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFlagValue(args []string, longName string, shortName string, defaultValue string) string {
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] == longName || args[i] == shortName {
|
||||
if i+1 < len(args) && args[i+1] != "" {
|
||||
return args[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func waitForDaemonInfo(tunnelType string, port int, pid int, timeout time.Duration) (*DaemonInfo, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if !IsProcessRunning(pid) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
info, err := LoadDaemonInfo(tunnelType, port)
|
||||
if err == nil && info != nil && info.PID == pid {
|
||||
if info.URL != "" {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CleanupStaleDaemons removes daemon info for processes that are no longer running
|
||||
func CleanupStaleDaemons() error {
|
||||
daemons, err := ListAllDaemons()
|
||||
@@ -220,28 +273,16 @@ func CleanupStaleDaemons() error {
|
||||
|
||||
// FormatDuration formats a duration in a human-readable way
|
||||
func FormatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
switch {
|
||||
case d < time.Minute:
|
||||
return fmt.Sprintf("%ds", int(d.Seconds()))
|
||||
} else if d < time.Hour {
|
||||
case d < time.Hour:
|
||||
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
|
||||
} else if d < 24*time.Hour {
|
||||
case d < 24*time.Hour:
|
||||
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
|
||||
}
|
||||
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
|
||||
// ParsePortFromArgs extracts the port number from command arguments
|
||||
func ParsePortFromArgs(args []string) (int, error) {
|
||||
for _, arg := range args {
|
||||
if len(arg) > 0 && arg[0] == '-' {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(arg)
|
||||
if err == nil && port > 0 && port <= 65535 {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("port number not found in arguments")
|
||||
}
|
||||
|
||||
@@ -2,22 +2,14 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
maxReconnectAttempts = 5
|
||||
reconnectInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
subdomain string
|
||||
daemonMode bool
|
||||
@@ -52,55 +44,19 @@ func init() {
|
||||
rootCmd.AddCommand(httpCmd)
|
||||
}
|
||||
|
||||
func runHTTP(cmd *cobra.Command, args []string) error {
|
||||
func runHTTP(_ *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if daemonMode && !daemonMarker {
|
||||
daemonArgs := append([]string{"http"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("http", port, daemonArgs)
|
||||
return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip http %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
serverAddr, token, err := resolveServerAddrAndToken("http", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
@@ -115,15 +71,7 @@ Please run 'drip config init' first, or use flags:
|
||||
|
||||
var daemon *DaemonInfo
|
||||
if daemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "http",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
daemon = newDaemonInfo("http", port, subdomain, serverAddr)
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
|
||||
@@ -2,24 +2,14 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
httpsSubdomain string
|
||||
httpsDaemonMode bool
|
||||
httpsDaemonMarker bool
|
||||
httpsLocalAddress string
|
||||
)
|
||||
|
||||
var httpsCmd = &cobra.Command{
|
||||
Use: "https <port>",
|
||||
Short: "Start HTTPS tunnel",
|
||||
@@ -39,86 +29,42 @@ Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
}
|
||||
|
||||
func init() {
|
||||
httpsCmd.Flags().StringVarP(&httpsSubdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
httpsCmd.Flags().BoolVarP(&httpsDaemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
httpsCmd.Flags().StringVarP(&httpsLocalAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
httpsCmd.Flags().BoolVar(&httpsDaemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpsCmd)
|
||||
}
|
||||
|
||||
func runHTTPS(cmd *cobra.Command, args []string) error {
|
||||
func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if httpsDaemonMode && !httpsDaemonMarker {
|
||||
daemonArgs := append([]string{"https"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if httpsSubdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", httpsSubdomain)
|
||||
}
|
||||
if httpsLocalAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", httpsLocalAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("https", port, daemonArgs)
|
||||
if daemonMode && !daemonMarker {
|
||||
return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip https %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
serverAddr, token, err := resolveServerAddrAndToken("https", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
TunnelType: protocol.TunnelTypeHTTPS,
|
||||
LocalHost: httpsLocalAddress,
|
||||
LocalHost: localAddress,
|
||||
LocalPort: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Subdomain: subdomain,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
if httpsDaemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "https",
|
||||
Port: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if daemonMarker {
|
||||
daemon = newDaemonInfo("https", port, subdomain, serverAddr)
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
|
||||
@@ -8,13 +8,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
interactiveMode bool
|
||||
)
|
||||
var interactiveMode bool
|
||||
|
||||
var listCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
@@ -26,7 +24,7 @@ Example:
|
||||
drip list -i Interactive mode (select to attach/stop)
|
||||
|
||||
This command shows:
|
||||
- Tunnel type (HTTP/TCP)
|
||||
- Tunnel type (HTTP/HTTPS/TCP)
|
||||
- Local port being tunneled
|
||||
- Public URL
|
||||
- Process ID (PID)
|
||||
@@ -44,7 +42,7 @@ func init() {
|
||||
rootCmd.AddCommand(listCmd)
|
||||
}
|
||||
|
||||
func runList(cmd *cobra.Command, args []string) error {
|
||||
func runList(_ *cobra.Command, _ []string) error {
|
||||
CleanupStaleDaemons()
|
||||
|
||||
daemons, err := ListAllDaemons()
|
||||
@@ -99,7 +97,7 @@ func runList(cmd *cobra.Command, args []string) error {
|
||||
|
||||
fmt.Print(table.Render())
|
||||
|
||||
if interactiveMode || shouldPromptForAction() {
|
||||
if interactiveMode {
|
||||
return runInteractiveList(daemons)
|
||||
}
|
||||
|
||||
@@ -114,13 +112,6 @@ func runList(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldPromptForAction() bool {
|
||||
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
var runningDaemons []*DaemonInfo
|
||||
for _, d := range daemons {
|
||||
@@ -161,9 +152,9 @@ func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
"",
|
||||
ui.Muted("What would you like to do?"),
|
||||
"",
|
||||
ui.Cyan(" 1.") + " Attach (view logs)",
|
||||
ui.Cyan(" 2.") + " Stop tunnel",
|
||||
ui.Muted(" q.") + " Cancel",
|
||||
ui.Cyan(" 1.")+" Attach (view logs)",
|
||||
ui.Cyan(" 2.")+" Stop tunnel",
|
||||
ui.Muted(" q.")+" Cancel",
|
||||
))
|
||||
|
||||
fmt.Print(ui.Muted("Choose an action: "))
|
||||
|
||||
@@ -3,7 +3,7 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -56,14 +56,12 @@ func init() {
|
||||
versionCmd.Flags().BoolVar(&versionPlain, "short", false, "Print version information without styling")
|
||||
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
// http and tcp commands are added in their respective init() functions
|
||||
// config command is added in config.go init() function
|
||||
}
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print version information",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, _ []string) {
|
||||
if versionPlain {
|
||||
fmt.Printf("Version: %s\nGit Commit: %s\nBuild Time: %s\n", Version, GitCommit, BuildTime)
|
||||
return
|
||||
|
||||
@@ -59,7 +59,7 @@ func init() {
|
||||
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
func runServer(_ *cobra.Command, _ []string) error {
|
||||
if serverTLSCert == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
|
||||
}
|
||||
@@ -126,11 +126,9 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
|
||||
|
||||
responseHandler := proxy.NewResponseHandler(logger)
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, serverDomain, serverAuthToken)
|
||||
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
|
||||
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler)
|
||||
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
|
||||
@@ -28,7 +28,7 @@ func init() {
|
||||
rootCmd.AddCommand(stopCmd)
|
||||
}
|
||||
|
||||
func runStop(cmd *cobra.Command, args []string) error {
|
||||
func runStop(_ *cobra.Command, args []string) error {
|
||||
if args[0] == "all" {
|
||||
return stopAllDaemons()
|
||||
}
|
||||
|
||||
@@ -2,13 +2,10 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -47,55 +44,19 @@ func init() {
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
}
|
||||
|
||||
func runTCP(cmd *cobra.Command, args []string) error {
|
||||
func runTCP(_ *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if daemonMode && !daemonMarker {
|
||||
daemonArgs := append([]string{"tcp"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("tcp", port, daemonArgs)
|
||||
return StartDaemon("tcp", port, buildDaemonArgs("tcp", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
serverAddr, token, err := resolveServerAddrAndToken("tcp", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
@@ -110,15 +71,7 @@ Please run 'drip config init' first, or use flags:
|
||||
|
||||
var daemon *DaemonInfo
|
||||
if daemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "tcp",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
daemon = newDaemonInfo("tcp", port, subdomain, serverAddr)
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
|
||||
67
internal/client/cli/tunnel_helpers.go
Normal file
67
internal/client/cli/tunnel_helpers.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAddress string) []string {
|
||||
daemonArgs := append([]string{tunnelType}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
|
||||
return daemonArgs
|
||||
}
|
||||
|
||||
func resolveServerAddrAndToken(tunnelType string, port int) (string, string, error) {
|
||||
if serverURL != "" {
|
||||
return serverURL, authToken, nil
|
||||
}
|
||||
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip %s %d --server SERVER:PORT --token TOKEN`, tunnelType, port)
|
||||
}
|
||||
|
||||
if cfg.Server == "" {
|
||||
return "", "", fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
return cfg.Server, cfg.Token, nil
|
||||
}
|
||||
|
||||
func newDaemonInfo(tunnelType string, port int, subdomain string, serverAddr string) *DaemonInfo {
|
||||
return &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: tunnelType,
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,17 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/internal/shared/utils"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// runTunnelWithUI runs a tunnel with the new UI
|
||||
const (
|
||||
maxReconnectAttempts = 5
|
||||
reconnectInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
@@ -28,7 +32,7 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
connector := tcp.NewTunnelClient(connConfig, logger)
|
||||
|
||||
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
|
||||
|
||||
@@ -87,8 +91,8 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
renderTicker := time.NewTicker(1 * time.Second)
|
||||
defer renderTicker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
lastRenderedLines := 0
|
||||
@@ -97,27 +101,38 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
case <-renderTicker.C:
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
status.Latency = lastLatency
|
||||
status.BytesIn = snapshot.TotalBytesIn
|
||||
status.BytesOut = snapshot.TotalBytesOut
|
||||
status.SpeedIn = float64(snapshot.SpeedIn)
|
||||
status.SpeedOut = float64(snapshot.SpeedOut)
|
||||
status.TotalRequest = snapshot.TotalRequests
|
||||
|
||||
statsView := ui.RenderTunnelStats(status)
|
||||
if lastRenderedLines > 0 {
|
||||
fmt.Print(clearLines(lastRenderedLines))
|
||||
}
|
||||
|
||||
fmt.Print(statsView)
|
||||
lastRenderedLines = countRenderedLines(statsView)
|
||||
if stats == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
status.Latency = lastLatency
|
||||
status.BytesIn = snapshot.TotalBytesIn
|
||||
status.BytesOut = snapshot.TotalBytesOut
|
||||
status.SpeedIn = float64(snapshot.SpeedIn)
|
||||
status.SpeedOut = float64(snapshot.SpeedOut)
|
||||
|
||||
if status.Type == "tcp" {
|
||||
if snapshot.SpeedIn == 0 && snapshot.SpeedOut == 0 {
|
||||
status.TotalRequest = 0
|
||||
} else {
|
||||
status.TotalRequest = snapshot.ActiveConnections
|
||||
}
|
||||
} else {
|
||||
status.TotalRequest = snapshot.TotalRequests
|
||||
}
|
||||
|
||||
statsView := ui.RenderTunnelStats(status)
|
||||
if lastRenderedLines > 0 {
|
||||
fmt.Print(clearLines(lastRenderedLines))
|
||||
}
|
||||
|
||||
fmt.Print(statsView)
|
||||
lastRenderedLines = countRenderedLines(statsView)
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,493 +1,50 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/recovery"
|
||||
"drip/pkg/config"
|
||||
"drip/internal/shared/stats"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LatencyCallback is called when latency is measured
|
||||
type LatencyCallback func(latency time.Duration)
|
||||
|
||||
// Connector manages the TCP connection to the server
|
||||
type Connector struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
token string
|
||||
tunnelType protocol.TunnelType
|
||||
localHost string
|
||||
localPort int
|
||||
subdomain string
|
||||
conn net.Conn
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
registered bool
|
||||
assignedURL string
|
||||
frameHandler *FrameHandler
|
||||
frameWriter *protocol.FrameWriter
|
||||
latencyCallback LatencyCallback
|
||||
heartbeatSentAt time.Time
|
||||
heartbeatMu sync.Mutex
|
||||
lastLatency time.Duration
|
||||
handlerWg sync.WaitGroup // Tracks active data frame handlers
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
|
||||
// Worker pool for handling data frames
|
||||
dataFrameQueue chan *protocol.Frame
|
||||
workerCount int
|
||||
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
}
|
||||
|
||||
// ConnectorConfig holds connector configuration
|
||||
type ConnectorConfig struct {
|
||||
ServerAddr string
|
||||
Token string
|
||||
TunnelType protocol.TunnelType
|
||||
LocalHost string // Local host address (default: 127.0.0.1)
|
||||
LocalHost string
|
||||
LocalPort int
|
||||
Subdomain string // Optional custom subdomain
|
||||
Insecure bool // Skip TLS verification (testing only)
|
||||
Subdomain string
|
||||
Insecure bool
|
||||
|
||||
PoolSize int
|
||||
PoolMin int
|
||||
PoolMax int
|
||||
}
|
||||
|
||||
// NewConnector creates a new connector
|
||||
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg.Insecure {
|
||||
tlsConfig = config.GetClientTLSConfigInsecure()
|
||||
} else {
|
||||
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
||||
tlsConfig = config.GetClientTLSConfig(host)
|
||||
}
|
||||
|
||||
localHost := cfg.LocalHost
|
||||
if localHost == "" {
|
||||
localHost = "127.0.0.1"
|
||||
}
|
||||
|
||||
numCPU := pool.NumCPU()
|
||||
workerCount := max(numCPU+numCPU/2, 4)
|
||||
|
||||
panicMetrics := recovery.NewPanicMetrics(logger, nil)
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
return &Connector{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: cfg.TunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
|
||||
workerCount: workerCount,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
}
|
||||
type TunnelClient interface {
|
||||
Connect() error
|
||||
Close() error
|
||||
Wait()
|
||||
GetURL() string
|
||||
GetSubdomain() string
|
||||
SetLatencyCallback(cb LatencyCallback)
|
||||
GetLatency() time.Duration
|
||||
GetStats() *stats.TrafficStats
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
// Connect connects to the server and registers the tunnel
|
||||
func (c *Connector) Connect() error {
|
||||
c.logger.Info("Connecting to server",
|
||||
zap.String("server", c.serverAddr),
|
||||
zap.String("tunnel_type", string(c.tunnelType)),
|
||||
zap.String("local_host", c.localHost),
|
||||
zap.Int("local_port", c.localPort),
|
||||
)
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
conn.Close()
|
||||
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
c.logger.Info("TLS connection established",
|
||||
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
|
||||
)
|
||||
|
||||
if err := c.register(); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
bufferPool := pool.NewBufferPool()
|
||||
|
||||
c.frameHandler = NewFrameHandler(
|
||||
c.conn,
|
||||
c.frameWriter,
|
||||
c.localHost,
|
||||
c.localPort,
|
||||
c.tunnelType,
|
||||
c.logger,
|
||||
c.IsClosed,
|
||||
bufferPool,
|
||||
)
|
||||
|
||||
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
|
||||
|
||||
for i := 0; i < c.workerCount; i++ {
|
||||
c.handlerWg.Add(1)
|
||||
go c.dataFrameWorker(i)
|
||||
}
|
||||
|
||||
go c.frameHandler.WarmupConnectionPool(3)
|
||||
go c.monitorQueuePressure()
|
||||
go c.handleFrames()
|
||||
|
||||
return nil
|
||||
func NewTunnelClient(cfg *ConnectorConfig, logger *zap.Logger) TunnelClient {
|
||||
return NewPoolClient(cfg, logger)
|
||||
}
|
||||
|
||||
// register sends registration request and waits for acknowledgment
|
||||
func (c *Connector) register() error {
|
||||
req := protocol.RegisterRequest{
|
||||
Token: c.token,
|
||||
CustomSubdomain: c.subdomain,
|
||||
TunnelType: c.tunnelType,
|
||||
LocalPort: c.localPort,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
|
||||
err = protocol.WriteFrame(c.conn, regFrame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send registration: %w", err)
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
ackFrame, err := protocol.ReadFrame(c.conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ack: %w", err)
|
||||
}
|
||||
defer ackFrame.Release()
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ackFrame.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
|
||||
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
return fmt.Errorf("registration error")
|
||||
}
|
||||
|
||||
if ackFrame.Type != protocol.FrameTypeRegisterAck {
|
||||
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
|
||||
}
|
||||
|
||||
var resp protocol.RegisterResponse
|
||||
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
c.registered = true
|
||||
c.assignedURL = resp.URL
|
||||
c.subdomain = resp.Subdomain
|
||||
|
||||
c.logger.Info("Tunnel registered successfully",
|
||||
zap.String("subdomain", resp.Subdomain),
|
||||
zap.String("url", resp.URL),
|
||||
zap.Int("remote_port", resp.Port),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connector) dataFrameWorker(workerID int) {
|
||||
defer c.handlerWg.Done()
|
||||
defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID))
|
||||
|
||||
for {
|
||||
select {
|
||||
case frame, ok := <-c.dataFrameQueue:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
sf := protocol.WithFrame(frame)
|
||||
defer sf.Close()
|
||||
defer c.recoverer.Recover("handleDataFrame")
|
||||
|
||||
if err := c.frameHandler.HandleDataFrame(sf.Frame); err != nil {
|
||||
c.logger.Error("Failed to handle data frame",
|
||||
zap.Int("worker_id", workerID),
|
||||
zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames from server
|
||||
func (c *Connector) handleFrames() {
|
||||
defer c.Close()
|
||||
defer c.recoverer.Recover("handleFrames")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(c.conn)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
c.logger.Warn("Read timeout")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
c.logger.Error("Failed to read frame", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeHeartbeatAck:
|
||||
c.heartbeatMu.Lock()
|
||||
if !c.heartbeatSentAt.IsZero() {
|
||||
latency := time.Since(c.heartbeatSentAt)
|
||||
c.lastLatency = latency
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
|
||||
|
||||
if c.latencyCallback != nil {
|
||||
c.latencyCallback(latency)
|
||||
}
|
||||
} else {
|
||||
c.heartbeatMu.Unlock()
|
||||
c.logger.Debug("Received heartbeat ack")
|
||||
}
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
select {
|
||||
case c.dataFrameQueue <- sf.Frame:
|
||||
case <-c.stopCh:
|
||||
sf.Close()
|
||||
return
|
||||
default:
|
||||
c.logger.Warn("Data frame queue full, dropping frame")
|
||||
sf.Close()
|
||||
}
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Server requested close")
|
||||
return
|
||||
|
||||
case protocol.FrameTypeError:
|
||||
var errMsg protocol.ErrorMessage
|
||||
if err := json.Unmarshal(sf.Frame.Payload, &errMsg); err == nil {
|
||||
c.logger.Error("Received error from server",
|
||||
zap.String("code", errMsg.Code),
|
||||
zap.String("message", errMsg.Message),
|
||||
)
|
||||
}
|
||||
sf.Close()
|
||||
return
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
|
||||
c.closedMu.RLock()
|
||||
if c.closed {
|
||||
c.closedMu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
c.closedMu.RUnlock()
|
||||
|
||||
c.heartbeatMu.Lock()
|
||||
c.heartbeatSentAt = time.Now()
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the server
|
||||
func (c *Connector) SendFrame(frame *protocol.Frame) error {
|
||||
if !c.registered {
|
||||
return fmt.Errorf("not registered")
|
||||
}
|
||||
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
func (c *Connector) Close() error {
|
||||
c.once.Do(func() {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
|
||||
close(c.stopCh)
|
||||
close(c.dataFrameQueue)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.handlerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
c.logger.Warn("Force closing: some handlers are still active")
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.WriteFrame(closeFrame)
|
||||
c.frameWriter.Close()
|
||||
} else {
|
||||
protocol.WriteFrame(c.conn, closeFrame)
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
}
|
||||
c.logger.Info("Connector closed")
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait blocks until connection is closed
|
||||
func (c *Connector) Wait() {
|
||||
<-c.stopCh
|
||||
}
|
||||
|
||||
// GetURL returns the assigned tunnel URL
|
||||
func (c *Connector) GetURL() string {
|
||||
return c.assignedURL
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connector) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// SetLatencyCallback sets the callback for latency updates
|
||||
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
|
||||
c.latencyCallback = cb
|
||||
}
|
||||
|
||||
// GetLatency returns the last measured latency
|
||||
func (c *Connector) GetLatency() time.Duration {
|
||||
c.heartbeatMu.Lock()
|
||||
defer c.heartbeatMu.Unlock()
|
||||
return c.lastLatency
|
||||
}
|
||||
|
||||
// GetStats returns the traffic stats from the frame handler
|
||||
func (c *Connector) GetStats() *TrafficStats {
|
||||
if c.frameHandler != nil {
|
||||
return c.frameHandler.GetStats()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connector has been closed
|
||||
func (c *Connector) IsClosed() bool {
|
||||
c.closedMu.RLock()
|
||||
defer c.closedMu.RUnlock()
|
||||
return c.closed
|
||||
}
|
||||
func (c *Connector) monitorQueuePressure() {
|
||||
defer c.recoverer.Recover("monitorQueuePressure")
|
||||
|
||||
const (
|
||||
pauseThreshold = 0.80
|
||||
resumeThreshold = 0.50
|
||||
checkInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
isPaused := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
queueLen := len(c.dataFrameQueue)
|
||||
queueCap := cap(c.dataFrameQueue)
|
||||
usage := float64(queueLen) / float64(queueCap)
|
||||
|
||||
if usage > pauseThreshold && !isPaused {
|
||||
c.sendFlowControl("*", protocol.FlowControlPause)
|
||||
isPaused = true
|
||||
c.logger.Warn("Queue pressure high, sent pause signal",
|
||||
zap.Int("queue_len", queueLen),
|
||||
zap.Int("queue_cap", queueCap),
|
||||
zap.Float64("usage", usage))
|
||||
} else if usage < resumeThreshold && isPaused {
|
||||
c.sendFlowControl("*", protocol.FlowControlResume)
|
||||
isPaused = false
|
||||
c.logger.Info("Queue pressure normal, sent resume signal",
|
||||
zap.Int("queue_len", queueLen),
|
||||
zap.Int("queue_cap", queueCap),
|
||||
zap.Float64("usage", usage))
|
||||
}
|
||||
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) sendFlowControl(streamID string, action protocol.FlowControlAction) {
|
||||
frame := protocol.NewFlowControlFrame(streamID, action)
|
||||
if err := c.SendFrame(frame); err != nil {
|
||||
c.logger.Error("Failed to send flow control",
|
||||
zap.String("action", string(action)),
|
||||
zap.Error(err))
|
||||
}
|
||||
func isExpectedCloseError(err error) bool {
|
||||
s := err.Error()
|
||||
return strings.Contains(s, "EOF") ||
|
||||
strings.Contains(s, "use of closed") ||
|
||||
strings.Contains(s, "connection reset")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
450
internal/client/tcp/pool_client.go
Normal file
450
internal/client/tcp/pool_client.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/stats"
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
// PoolClient manages a pool of yamux sessions for tunnel connections.
|
||||
type PoolClient struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
token string
|
||||
tunnelType protocol.TunnelType
|
||||
localHost string
|
||||
localPort int
|
||||
subdomain string
|
||||
|
||||
assignedURL string
|
||||
tunnelID string
|
||||
|
||||
minSessions int
|
||||
maxSessions int
|
||||
initialSessions int
|
||||
|
||||
stats *stats.TrafficStats
|
||||
|
||||
httpClient *http.Client
|
||||
|
||||
latencyCallback atomic.Value // LatencyCallback
|
||||
latencyNanos atomic.Int64
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
closed atomic.Bool
|
||||
|
||||
primary *sessionHandle
|
||||
|
||||
mu sync.RWMutex
|
||||
dataSessions map[string]*sessionHandle
|
||||
desiredTotal int
|
||||
lastScale time.Time
|
||||
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewPoolClient creates a new pool client.
|
||||
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg.Insecure {
|
||||
tlsConfig = config.GetClientTLSConfigInsecure()
|
||||
} else {
|
||||
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
||||
tlsConfig = config.GetClientTLSConfig(host)
|
||||
}
|
||||
|
||||
localHost := cfg.LocalHost
|
||||
if localHost == "" {
|
||||
localHost = "127.0.0.1"
|
||||
}
|
||||
|
||||
tunnelType := cfg.TunnelType
|
||||
if tunnelType == "" {
|
||||
tunnelType = protocol.TunnelTypeTCP
|
||||
}
|
||||
|
||||
numCPU := runtime.NumCPU()
|
||||
|
||||
minSessions := cfg.PoolMin
|
||||
if minSessions <= 0 {
|
||||
minSessions = 2
|
||||
}
|
||||
|
||||
maxSessions := cfg.PoolMax
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = max(numCPU*16, minSessions)
|
||||
}
|
||||
if maxSessions < minSessions {
|
||||
maxSessions = minSessions
|
||||
}
|
||||
|
||||
initialSessions := cfg.PoolSize
|
||||
if initialSessions <= 0 {
|
||||
initialSessions = 4
|
||||
}
|
||||
initialSessions = min(max(initialSessions, minSessions), maxSessions)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
c := &PoolClient{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: tunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
minSessions: minSessions,
|
||||
maxSessions: maxSessions,
|
||||
initialSessions: initialSessions,
|
||||
stats: stats.NewTrafficStats(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
dataSessions: make(map[string]*sessionHandle),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
c.httpClient = newLocalHTTPClient(tunnelType)
|
||||
}
|
||||
|
||||
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect establishes the primary connection and starts background workers.
|
||||
func (c *PoolClient) Connect() error {
|
||||
primaryConn, err := c.dialTLS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maxData := max(c.maxSessions-1, 0)
|
||||
req := protocol.RegisterRequest{
|
||||
Token: c.token,
|
||||
CustomSubdomain: c.subdomain,
|
||||
TunnelType: c.tunnelType,
|
||||
LocalPort: c.localPort,
|
||||
ConnectionType: "primary",
|
||||
PoolCapabilities: &protocol.PoolCapabilities{
|
||||
MaxDataConns: maxData,
|
||||
Version: 1,
|
||||
},
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to send registration: %w", err)
|
||||
}
|
||||
|
||||
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
ack, err := protocol.ReadFrame(primaryConn)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to read register ack: %w", err)
|
||||
}
|
||||
defer ack.Release()
|
||||
_ = primaryConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ack.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("registration error")
|
||||
}
|
||||
if ack.Type != protocol.FrameTypeRegisterAck {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
|
||||
}
|
||||
|
||||
var resp protocol.RegisterResponse
|
||||
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to parse register response: %w", err)
|
||||
}
|
||||
|
||||
c.assignedURL = resp.URL
|
||||
c.subdomain = resp.Subdomain
|
||||
if resp.SupportsDataConn && resp.TunnelID != "" {
|
||||
c.tunnelID = resp.TunnelID
|
||||
}
|
||||
|
||||
yamuxCfg := yamux.DefaultConfig()
|
||||
yamuxCfg.EnableKeepAlive = false
|
||||
yamuxCfg.LogOutput = io.Discard
|
||||
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Server(primaryConn, yamuxCfg)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
|
||||
primary := &sessionHandle{
|
||||
id: "primary",
|
||||
conn: primaryConn,
|
||||
session: session,
|
||||
}
|
||||
primary.touch()
|
||||
c.primary = primary
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
<-c.stopCh
|
||||
}()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.acceptLoop(primary, true)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.sessionWatcher(primary, true)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.pingLoop(primary)
|
||||
|
||||
if c.tunnelID != "" {
|
||||
c.mu.Lock()
|
||||
c.desiredTotal = c.initialSessions
|
||||
c.mu.Unlock()
|
||||
|
||||
c.ensureSessions()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.scalerLoop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
close(c.doneCh)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) dialTLS() (net.Conn, error) {
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
|
||||
defer c.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
stream, err := h.session.Accept()
|
||||
if err != nil {
|
||||
if c.IsClosed() || isExpectedCloseError(err) {
|
||||
return
|
||||
}
|
||||
if isPrimary {
|
||||
c.logger.Debug("Primary session accept failed", zap.Error(err))
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
|
||||
c.removeDataSession(h.id)
|
||||
return
|
||||
}
|
||||
|
||||
h.active.Add(1)
|
||||
h.touch()
|
||||
|
||||
c.stats.AddRequest()
|
||||
c.stats.IncActiveConnections()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.handleStream(h, stream)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
|
||||
defer c.wg.Done()
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-h.session.CloseChan():
|
||||
if isPrimary {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
c.removeDataSession(h.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PoolClient) pingLoop(h *sessionHandle) {
|
||||
defer c.wg.Done()
|
||||
|
||||
const maxConsecutiveFailures = 3
|
||||
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
consecutiveFailures := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
if h.session == nil || h.session.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
latency, err := h.session.Ping()
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
c.logger.Debug("Ping failed",
|
||||
zap.String("session_id", h.id),
|
||||
zap.Int("consecutive_failures", consecutiveFailures),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
c.logger.Warn("Session ping failed too many times, closing",
|
||||
zap.String("session_id", h.id),
|
||||
zap.Int("failures", consecutiveFailures),
|
||||
)
|
||||
if h.id == "primary" {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
c.removeDataSession(h.id)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
consecutiveFailures = 0
|
||||
h.touch()
|
||||
|
||||
c.latencyNanos.Store(int64(latency))
|
||||
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
|
||||
cb(latency)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the client and all sessions.
|
||||
func (c *PoolClient) Close() error {
|
||||
var closeErr error
|
||||
|
||||
c.once.Do(func() {
|
||||
c.closed.Store(true)
|
||||
close(c.stopCh)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
var data []*sessionHandle
|
||||
var primary *sessionHandle
|
||||
|
||||
c.mu.Lock()
|
||||
for _, h := range c.dataSessions {
|
||||
data = append(data, h)
|
||||
}
|
||||
c.dataSessions = make(map[string]*sessionHandle)
|
||||
primary = c.primary
|
||||
c.primary = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, h := range data {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
if h.session != nil {
|
||||
_ = h.session.Close()
|
||||
}
|
||||
if h.conn != nil {
|
||||
_ = h.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if primary != nil {
|
||||
if primary.session != nil {
|
||||
closeErr = primary.session.Close()
|
||||
}
|
||||
if primary.conn != nil {
|
||||
_ = primary.conn.Close()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (c *PoolClient) Wait() { <-c.doneCh }
|
||||
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
||||
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
||||
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
||||
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
||||
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
||||
|
||||
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
|
||||
if cb == nil {
|
||||
cb = func(time.Duration) {}
|
||||
}
|
||||
c.latencyCallback.Store(cb)
|
||||
}
|
||||
253
internal/client/tcp/pool_handler.go
Normal file
253
internal/client/tcp/pool_handler.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// handleStream routes incoming stream to appropriate handler.
|
||||
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
|
||||
defer c.wg.Done()
|
||||
defer func() {
|
||||
h.active.Add(-1)
|
||||
c.stats.DecActiveConnections()
|
||||
}()
|
||||
defer stream.Close()
|
||||
|
||||
switch c.tunnelType {
|
||||
case protocol.TunnelTypeHTTP, protocol.TunnelTypeHTTPS:
|
||||
c.handleHTTPStream(stream)
|
||||
default:
|
||||
c.handleTCPStream(stream)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTCPStream handles raw TCP tunneling.
|
||||
func (c *PoolClient) handleTCPStream(stream net.Conn) {
|
||||
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
|
||||
if err != nil {
|
||||
c.logger.Debug("Dial local failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
defer localConn.Close()
|
||||
|
||||
if tcpConn, ok := localConn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
c.ctx,
|
||||
stream,
|
||||
localConn,
|
||||
pool.SizeLarge,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
)
|
||||
}
|
||||
|
||||
// handleHTTPStream handles HTTP/HTTPS proxy requests.
|
||||
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
cc := netutil.NewCountingConn(stream,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
)
|
||||
|
||||
br := bufio.NewReaderSize(cc, 32*1024)
|
||||
req, err := http.ReadRequest(br)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
|
||||
if httputil.IsWebSocketUpgrade(req) {
|
||||
c.handleWebSocketUpgrade(cc, req)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.ctx)
|
||||
defer cancel()
|
||||
|
||||
scheme := "http"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
targetURL := fmt.Sprintf("%s://%s:%d%s", scheme, c.localHost, c.localPort, req.URL.RequestURI())
|
||||
outReq, err := http.NewRequestWithContext(ctx, req.Method, targetURL, req.Body)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
|
||||
return
|
||||
}
|
||||
|
||||
origHost := req.Host
|
||||
httputil.CopyHeaders(outReq.Header, req.Header)
|
||||
httputil.CleanHopByHopHeaders(outReq.Header)
|
||||
|
||||
targetHost := c.localHost
|
||||
if c.localPort != 80 && c.localPort != 443 {
|
||||
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
|
||||
}
|
||||
outReq.Host = targetHost
|
||||
outReq.Header.Set("Host", targetHost)
|
||||
if origHost != "" {
|
||||
outReq.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
outReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
resp, err := c.httpClient.Do(outReq)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Local service unavailable")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
if err := writeResponseHeader(cc, resp); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := resp.Body.Read(buf)
|
||||
if nr > 0 {
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
nw, ew := cc.Write(buf[:nr])
|
||||
if ew != nil || nr != nw {
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}
|
||||
|
||||
// handleWebSocketUpgrade handles WebSocket upgrade requests.
|
||||
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
scheme := "ws"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "wss"
|
||||
}
|
||||
|
||||
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
|
||||
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "WebSocket backend unavailable")
|
||||
return
|
||||
}
|
||||
defer localConn.Close()
|
||||
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
tlsConn := tls.Client(localConn, &tls.Config{InsecureSkipVerify: true})
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "TLS handshake failed")
|
||||
return
|
||||
}
|
||||
localConn = tlsConn
|
||||
}
|
||||
|
||||
req.URL.Scheme = scheme
|
||||
req.URL.Host = targetAddr
|
||||
if err := req.Write(localConn); err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
|
||||
return
|
||||
}
|
||||
|
||||
localBr := bufio.NewReader(localConn)
|
||||
resp, err := http.ReadResponse(localBr, req)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to read upgrade response")
|
||||
return
|
||||
}
|
||||
|
||||
if err := resp.Write(cc); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
c.ctx,
|
||||
cc,
|
||||
localConn,
|
||||
pool.SizeLarge,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// newLocalHTTPClient creates an HTTP client for local service requests.
|
||||
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
|
||||
var tlsConfig *tls.Config
|
||||
if tunnelType == protocol.TunnelTypeHTTPS {
|
||||
tlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 2000,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
MaxConnsPerHost: 0,
|
||||
IdleConnTimeout: 180 * time.Second,
|
||||
DisableCompression: true,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 500 * time.Millisecond,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func writeResponseHeader(w io.Writer, resp *http.Response) error {
|
||||
statusLine := fmt.Sprintf("HTTP/%d.%d %d %s\r\n",
|
||||
resp.ProtoMajor, resp.ProtoMinor,
|
||||
resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||
if _, err := io.WriteString(w, statusLine); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := resp.Header.Write(w); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
312
internal/client/tcp/pool_session.go
Normal file
312
internal/client/tcp/pool_session.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
// sessionHandle wraps a yamux session with metadata.
|
||||
type sessionHandle struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
session *yamux.Session
|
||||
active atomic.Int64
|
||||
lastActive atomic.Int64 // unix nanos
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (h *sessionHandle) touch() {
|
||||
h.lastActive.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (h *sessionHandle) lastActiveTime() time.Time {
|
||||
n := h.lastActive.Load()
|
||||
if n == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(0, n)
|
||||
}
|
||||
|
||||
// scalerLoop monitors load and adjusts session count.
|
||||
func (c *PoolClient) scalerLoop() {
|
||||
defer c.wg.Done()
|
||||
|
||||
const (
|
||||
checkInterval = 5 * time.Second
|
||||
scaleUpCooldown = 5 * time.Second
|
||||
scaleDownCooldown = 60 * time.Second
|
||||
capacityPerSession = int64(64)
|
||||
scaleUpLoad = 0.7
|
||||
scaleDownLoad = 0.3
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
desired := c.desiredTotal
|
||||
if desired == 0 {
|
||||
desired = c.initialSessions
|
||||
c.desiredTotal = desired
|
||||
}
|
||||
lastScale := c.lastScale
|
||||
c.mu.Unlock()
|
||||
|
||||
current := c.sessionCount()
|
||||
if current <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
active := c.stats.GetActiveConnections()
|
||||
load := float64(active) / float64(int64(current)*capacityPerSession)
|
||||
|
||||
sinceLastScale := time.Since(lastScale)
|
||||
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
|
||||
c.mu.Lock()
|
||||
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
|
||||
c.lastScale = time.Now()
|
||||
c.mu.Unlock()
|
||||
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
|
||||
c.mu.Lock()
|
||||
c.desiredTotal = max(c.desiredTotal-1, c.minSessions)
|
||||
c.lastScale = time.Now()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
c.ensureSessions()
|
||||
}
|
||||
}
|
||||
|
||||
// ensureSessions adjusts session count to match desired.
|
||||
func (c *PoolClient) ensureSessions() {
|
||||
if c.IsClosed() || c.tunnelID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
desired := c.desiredTotal
|
||||
c.mu.RUnlock()
|
||||
|
||||
desired = min(max(desired, c.minSessions), c.maxSessions)
|
||||
|
||||
current := c.sessionCount()
|
||||
if current < desired {
|
||||
for i := 0; i < desired-current; i++ {
|
||||
if err := c.addDataSession(); err != nil {
|
||||
c.logger.Debug("Add data session failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if current > desired {
|
||||
c.removeIdleSessions(current - desired)
|
||||
}
|
||||
}
|
||||
|
||||
// addDataSession creates a new data session.
|
||||
func (c *PoolClient) addDataSession() error {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
if c.tunnelID == "" {
|
||||
return fmt.Errorf("server does not support data connections")
|
||||
}
|
||||
|
||||
conn, err := c.dialTLS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
|
||||
|
||||
req := protocol.DataConnectRequest{
|
||||
TunnelID: c.tunnelID,
|
||||
Token: c.token,
|
||||
ConnectionID: connID,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to marshal data connect request: %w", err)
|
||||
}
|
||||
|
||||
if err := protocol.WriteFrame(conn, protocol.NewFrame(protocol.FrameTypeDataConnect, payload)); err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to send data connect: %w", err)
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
ack, err := protocol.ReadFrame(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to read data connect ack: %w", err)
|
||||
}
|
||||
defer ack.Release()
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ack.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("data connect error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("data connect error")
|
||||
}
|
||||
if ack.Type != protocol.FrameTypeDataConnectAck {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("unexpected data connect ack frame: %s", ack.Type)
|
||||
}
|
||||
|
||||
var resp protocol.DataConnectResponse
|
||||
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to parse data connect response: %w", err)
|
||||
}
|
||||
if !resp.Accepted {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("data connection rejected: %s", resp.Message)
|
||||
}
|
||||
|
||||
yamuxCfg := yamux.DefaultConfig()
|
||||
yamuxCfg.EnableKeepAlive = false
|
||||
yamuxCfg.LogOutput = io.Discard
|
||||
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Server(conn, yamuxCfg)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
|
||||
h := &sessionHandle{
|
||||
id: connID,
|
||||
conn: conn,
|
||||
session: session,
|
||||
}
|
||||
h.touch()
|
||||
|
||||
c.mu.Lock()
|
||||
c.dataSessions[connID] = h
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.acceptLoop(h, false)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.sessionWatcher(h, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeIdleSessions removes n idle sessions.
|
||||
func (c *PoolClient) removeIdleSessions(n int) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
id string
|
||||
active int64
|
||||
lastActive time.Time
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
candidates := make([]candidate, 0, len(c.dataSessions))
|
||||
for id, h := range c.dataSessions {
|
||||
candidates = append(candidates, candidate{
|
||||
id: id,
|
||||
active: h.active.Load(),
|
||||
lastActive: h.lastActiveTime(),
|
||||
})
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
removed := 0
|
||||
for removed < n {
|
||||
var best candidate
|
||||
found := false
|
||||
for _, cand := range candidates {
|
||||
if cand.active != 0 {
|
||||
continue
|
||||
}
|
||||
if !found || cand.lastActive.Before(best.lastActive) {
|
||||
best = cand
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
if c.removeDataSession(best.id) {
|
||||
removed++
|
||||
}
|
||||
for i := range candidates {
|
||||
if candidates[i].id == best.id {
|
||||
candidates[i].active = 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeDataSession removes a data session by ID.
|
||||
func (c *PoolClient) removeDataSession(id string) bool {
|
||||
var h *sessionHandle
|
||||
|
||||
c.mu.Lock()
|
||||
h = c.dataSessions[id]
|
||||
delete(c.dataSessions, id)
|
||||
c.mu.Unlock()
|
||||
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !h.closed.CompareAndSwap(false, true) {
|
||||
return false
|
||||
}
|
||||
|
||||
if h.session != nil {
|
||||
_ = h.session.Close()
|
||||
}
|
||||
if h.conn != nil {
|
||||
_ = h.conn.Close()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// sessionCount returns the total number of active sessions.
|
||||
func (c *PoolClient) sessionCount() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
count := len(c.dataSessions)
|
||||
if c.primary != nil {
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
@@ -1,490 +1,251 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
|
||||
type Handler struct {
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
responses *ResponseHandler
|
||||
domain string
|
||||
authToken string
|
||||
headerPool *pool.HeaderPool
|
||||
bufferPool *pool.AdaptiveBufferPool
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
authToken string
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler {
|
||||
return &Handler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
responses: responses,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
bufferPool: pool.NewAdaptiveBufferPool(),
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Always handle /health and /stats directly, regardless of subdomain
|
||||
// Always handle /health and /stats directly, regardless of subdomain.
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/stats" {
|
||||
h.serveStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
subdomain := h.extractSubdomain(r.Host)
|
||||
|
||||
if subdomain == "" {
|
||||
h.serveHomePage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := h.manager.Get(subdomain)
|
||||
if !ok {
|
||||
tconn, ok := h.manager.Get(subdomain)
|
||||
if !ok || tconn == nil {
|
||||
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if conn.IsClosed() {
|
||||
if tconn.IsClosed() {
|
||||
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
transport := conn.GetTransport()
|
||||
if transport == nil {
|
||||
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
tType := conn.GetTunnelType()
|
||||
tType := tconn.GetTunnelType()
|
||||
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
|
||||
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := utils.GenerateID()
|
||||
// Check for WebSocket upgrade
|
||||
if httputil.IsWebSocketUpgrade(r) {
|
||||
h.handleWebSocket(w, r, tconn)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
|
||||
}
|
||||
// Open stream with timeout
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
// Track active connections
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
// Wrap stream with counting for traffic stats
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
tconn.AddBytesOut, // Data read from stream = bytes out to client
|
||||
tconn.AddBytesIn, // Data written to stream = bytes in from client
|
||||
)
|
||||
|
||||
// 1) Write request over the stream (net/http handles large bodies correctly).
|
||||
if err := r.Write(countingStream); err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
_ = r.Body.Close()
|
||||
http.Error(w, "Forward failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// 2) Read response from stream.
|
||||
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Read response failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 3) Copy headers (strip hop-by-hop).
|
||||
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
// Ensure message delimiting works with our custom ResponseWriter:
|
||||
// - If Content-Length is known, send it.
|
||||
// - Otherwise, re-chunk the decoded body ourselves.
|
||||
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Del("Content-Length")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
if len(resp.Trailer) > 0 {
|
||||
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var cancelTransport func()
|
||||
if transport != nil {
|
||||
cancelOnce := sync.Once{}
|
||||
cancelFunc := func() {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Debug("Failed to send cancel frame to client",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
cancelTransport = func() {
|
||||
cancelOnce.Do(cancelFunc)
|
||||
}
|
||||
|
||||
h.responses.RegisterCancelFunc(requestID, cancelTransport)
|
||||
defer h.responses.CleanupCancelFunc(requestID)
|
||||
}
|
||||
|
||||
largeBufferPtr := h.bufferPool.GetLarge()
|
||||
tempBufPtr := h.bufferPool.GetMedium()
|
||||
|
||||
defer func() {
|
||||
h.bufferPool.PutLarge(largeBufferPtr)
|
||||
h.bufferPool.PutMedium(tempBufPtr)
|
||||
}()
|
||||
|
||||
buffer := (*largeBufferPtr)[:0]
|
||||
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
|
||||
for totalRead < streamingThreshold {
|
||||
n, err := r.Body.Read(tempBuf)
|
||||
if n > 0 {
|
||||
buffer = append(buffer, tempBuf[:n]...)
|
||||
totalRead += int64(n)
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
r.Body.Close()
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if totalRead >= streamingThreshold {
|
||||
hitThreshold = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReq := protocol.HTTPRequest{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequest,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode data payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReqHead := protocol.HTTPRequestHead{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
ContentLength: r.ContentLength,
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead, // shared streaming head type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode head payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(headFrame); err != nil {
|
||||
h.logger.Error("Send head frame failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if len(bufferedData) > 0 {
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
streamBufPtr := h.bufferPool.GetMedium()
|
||||
defer h.bufferPool.PutMedium(streamBufPtr)
|
||||
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
|
||||
for {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming request cancelled via context",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
default:
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode chunk payload failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send chunk frame failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
if n == 0 {
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(readErr))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
|
||||
h.logger.Debug("Write chunked response failed", zap.Error(err))
|
||||
}
|
||||
close(done)
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
r.Body.Close()
|
||||
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
|
||||
type result struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
s, err := tconn.OpenStream()
|
||||
ch <- result{s, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
case r := <-ch:
|
||||
return r.stream, r.err
|
||||
case <-time.After(openStreamTimeout):
|
||||
return nil, fmt.Errorf("open stream timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
|
||||
if resp == nil {
|
||||
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
|
||||
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// For buffered responses, we have the complete body, so we can set Content-Length
|
||||
// Skip ALL hop-by-hop headers - client should have already cleaned them
|
||||
for key, values := range resp.Headers {
|
||||
tconn.IncActiveConnections()
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
stream.Close()
|
||||
clientConn.Close()
|
||||
tconn.DecActiveConnections()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
|
||||
for key, values := range src {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
// Hop-by-hop headers must not be forwarded.
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
@@ -496,29 +257,61 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
|
||||
if canonicalKey == "Location" && len(values) > 0 {
|
||||
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
||||
w.Header().Set("Location", rewrittenLocation)
|
||||
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func trailerKeys(hdr http.Header) string {
|
||||
keys := make([]string, 0, len(hdr))
|
||||
for k := range hdr {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
// Deterministic order is nicer for debugging; no semantic impact.
|
||||
sortStrings(keys)
|
||||
return strings.Join(keys, ", ")
|
||||
}
|
||||
|
||||
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// For buffered mode, always set Content-Length with the actual body size
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
if _, err := io.WriteString(w, "0\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if len(resp.Body) > 0 {
|
||||
w.Write(resp.Body)
|
||||
for k, vv := range trailer {
|
||||
for _, v := range vv {
|
||||
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
@@ -535,22 +328,13 @@ func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
strings.HasPrefix(locationURL.Host, "localhost:") ||
|
||||
locationURL.Host == "127.0.0.1" ||
|
||||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
|
||||
scheme := "https"
|
||||
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
|
||||
parts := strings.Split(proxyHost, ":")
|
||||
if len(parts) == 2 && parts[1] != "443" {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
|
||||
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
|
||||
if locationURL.RawQuery != "" {
|
||||
rewritten += "?" + locationURL.RawQuery
|
||||
}
|
||||
if locationURL.Fragment != "" {
|
||||
rewritten += "#" + locationURL.Fragment
|
||||
}
|
||||
|
||||
return rewritten
|
||||
}
|
||||
|
||||
@@ -568,8 +352,7 @@ func (h *Handler) extractSubdomain(host string) string {
|
||||
|
||||
suffix := "." + h.domain
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
subdomain := strings.TrimSuffix(host, suffix)
|
||||
return subdomain
|
||||
return strings.TrimSuffix(host, suffix)
|
||||
}
|
||||
|
||||
return ""
|
||||
@@ -652,9 +435,17 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
|
||||
"subdomain": conn.Subdomain,
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"subdomain": conn.Subdomain,
|
||||
"tunnel_type": string(conn.GetTunnelType()),
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"bytes_in": conn.GetBytesIn(),
|
||||
"bytes_out": conn.GetBytesOut(),
|
||||
"active_connections": conn.GetActiveConnections(),
|
||||
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -668,3 +459,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func sortStrings(s []string) {
|
||||
for i := 0; i < len(s); i++ {
|
||||
for j := i + 1; j < len(s); j++ {
|
||||
if s[j] < s[i] {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// responseChanEntry holds a response channel and its creation time
|
||||
type responseChanEntry struct {
|
||||
ch chan *protocol.HTTPResponse
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// streamingResponseEntry holds a streaming response writer
|
||||
type streamingResponseEntry struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
createdAt time.Time
|
||||
lastActivityAt time.Time
|
||||
headersSent bool
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
|
||||
type ResponseHandler struct {
|
||||
channels map[string]*responseChanEntry
|
||||
streamingChannels map[string]*streamingResponseEntry
|
||||
cancelFuncs map[string]func()
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewResponseHandler creates a new response handler
|
||||
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
|
||||
h := &ResponseHandler{
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
streamingChannels: make(map[string]*streamingResponseEntry),
|
||||
cancelFuncs: make(map[string]func()),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start single cleanup goroutine instead of one per request
|
||||
go h.cleanupLoop()
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// CreateResponseChan creates a response channel for a request ID
|
||||
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
ch := make(chan *protocol.HTTPResponse, 1)
|
||||
h.channels[requestID] = &responseChanEntry{
|
||||
ch: ch,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// CreateStreamingResponse creates a streaming response entry for a request ID
|
||||
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
flusher, _ := w.(http.Flusher)
|
||||
done := make(chan struct{})
|
||||
now := time.Now()
|
||||
h.streamingChannels[requestID] = &streamingResponseEntry{
|
||||
w: w,
|
||||
flusher: flusher,
|
||||
createdAt: now,
|
||||
lastActivityAt: now,
|
||||
done: done,
|
||||
}
|
||||
|
||||
return done
|
||||
}
|
||||
|
||||
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
|
||||
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.cancelFuncs[requestID] = cancel
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetResponseChan gets the response channel for a request ID
|
||||
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
if entry := h.channels[requestID]; entry != nil {
|
||||
return entry.ch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendResponse sends a response to the waiting channel
|
||||
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.channels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case entry.ch <- resp:
|
||||
case <-time.After(30 * time.Second):
|
||||
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.Int("body_size", len(resp.Body)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if entry.headersSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy headers, removing hop-by-hop headers that were already handled by client
|
||||
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
|
||||
// But we need to check again in case they slipped through
|
||||
hasContentLength := false
|
||||
|
||||
for key, values := range head.Headers {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip ALL hop-by-hop headers
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
if canonicalKey == "Content-Length" {
|
||||
hasContentLength = true
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
entry.w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// For streaming responses, decide how to indicate message length
|
||||
if head.ContentLength >= 0 && !hasContentLength {
|
||||
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
|
||||
}
|
||||
|
||||
statusCode := head.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
entry.w.WriteHeader(statusCode)
|
||||
entry.headersSent = true
|
||||
entry.lastActivityAt = time.Now()
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if len(chunk) > 0 {
|
||||
_, err := entry.w.Write(chunk)
|
||||
if err != nil {
|
||||
if isClientDisconnectError(err) {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
h.triggerCancel(requestID)
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
h.triggerCancel(requestID)
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.lastActivityAt = time.Now()
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if isLast {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isClientDisconnectError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if netErr, ok := err.(*net.OpError); ok {
|
||||
if netErr.Err != nil {
|
||||
errStr := netErr.Err.Error()
|
||||
if strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "use of closed network connection")
|
||||
}
|
||||
|
||||
// triggerCancel invokes and removes the cancel callback for a request.
|
||||
func (h *ResponseHandler) triggerCancel(requestID string) {
|
||||
h.mu.Lock()
|
||||
cancel := h.cancelFuncs[requestID]
|
||||
if cancel != nil {
|
||||
delete(h.cancelFuncs, requestID)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
go func() {
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.channels[requestID]; exists {
|
||||
close(entry.ch)
|
||||
delete(h.channels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.streamingChannels[requestID]; exists {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupCancelFunc removes a registered cancel callback.
|
||||
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
|
||||
h.mu.Lock()
|
||||
delete(h.cancelFuncs, requestID)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) GetPendingCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.channels) + len(h.streamingChannels)
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
h.cleanupExpiredChannels()
|
||||
case <-h.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 5 * time.Minute
|
||||
streamingTimeout := 5 * time.Minute
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
expiredCount := 0
|
||||
cancelList := make([]string, 0)
|
||||
for requestID, entry := range h.channels {
|
||||
if now.Sub(entry.createdAt) > timeout {
|
||||
close(entry.ch)
|
||||
delete(h.channels, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
for requestID, entry := range h.streamingChannels {
|
||||
if now.Sub(entry.lastActivityAt) > streamingTimeout {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
cancelList = append(cancelList, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, requestID := range cancelList {
|
||||
if cancel := h.cancelFuncs[requestID]; cancel != nil {
|
||||
delete(h.cancelFuncs, requestID)
|
||||
go cancel()
|
||||
}
|
||||
}
|
||||
|
||||
if expiredCount > 0 {
|
||||
h.logger.Debug("Cleaned up expired response channels",
|
||||
zap.Int("count", expiredCount),
|
||||
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) Close() {
|
||||
close(h.stopCh)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, entry := range h.channels {
|
||||
close(entry.ch)
|
||||
}
|
||||
h.channels = make(map[string]*responseChanEntry)
|
||||
|
||||
for _, entry := range h.streamingChannels {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
h.streamingChannels = make(map[string]*streamingResponseEntry)
|
||||
|
||||
for _, cancel := range h.cancelFuncs {
|
||||
cancel()
|
||||
}
|
||||
h.cancelFuncs = make(map[string]func())
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
@@ -33,36 +34,27 @@ type Connection struct {
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
proxy *TunnelProxy
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
lastHeartbeat time.Time
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Flow control
|
||||
paused bool
|
||||
pauseCond *sync.Cond
|
||||
}
|
||||
// gost-like TCP tunnel (yamux)
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
|
||||
// HTTPResponseHandler interface for response channel operations
|
||||
type HTTPResponseHandler interface {
|
||||
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
|
||||
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
||||
CleanupResponseChan(requestID string)
|
||||
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
||||
// Streaming response methods
|
||||
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
|
||||
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
|
||||
// Multi-connection support
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
conn: conn,
|
||||
@@ -73,13 +65,12 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
groupManager: groupManager,
|
||||
}
|
||||
c.pauseCond = sync.NewCond(&c.mu)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -97,8 +88,8 @@ func (c *Connection) Handle() error {
|
||||
// Use buffered reader to support peeking
|
||||
reader := bufio.NewReader(c.conn)
|
||||
|
||||
// Peek first 8 bytes to detect protocol
|
||||
peek, err := reader.Peek(8)
|
||||
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
|
||||
peek, err := reader.Peek(4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
@@ -127,6 +118,11 @@ func (c *Connection) Handle() error {
|
||||
sf := protocol.WithFrame(frame)
|
||||
defer sf.Close()
|
||||
|
||||
// Handle data connection request (for multi-connection pool)
|
||||
if sf.Frame.Type == protocol.FrameTypeDataConnect {
|
||||
return c.handleDataConnect(sf.Frame, reader)
|
||||
}
|
||||
|
||||
if sf.Frame.Type != protocol.FrameTypeRegister {
|
||||
return fmt.Errorf("expected register frame, got %s", sf.Frame.Type)
|
||||
}
|
||||
@@ -180,7 +176,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
// Store TCP connection reference and metadata for HTTP proxy routing
|
||||
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
|
||||
c.tunnelConn.SetTransport(c, req.TunnelType)
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
@@ -208,11 +203,33 @@ func (c *Connection) Handle() error {
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
||||
}
|
||||
|
||||
// Generate TunnelID for multi-connection support if client supports it
|
||||
var tunnelID string
|
||||
var supportsDataConn bool
|
||||
recommendedConns := 0
|
||||
|
||||
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
|
||||
// Client supports connection pooling
|
||||
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
|
||||
tunnelID = group.TunnelID
|
||||
c.tunnelID = tunnelID
|
||||
supportsDataConn = true
|
||||
recommendedConns = 4 // Recommend 4 data connections
|
||||
|
||||
c.logger.Info("Created connection group for multi-connection support",
|
||||
zap.String("tunnel_id", tunnelID),
|
||||
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
|
||||
)
|
||||
}
|
||||
|
||||
resp := protocol.RegisterResponse{
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
@@ -224,6 +241,17 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
// Clear deadline for tunnel data-plane.
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// gost-like tunnels: switch to yamux after RegisterAck.
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
return c.handleTCPTunnel(reader)
|
||||
}
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
return c.handleHTTPProxyTunnel(reader)
|
||||
}
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
c.frameWriter.SetWriteErrorHandler(func(err error) {
|
||||
@@ -231,15 +259,6 @@ func (c *Connection) Handle() error {
|
||||
c.Close()
|
||||
})
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start TCP proxy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
go c.heartbeatChecker()
|
||||
|
||||
return c.handleFrames(reader)
|
||||
@@ -376,7 +395,7 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if isTimeoutError(err) {
|
||||
c.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
@@ -404,15 +423,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.handleHeartbeat()
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
// Data frame from client (response to forwarded request)
|
||||
c.handleDataFrame(sf.Frame)
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeFlowControl:
|
||||
c.handleFlowControl(sf.Frame)
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Client requested close")
|
||||
@@ -436,127 +446,12 @@ func (c *Connection) handleHeartbeat() {
|
||||
// Send heartbeat ack
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
|
||||
|
||||
err := c.frameWriter.WriteFrame(ackFrame)
|
||||
err := c.frameWriter.WriteControl(ackFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDataFrame handles data frame (response from client)
|
||||
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
||||
// Decode payload (auto-detects protocol version)
|
||||
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode data payload",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Received data frame",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.String("type", header.Type.String()),
|
||||
zap.Int("data_size", len(data)),
|
||||
)
|
||||
|
||||
switch header.Type {
|
||||
case protocol.DataTypeResponse:
|
||||
// TCP tunnel response, forward to proxy
|
||||
if c.proxy != nil {
|
||||
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
|
||||
c.logger.Error("Failed to handle response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
case protocol.DataTypeHTTPResponse:
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response channel handler for HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode HTTP response (auto-detects JSON vs msgpack)
|
||||
httpResp, err := protocol.DecodeHTTPResponse(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Route by request ID when provided to keep request/response aligned.
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
c.responseChans.SendResponse(reqID, httpResp)
|
||||
case protocol.DataTypeHTTPHead:
|
||||
// Streaming HTTP response headers
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
httpHead, err := protocol.DecodeHTTPResponseHead(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
|
||||
c.logger.Error("Failed to send streaming head",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeHTTPBodyChunk:
|
||||
// Streaming HTTP response body chunk
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP chunk",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
|
||||
c.logger.Error("Failed to send streaming chunk",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeClose:
|
||||
// Client is closing the stream
|
||||
if c.proxy != nil {
|
||||
c.proxy.CloseStream(header.StreamID)
|
||||
}
|
||||
default:
|
||||
c.logger.Warn("Unknown data frame type",
|
||||
zap.String("type", header.Type.String()),
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatChecker checks for heartbeat timeout
|
||||
func (c *Connection) heartbeatChecker() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
@@ -583,16 +478,6 @@ func (c *Connection) heartbeatChecker() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
if frame.Type == protocol.FrameTypeData {
|
||||
return c.sendWithBackpressure(frame)
|
||||
}
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
func (c *Connection) sendError(code, message string) {
|
||||
errMsg := protocol.ErrorMessage{
|
||||
Code: code,
|
||||
@@ -618,8 +503,12 @@ func (c *Connection) Close() {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
|
||||
if c.conn != nil {
|
||||
_ = c.conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Flush()
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
@@ -627,7 +516,13 @@ func (c *Connection) Close() {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
@@ -635,6 +530,12 @@ func (c *Connection) Close() {
|
||||
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
|
||||
// Clean up connection group when PRIMARY connection closes
|
||||
// (only primary connections have subdomain set)
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
c.groupManager.RemoveGroup(c.tunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
@@ -643,11 +544,6 @@ func (c *Connection) Close() {
|
||||
})
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connection) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
@@ -698,39 +594,196 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
|
||||
func (c *Connection) handleFlowControl(frame *protocol.Frame) {
|
||||
msg, err := protocol.DecodeFlowControlMessage(frame.Payload)
|
||||
// handleDataConnect handles a data connection join request
|
||||
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
|
||||
var req protocol.DataConnectRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
c.sendError("invalid_request", "Failed to parse data connect request")
|
||||
return fmt.Errorf("failed to parse data connect request: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("Data connection request received",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Validate the request
|
||||
if c.groupManager == nil {
|
||||
c.sendDataConnectError("not_supported", "Multi-connection not supported")
|
||||
return fmt.Errorf("group manager not available")
|
||||
}
|
||||
|
||||
// Validate auth token
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
group, ok := c.groupManager.GetGroup(req.TunnelID)
|
||||
if !ok || group == nil {
|
||||
c.sendDataConnectError("join_failed", "Tunnel not found")
|
||||
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token against the primary registration token.
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
// Store tunnelID for cleanup
|
||||
c.tunnelID = req.TunnelID
|
||||
|
||||
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
|
||||
// stream forwarding, not framed request/response routing.
|
||||
if group.TunnelType == protocol.TunnelTypeTCP {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("TCP data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Clear deadline for yamux data-plane.
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add data connection to group
|
||||
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode flow control", zap.Error(err))
|
||||
return
|
||||
c.sendDataConnectError("join_failed", err.Error())
|
||||
return fmt.Errorf("failed to join connection group: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// Send success response
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
switch msg.Action {
|
||||
case protocol.FlowControlPause:
|
||||
c.paused = true
|
||||
c.logger.Warn("Client requested pause",
|
||||
zap.String("stream", msg.StreamID))
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
case protocol.FlowControlResume:
|
||||
c.paused = false
|
||||
c.pauseCond.Broadcast()
|
||||
c.logger.Info("Client requested resume",
|
||||
zap.String("stream", msg.StreamID))
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
c.logger.Warn("Unknown flow control action",
|
||||
zap.String("action", string(msg.Action)))
|
||||
c.logger.Info("Data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Handle data frames on this connection
|
||||
return c.handleDataConnectionFrames(dataConn, reader)
|
||||
}
|
||||
|
||||
// handleDataConnectionFrames handles frames on a data connection
|
||||
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
|
||||
defer func() {
|
||||
// Get the group and remove this data connection
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
|
||||
group.RemoveDataConnection(dataConn.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-dataConn.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
// Timeout is OK, continue
|
||||
if isTimeoutError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
dataConn.mu.Lock()
|
||||
dataConn.LastActive = time.Now()
|
||||
dataConn.mu.Unlock()
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Data connection closed by client",
|
||||
zap.String("connection_id", dataConn.ID))
|
||||
return nil
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type on data connection",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
zap.String("connection_id", dataConn.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) sendWithBackpressure(frame *protocol.Frame) error {
|
||||
c.mu.Lock()
|
||||
for c.paused {
|
||||
c.pauseCond.Wait()
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Fallback for wrapped errors without net.Error
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
// sendDataConnectError sends a data connect error response
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
Message: fmt.Sprintf("%s: %s", code, message),
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
438
internal/server/tcp/connection_group.go
Normal file
438
internal/server/tcp/connection_group.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
||||
type DataConnection struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
LastActive time.Time
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type ConnectionGroup struct {
|
||||
TunnelID string
|
||||
Subdomain string
|
||||
Token string
|
||||
PrimaryConn *Connection
|
||||
DataConns map[string]*DataConnection
|
||||
Sessions map[string]*yamux.Session
|
||||
TunnelType protocol.TunnelType
|
||||
RegisteredAt time.Time
|
||||
LastActivity time.Time
|
||||
sessionIdx uint32
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
logger *zap.Logger
|
||||
|
||||
heartbeatStarted bool
|
||||
}
|
||||
|
||||
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
|
||||
return &ConnectionGroup{
|
||||
TunnelID: tunnelID,
|
||||
Subdomain: subdomain,
|
||||
Token: token,
|
||||
PrimaryConn: primaryConn,
|
||||
DataConns: make(map[string]*DataConnection),
|
||||
Sessions: make(map[string]*yamux.Session),
|
||||
TunnelType: tunnelType,
|
||||
RegisteredAt: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
logger: logger.With(zap.String("tunnel_id", tunnelID)),
|
||||
}
|
||||
}
|
||||
|
||||
// StartHeartbeat starts a goroutine that periodically pings all sessions
|
||||
// and removes dead ones. The caller should ensure this is only called once.
|
||||
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
|
||||
go g.heartbeatLoop(interval, timeout)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
const maxConsecutiveFailures = 3
|
||||
failureCount := make(map[string]int)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
g.mu.RLock()
|
||||
sessions := make(map[string]*yamux.Session, len(g.Sessions))
|
||||
for id, s := range g.Sessions {
|
||||
sessions[id] = s
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
for id, session := range sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ping with timeout
|
||||
done := make(chan error, 1)
|
||||
go func(s *yamux.Session) {
|
||||
_, err := s.Ping()
|
||||
done <- err
|
||||
}(session)
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(timeout):
|
||||
err = fmt.Errorf("ping timeout")
|
||||
case <-g.stopCh:
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
failureCount[id]++
|
||||
g.logger.Debug("Session ping failed",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("consecutive_failures", failureCount[id]),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
if failureCount[id] >= maxConsecutiveFailures {
|
||||
g.logger.Warn("Session ping failed too many times, removing",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("failures", failureCount[id]),
|
||||
)
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
}
|
||||
} else {
|
||||
// Reset on success
|
||||
failureCount[id] = 0
|
||||
g.mu.Lock()
|
||||
g.LastActivity = time.Now()
|
||||
g.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all sessions are gone
|
||||
g.mu.RLock()
|
||||
sessionCount := len(g.Sessions)
|
||||
g.mu.RUnlock()
|
||||
|
||||
if sessionCount == 0 {
|
||||
g.logger.Info("All sessions closed, tunnel will be cleaned up")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
dataConn := &DataConnection{
|
||||
ID: connID,
|
||||
Conn: conn,
|
||||
LastActive: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
g.DataConns[connID] = dataConn
|
||||
g.LastActivity = time.Now()
|
||||
return dataConn
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if dataConn, ok := g.DataConns[connID]; ok {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
delete(g.DataConns, connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) DataConnectionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.DataConns)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) Close() {
|
||||
g.mu.Lock()
|
||||
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
g.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
close(g.stopCh)
|
||||
}
|
||||
|
||||
dataConns := make([]*DataConnection, 0, len(g.DataConns))
|
||||
for _, dataConn := range g.DataConns {
|
||||
dataConns = append(dataConns, dataConn)
|
||||
}
|
||||
g.DataConns = make(map[string]*DataConnection)
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for _, session := range g.Sessions {
|
||||
if session != nil {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
g.Sessions = make(map[string]*yamux.Session)
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, dataConn := range dataConns {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
_ = dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
_ = session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return time.Since(g.LastActivity) > timeout
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
|
||||
if connID == "" || session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
if g.Sessions == nil {
|
||||
g.Sessions = make(map[string]*yamux.Session)
|
||||
}
|
||||
g.Sessions[connID] = session
|
||||
g.LastActivity = time.Now()
|
||||
|
||||
// Start heartbeat on first session
|
||||
shouldStartHeartbeat := !g.heartbeatStarted
|
||||
if shouldStartHeartbeat {
|
||||
g.heartbeatStarted = true
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if shouldStartHeartbeat {
|
||||
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveSession(connID string) {
|
||||
if connID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var session *yamux.Session
|
||||
|
||||
g.mu.Lock()
|
||||
if g.Sessions != nil {
|
||||
session = g.Sessions[connID]
|
||||
delete(g.Sessions, connID)
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if session != nil {
|
||||
_ = session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) SessionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.Sessions)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
const (
|
||||
maxStreamsPerSession = 256
|
||||
maxRetries = 3
|
||||
backoffBase = 25 * time.Millisecond
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return nil, net.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
sessions := g.sessionsSnapshot()
|
||||
if len(sessions) == 0 {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
tried := make([]bool, len(sessions))
|
||||
anyUnderCap := false
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
|
||||
for range sessions {
|
||||
bestIdx := -1
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
idx := (start + i) % len(sessions)
|
||||
if tried[idx] {
|
||||
continue
|
||||
}
|
||||
|
||||
session := sessions[idx]
|
||||
if session == nil || session.IsClosed() {
|
||||
tried[idx] = true
|
||||
continue
|
||||
}
|
||||
|
||||
n := session.NumStreams()
|
||||
if n >= maxStreamsPerSession {
|
||||
continue
|
||||
}
|
||||
anyUnderCap = true
|
||||
|
||||
if n < minStreams {
|
||||
minStreams = n
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if bestIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
tried[bestIdx] = true
|
||||
session := sessions[bestIdx]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
|
||||
stream, err := session.Open()
|
||||
if err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
if session.IsClosed() {
|
||||
g.deleteClosedSessions()
|
||||
}
|
||||
}
|
||||
|
||||
if !anyUnderCap {
|
||||
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
|
||||
}
|
||||
|
||||
if attempt < maxRetries-1 {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return nil, net.ErrClosed
|
||||
case <-time.After(backoffBase * time.Duration(attempt+1)):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("failed to open stream")
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
sessions := g.sessionsSnapshot()
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
var best *yamux.Session
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
session := sessions[(start+i)%len(sessions)]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
if n := session.NumStreams(); n < minStreams {
|
||||
minStreams = n
|
||||
best = session
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if len(g.Sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
delete(g.Sessions, id)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
g.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) deleteClosedSessions() {
|
||||
g.mu.Lock()
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
delete(g.Sessions, id)
|
||||
}
|
||||
}
|
||||
g.mu.Unlock()
|
||||
}
|
||||
163
internal/server/tcp/connection_group_manager.go
Normal file
163
internal/server/tcp/connection_group_manager.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConnectionGroupManager manages all connection groups
|
||||
type ConnectionGroupManager struct {
|
||||
groups map[string]*ConnectionGroup // TunnelID -> ConnectionGroup
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
|
||||
// Cleanup
|
||||
cleanupInterval time.Duration
|
||||
staleTimeout time.Duration
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewConnectionGroupManager creates a new connection group manager
|
||||
func NewConnectionGroupManager(logger *zap.Logger) *ConnectionGroupManager {
|
||||
m := &ConnectionGroupManager{
|
||||
groups: make(map[string]*ConnectionGroup),
|
||||
logger: logger,
|
||||
cleanupInterval: 60 * time.Second,
|
||||
staleTimeout: 5 * time.Minute,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// GenerateTunnelID generates a unique tunnel ID
|
||||
func GenerateTunnelID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// CreateGroup creates a new connection group
|
||||
func (m *ConnectionGroupManager) CreateGroup(subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType) *ConnectionGroup {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tunnelID := GenerateTunnelID()
|
||||
|
||||
group := NewConnectionGroup(tunnelID, subdomain, token, primaryConn, tunnelType, m.logger)
|
||||
|
||||
m.groups[tunnelID] = group
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// GetGroup returns a connection group by tunnel ID
|
||||
func (m *ConnectionGroupManager) GetGroup(tunnelID string) (*ConnectionGroup, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
group, ok := m.groups[tunnelID]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
// RemoveGroup removes and closes a connection group
|
||||
func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
|
||||
m.mu.Lock()
|
||||
group, ok := m.groups[tunnelID]
|
||||
if ok {
|
||||
delete(m.groups, tunnelID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if ok && group != nil {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDataConnection adds a data connection to a group
|
||||
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
|
||||
m.mu.RLock()
|
||||
group, ok := m.groups[req.TunnelID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
dataConn := group.AddDataConnection(req.ConnectionID, conn)
|
||||
|
||||
return dataConn, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up stale groups
|
||||
func (m *ConnectionGroupManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(m.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanupStaleGroups()
|
||||
case <-m.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ConnectionGroupManager) cleanupStaleGroups() {
|
||||
// Collect stale groups under lock
|
||||
m.mu.Lock()
|
||||
var staleGroups []*ConnectionGroup
|
||||
var staleIDs []string
|
||||
for tunnelID, group := range m.groups {
|
||||
if group.IsStale(m.staleTimeout) {
|
||||
staleIDs = append(staleIDs, tunnelID)
|
||||
staleGroups = append(staleGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from map while holding lock
|
||||
for _, tunnelID := range staleIDs {
|
||||
delete(m.groups, tunnelID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close groups without holding lock to avoid blocking other operations
|
||||
for _, group := range staleGroups {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the manager
|
||||
func (m *ConnectionGroupManager) Close() {
|
||||
close(m.stopCh)
|
||||
|
||||
// Collect all groups under lock
|
||||
m.mu.Lock()
|
||||
groups := make([]*ConnectionGroup, 0, len(m.groups))
|
||||
for _, group := range m.groups {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
m.groups = make(map[string]*ConnectionGroup)
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close groups without holding lock
|
||||
for _, group := range groups {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
@@ -12,32 +12,34 @@ import (
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/recovery"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Listener handles TCP connections with TLS 1.3
|
||||
type Listener struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
numCPU := pool.NumCPU()
|
||||
workers := numCPU * 5
|
||||
queueSize := workers * 20
|
||||
@@ -53,21 +55,21 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
return &Listener{
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +208,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -222,14 +224,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
// Client disconnection errors - normal network behavior, log as DEBUG
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
// Client disconnection errors - normal network behavior, ignore
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
l.logger.Debug("Client disconnected",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,6 +276,10 @@ func (l *Listener) Stop() error {
|
||||
l.workerPool.Close()
|
||||
}
|
||||
|
||||
if l.groupManager != nil {
|
||||
l.groupManager.Close()
|
||||
}
|
||||
|
||||
l.logger.Info("TCP listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,64 +1,79 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TunnelProxy handles TCP connections for a specific tunnel
|
||||
type TunnelProxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
tcpConn net.Conn // The tunnel control connection
|
||||
listener net.Listener
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]*proxyStream // streamID -> stream info
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
// Proxy exposes a public TCP port and forwards each incoming
|
||||
// connection over a dedicated mux stream.
|
||||
type Proxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
logger *zap.Logger
|
||||
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
openStream func() (net.Conn, error)
|
||||
stats trafficStats
|
||||
sem chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// proxyStream holds connection info with close state
|
||||
type proxyStream struct {
|
||||
conn net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
type trafficStats interface {
|
||||
AddBytesIn(n int64)
|
||||
AddBytesOut(n int64)
|
||||
IncActiveConnections()
|
||||
DecActiveConnections()
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
tcpConn: tcpConn,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]*proxyStream),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
const maxConcurrentConnections = 10000
|
||||
var sem chan struct{}
|
||||
if maxConcurrentConnections > 0 {
|
||||
sem = make(chan struct{}, maxConcurrentConnections)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
openStream: openStream,
|
||||
stats: stats,
|
||||
sem: sem,
|
||||
ctx: cctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts listening on the allocated port
|
||||
func (p *TunnelProxy) Start() error {
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
p.listener = ln
|
||||
|
||||
p.logger.Info("TCP proxy started",
|
||||
zap.Int("port", p.port),
|
||||
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming TCP connections
|
||||
func (p *TunnelProxy) acceptLoop() {
|
||||
func (p *Proxy) Stop() {
|
||||
p.once.Do(func() {
|
||||
close(p.stopCh)
|
||||
p.cancel()
|
||||
|
||||
if p.listener != nil {
|
||||
_ = p.listener.Close()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
const stopTimeout = 30 * time.Second
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
p.logger.Info("TCP proxy stopped",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
case <-time.After(stopTimeout):
|
||||
p.logger.Warn("TCP proxy stop timed out",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
zap.Duration("timeout", stopTimeout),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) acceptLoop() {
|
||||
defer p.wg.Done()
|
||||
|
||||
tcpLn, _ := p.listener.(*net.TCPListener)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
|
||||
if tcpLn != nil {
|
||||
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
}
|
||||
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
go p.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
func (p *Proxy) handleConn(conn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
stream := &proxyStream{
|
||||
conn: conn,
|
||||
closed: false,
|
||||
if p.sem != nil {
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = stream
|
||||
p.streamMu.Unlock()
|
||||
if p.stats != nil {
|
||||
p.stats.IncActiveConnections()
|
||||
defer p.stats.DecActiveConnections()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p.streamMu.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMu.Unlock()
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
if p.openStream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Open stream with timeout to prevent goroutine leak
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
type streamResult struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan streamResult, 1)
|
||||
|
||||
go func() {
|
||||
s, err := p.openStream()
|
||||
resultCh <- streamResult{s, err}
|
||||
}()
|
||||
|
||||
bufPtr := p.bufferPool.Get(pool.SizeMedium)
|
||||
defer p.bufferPool.Put(bufPtr)
|
||||
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Debug("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
var stream net.Conn
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
if !errors.Is(result.err, net.ErrClosed) {
|
||||
p.logger.Debug("Open stream failed", zap.Error(result.err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
stream = result.stream
|
||||
case <-time.After(openStreamTimeout):
|
||||
p.logger.Debug("Open stream timeout")
|
||||
return
|
||||
case <-p.stopCh:
|
||||
default:
|
||||
p.sendCloseToTunnel(streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return fmt.Errorf("tunnel proxy stopped")
|
||||
default:
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeData,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode payload: %w", err)
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
err = p.frameWriter.WriteFrame(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
p.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
// Stream may have been closed by client, this is normal
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
if _, err := stream.conn.Write(data); err != nil {
|
||||
p.logger.Debug("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as closed first
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Now close the connection
|
||||
stream.conn.Close()
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
p.logger.Info("Stopping TCP proxy",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesIn(n)
|
||||
}
|
||||
},
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesOut(n)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
close(p.stopCh)
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, stream := range p.streams {
|
||||
stream.mu.Lock()
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
stream.conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]*proxyStream)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.frameWriter != nil {
|
||||
p.frameWriter.Close()
|
||||
}
|
||||
|
||||
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
|
||||
}
|
||||
|
||||
98
internal/server/tcp/tunnel.go
Normal file
98
internal/server/tcp/tunnel.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
}
|
||||
|
||||
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tcp proxy: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
}
|
||||
|
||||
if c.tunnelConn != nil {
|
||||
c.tunnelConn.SetOpenStream(openStream)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
@@ -9,13 +11,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Transport represents the control channel to the client.
|
||||
// It is implemented by the TCP control connection so the HTTP proxy
|
||||
// can push frames directly to the client without depending on WebSockets.
|
||||
type Transport interface {
|
||||
SendFrame(frame *protocol.Frame) error
|
||||
}
|
||||
|
||||
// Connection represents a tunnel connection from a client
|
||||
type Connection struct {
|
||||
Subdomain string
|
||||
@@ -26,8 +21,12 @@ type Connection struct {
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
closed bool
|
||||
transport Transport
|
||||
tunnelType protocol.TunnelType
|
||||
openStream func() (net.Conn, error)
|
||||
|
||||
bytesIn atomic.Int64
|
||||
bytesOut atomic.Int64
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
// NewConnection creates a new tunnel connection
|
||||
@@ -106,21 +105,6 @@ func (c *Connection) IsClosed() bool {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// SetTransport attaches the control transport and tunnel type.
|
||||
func (c *Connection) SetTransport(t Transport, tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
c.transport = t
|
||||
c.tunnelType = tType
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetTransport returns the attached transport (if any).
|
||||
func (c *Connection) GetTransport() Transport {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.transport
|
||||
}
|
||||
|
||||
// SetTunnelType sets the tunnel type.
|
||||
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
@@ -135,6 +119,63 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
|
||||
return c.tunnelType
|
||||
}
|
||||
|
||||
// SetOpenStream registers a yamux stream opener for this tunnel.
|
||||
// It is used by the HTTP proxy to forward each request over a mux stream.
|
||||
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
|
||||
c.mu.Lock()
|
||||
c.openStream = open
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// OpenStream opens a new mux stream to the tunnel client.
|
||||
func (c *Connection) OpenStream() (net.Conn, error) {
|
||||
c.mu.RLock()
|
||||
open := c.openStream
|
||||
closed := c.closed
|
||||
c.mu.RUnlock()
|
||||
|
||||
if closed || open == nil {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
return open()
|
||||
}
|
||||
|
||||
func (c *Connection) AddBytesIn(n int64) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
c.bytesIn.Add(n)
|
||||
}
|
||||
|
||||
func (c *Connection) AddBytesOut(n int64) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
c.bytesOut.Add(n)
|
||||
}
|
||||
|
||||
func (c *Connection) GetBytesIn() int64 {
|
||||
return c.bytesIn.Load()
|
||||
}
|
||||
|
||||
func (c *Connection) GetBytesOut() int64 {
|
||||
return c.bytesOut.Load()
|
||||
}
|
||||
|
||||
func (c *Connection) IncActiveConnections() {
|
||||
c.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func (c *Connection) DecActiveConnections() {
|
||||
if v := c.activeConnections.Add(-1); v < 0 {
|
||||
c.activeConnections.Store(0)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) GetActiveConnections() int64 {
|
||||
return c.activeConnections.Load()
|
||||
}
|
||||
|
||||
// StartWritePump starts the write pump for sending messages
|
||||
func (c *Connection) StartWritePump() {
|
||||
// Skip write pump for TCP-only connections (no WebSocket)
|
||||
|
||||
@@ -1,280 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Decoder decompresses HPACK-encoded headers
|
||||
// Each connection MUST have its own decoder instance to maintain correct state
|
||||
type Decoder struct {
|
||||
mu sync.Mutex
|
||||
dynamicTable *DynamicTable
|
||||
staticTable *StaticTable
|
||||
maxTableSize uint32
|
||||
}
|
||||
|
||||
// NewDecoder creates a new HPACK decoder
|
||||
func NewDecoder(maxTableSize uint32) *Decoder {
|
||||
if maxTableSize == 0 {
|
||||
maxTableSize = DefaultDynamicTableSize
|
||||
}
|
||||
|
||||
return &Decoder{
|
||||
dynamicTable: NewDynamicTable(maxTableSize),
|
||||
staticTable: GetStaticTable(),
|
||||
maxTableSize: maxTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes HPACK-encoded headers
|
||||
func (d *Decoder) Decode(data []byte) (http.Header, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if len(data) == 0 {
|
||||
return http.Header{}, nil
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
for buf.Len() > 0 {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read header byte: %w", err)
|
||||
}
|
||||
|
||||
// Unread the byte so we can process it properly
|
||||
if err := buf.UnreadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var name, value string
|
||||
|
||||
if b&0x80 != 0 {
|
||||
// Indexed header field (10xxxxxx)
|
||||
name, value, err = d.decodeIndexedHeader(buf)
|
||||
} else if b&0x40 != 0 {
|
||||
// Literal with incremental indexing (01xxxxxx)
|
||||
name, value, err = d.decodeLiteralWithIndexing(buf)
|
||||
} else {
|
||||
// Literal without indexing (0000xxxx)
|
||||
name, value, err = d.decodeLiteralWithoutIndexing(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers.Add(name, value)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// decodeIndexedHeader decodes an indexed header field
|
||||
func (d *Decoder) decodeIndexedHeader(buf *bytes.Reader) (string, string, error) {
|
||||
index, err := d.readInteger(buf, 7)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read index: %w", err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
return "", "", errors.New("invalid index: 0")
|
||||
}
|
||||
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
|
||||
if index <= staticSize {
|
||||
// Static table
|
||||
return d.staticTable.Get(index - 1)
|
||||
}
|
||||
|
||||
// Dynamic table (indices start after static table)
|
||||
dynamicIndex := index - staticSize - 1
|
||||
return d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
|
||||
// decodeLiteralWithIndexing decodes a literal header with incremental indexing
|
||||
func (d *Decoder) decodeLiteralWithIndexing(buf *bytes.Reader) (string, string, error) {
|
||||
nameIndex, err := d.readInteger(buf, 6)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var name string
|
||||
if nameIndex == 0 {
|
||||
// Name is literal
|
||||
name, err = d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read name: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Name is indexed
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
if nameIndex <= staticSize {
|
||||
name, _, err = d.staticTable.Get(nameIndex - 1)
|
||||
} else {
|
||||
dynamicIndex := nameIndex - staticSize - 1
|
||||
name, _, err = d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get indexed name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Value is always literal
|
||||
value, err := d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read value: %w", err)
|
||||
}
|
||||
|
||||
// Add to dynamic table
|
||||
d.dynamicTable.Add(name, value)
|
||||
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
// decodeLiteralWithoutIndexing decodes a literal header without indexing
|
||||
func (d *Decoder) decodeLiteralWithoutIndexing(buf *bytes.Reader) (string, string, error) {
|
||||
nameIndex, err := d.readInteger(buf, 4)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var name string
|
||||
if nameIndex == 0 {
|
||||
// Name is literal
|
||||
name, err = d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read name: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Name is indexed
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
if nameIndex <= staticSize {
|
||||
name, _, err = d.staticTable.Get(nameIndex - 1)
|
||||
} else {
|
||||
dynamicIndex := nameIndex - staticSize - 1
|
||||
name, _, err = d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get indexed name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Value is always literal
|
||||
value, err := d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read value: %w", err)
|
||||
}
|
||||
|
||||
// Do NOT add to dynamic table
|
||||
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
// readInteger reads an HPACK integer
|
||||
func (d *Decoder) readInteger(buf *bytes.Reader, prefixBits int) (uint32, error) {
|
||||
if prefixBits < 1 || prefixBits > 8 {
|
||||
return 0, fmt.Errorf("invalid prefix bits: %d", prefixBits)
|
||||
}
|
||||
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
maxPrefix := uint32((1 << prefixBits) - 1)
|
||||
mask := byte(maxPrefix)
|
||||
|
||||
value := uint32(b & mask)
|
||||
if value < maxPrefix {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Multi-byte integer
|
||||
m := uint32(0)
|
||||
for {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
value += uint32(b&0x7f) << m
|
||||
m += 7
|
||||
|
||||
if b&0x80 == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if m > 28 {
|
||||
return 0, errors.New("integer overflow")
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// readString reads an HPACK string
|
||||
func (d *Decoder) readString(buf *bytes.Reader) (string, error) {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := buf.UnreadByte(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
huffmanEncoded := (b & 0x80) != 0
|
||||
|
||||
length, err := d.readInteger(buf, 7)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read string length: %w", err)
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if length > uint32(buf.Len()) {
|
||||
return "", fmt.Errorf("string length %d exceeds buffer size %d", length, buf.Len())
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err := buf.Read(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n != int(length) {
|
||||
return "", fmt.Errorf("expected %d bytes, read %d", length, n)
|
||||
}
|
||||
|
||||
if huffmanEncoded {
|
||||
// TODO: Implement Huffman decoding if needed
|
||||
return "", errors.New("huffman decoding not implemented")
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// SetMaxTableSize updates the dynamic table size
|
||||
func (d *Decoder) SetMaxTableSize(size uint32) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.maxTableSize = size
|
||||
d.dynamicTable.SetMaxSize(size)
|
||||
}
|
||||
|
||||
// Reset clears the dynamic table
|
||||
func (d *Decoder) Reset() {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.dynamicTable = NewDynamicTable(d.maxTableSize)
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DynamicTable implements the HPACK dynamic table (RFC 7541 Section 2.3.2)
|
||||
// The dynamic table is a FIFO queue where new entries are added at the beginning
|
||||
// and old entries are evicted when the table size exceeds the maximum
|
||||
type DynamicTable struct {
|
||||
entries []HeaderField
|
||||
size uint32 // Current size in bytes
|
||||
maxSize uint32 // Maximum size in bytes
|
||||
}
|
||||
|
||||
// HeaderField represents a header name-value pair
|
||||
type HeaderField struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Size returns the size of this header field in bytes
|
||||
// RFC 7541: size = len(name) + len(value) + 32
|
||||
func (h *HeaderField) Size() uint32 {
|
||||
return uint32(len(h.Name) + len(h.Value) + 32)
|
||||
}
|
||||
|
||||
// NewDynamicTable creates a new dynamic table with the specified maximum size
|
||||
func NewDynamicTable(maxSize uint32) *DynamicTable {
|
||||
return &DynamicTable{
|
||||
entries: make([]HeaderField, 0, 32),
|
||||
size: 0,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a header field to the dynamic table
|
||||
// New entries are added at the beginning (index 0)
|
||||
func (dt *DynamicTable) Add(name, value string) {
|
||||
field := HeaderField{Name: name, Value: value}
|
||||
fieldSize := field.Size()
|
||||
|
||||
// If the field is larger than maxSize, don't add it
|
||||
if fieldSize > dt.maxSize {
|
||||
dt.evictAll()
|
||||
return
|
||||
}
|
||||
|
||||
// Evict entries if necessary to make room
|
||||
for dt.size+fieldSize > dt.maxSize && len(dt.entries) > 0 {
|
||||
dt.evictOldest()
|
||||
}
|
||||
|
||||
// Add new entry at the beginning
|
||||
dt.entries = append([]HeaderField{field}, dt.entries...)
|
||||
dt.size += fieldSize
|
||||
}
|
||||
|
||||
// Get retrieves a header field by index (0-based)
|
||||
// Index 0 is the most recently added entry
|
||||
func (dt *DynamicTable) Get(index uint32) (string, string, error) {
|
||||
if index >= uint32(len(dt.entries)) {
|
||||
return "", "", fmt.Errorf("index %d out of range (table size: %d)", index, len(dt.entries))
|
||||
}
|
||||
|
||||
field := dt.entries[index]
|
||||
return field.Name, field.Value, nil
|
||||
}
|
||||
|
||||
// FindExact searches for an exact match (name and value)
|
||||
// Returns the index (0-based) and true if found
|
||||
func (dt *DynamicTable) FindExact(name, value string) (uint32, bool) {
|
||||
for i, field := range dt.entries {
|
||||
if field.Name == name && field.Value == value {
|
||||
return uint32(i), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// FindName searches for a name match
|
||||
// Returns the index (0-based) and true if found
|
||||
func (dt *DynamicTable) FindName(name string) (uint32, bool) {
|
||||
for i, field := range dt.entries {
|
||||
if field.Name == name {
|
||||
return uint32(i), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// SetMaxSize updates the maximum table size
|
||||
// If the new size is smaller, entries are evicted
|
||||
func (dt *DynamicTable) SetMaxSize(maxSize uint32) {
|
||||
dt.maxSize = maxSize
|
||||
|
||||
// Evict entries if current size exceeds new max
|
||||
for dt.size > dt.maxSize && len(dt.entries) > 0 {
|
||||
dt.evictOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentSize returns the current size of the table in bytes
|
||||
func (dt *DynamicTable) CurrentSize() uint32 {
|
||||
return dt.size
|
||||
}
|
||||
|
||||
// evictOldest removes the oldest entry (last in the slice)
|
||||
func (dt *DynamicTable) evictOldest() {
|
||||
if len(dt.entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
lastIndex := len(dt.entries) - 1
|
||||
evicted := dt.entries[lastIndex]
|
||||
dt.entries = dt.entries[:lastIndex]
|
||||
dt.size -= evicted.Size()
|
||||
}
|
||||
|
||||
// evictAll removes all entries
|
||||
func (dt *DynamicTable) evictAll() {
|
||||
dt.entries = dt.entries[:0]
|
||||
dt.size = 0
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDynamicTableSize is the default size of the dynamic table (4KB)
|
||||
DefaultDynamicTableSize = 4096
|
||||
|
||||
// IndexedHeaderField represents a fully indexed header field
|
||||
indexedHeaderField = 0x80 // 10xxxxxx
|
||||
|
||||
// LiteralHeaderFieldWithIndexing represents a literal with incremental indexing
|
||||
literalHeaderFieldWithIndexing = 0x40 // 01xxxxxx
|
||||
)
|
||||
|
||||
// Encoder compresses HTTP headers using HPACK
|
||||
// Each connection MUST have its own encoder instance to avoid state corruption
|
||||
type Encoder struct {
|
||||
mu sync.Mutex
|
||||
dynamicTable *DynamicTable
|
||||
staticTable *StaticTable
|
||||
maxTableSize uint32
|
||||
}
|
||||
|
||||
// NewEncoder creates a new HPACK encoder with the specified dynamic table size
|
||||
// This encoder is NOT thread-safe and should be used by a single connection
|
||||
func NewEncoder(maxTableSize uint32) *Encoder {
|
||||
if maxTableSize == 0 {
|
||||
maxTableSize = DefaultDynamicTableSize
|
||||
}
|
||||
|
||||
return &Encoder{
|
||||
dynamicTable: NewDynamicTable(maxTableSize),
|
||||
staticTable: GetStaticTable(),
|
||||
maxTableSize: maxTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes HTTP headers into HPACK binary format
|
||||
// This method is safe to call concurrently within the same encoder instance
|
||||
func (e *Encoder) Encode(headers http.Header) ([]byte, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if headers == nil {
|
||||
return nil, errors.New("headers cannot be nil")
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for name, values := range headers {
|
||||
for _, value := range values {
|
||||
if err := e.encodeHeaderField(buf, name, value); err != nil {
|
||||
return nil, fmt.Errorf("encode header %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// encodeHeaderField encodes a single header field
|
||||
func (e *Encoder) encodeHeaderField(buf *bytes.Buffer, name, value string) error {
|
||||
// HTTP/2 requires header names to be lowercase (RFC 7540 Section 8.1.2)
|
||||
// Convert to lowercase for table lookups and storage
|
||||
nameLower := strings.ToLower(name)
|
||||
|
||||
// Try to find in static table first
|
||||
if index, found := e.staticTable.FindExact(nameLower, value); found {
|
||||
return e.writeIndexedHeader(buf, index+1)
|
||||
}
|
||||
|
||||
// Check if name exists in static table (for literal with name reference)
|
||||
if index, found := e.staticTable.FindName(nameLower); found {
|
||||
return e.writeLiteralWithIndexing(buf, index+1, value, true)
|
||||
}
|
||||
|
||||
// Try dynamic table
|
||||
if index, found := e.dynamicTable.FindExact(nameLower, value); found {
|
||||
// Dynamic table indices start after static table
|
||||
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
|
||||
return e.writeIndexedHeader(buf, dynamicIndex)
|
||||
}
|
||||
|
||||
if index, found := e.dynamicTable.FindName(nameLower); found {
|
||||
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
|
||||
return e.writeLiteralWithIndexing(buf, dynamicIndex, value, true)
|
||||
}
|
||||
|
||||
// Not found anywhere - literal with indexing and new name
|
||||
// Write literal flag
|
||||
buf.WriteByte(literalHeaderFieldWithIndexing)
|
||||
|
||||
// Write name as literal string (must come before value)
|
||||
// Use lowercase name for consistency
|
||||
if err := e.writeString(buf, nameLower, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write value as literal string
|
||||
if err := e.writeString(buf, value, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add to dynamic table with lowercase name
|
||||
e.dynamicTable.Add(nameLower, value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeIndexedHeader writes an indexed header field (10xxxxxx)
|
||||
func (e *Encoder) writeIndexedHeader(buf *bytes.Buffer, index uint32) error {
|
||||
return e.writeInteger(buf, index, 7, indexedHeaderField)
|
||||
}
|
||||
|
||||
// writeLiteralWithIndexing writes a literal header with incremental indexing (01xxxxxx)
|
||||
func (e *Encoder) writeLiteralWithIndexing(buf *bytes.Buffer, nameIndex uint32, value string, hasIndex bool) error {
|
||||
if hasIndex {
|
||||
// Write name as index
|
||||
if err := e.writeInteger(buf, nameIndex, 6, literalHeaderFieldWithIndexing); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Write literal flag
|
||||
buf.WriteByte(literalHeaderFieldWithIndexing)
|
||||
}
|
||||
|
||||
// Write value as literal string
|
||||
return e.writeString(buf, value, false)
|
||||
}
|
||||
|
||||
// writeInteger writes an integer using HPACK integer representation
|
||||
func (e *Encoder) writeInteger(buf *bytes.Buffer, value uint32, prefixBits int, prefix byte) error {
|
||||
if prefixBits < 1 || prefixBits > 8 {
|
||||
return fmt.Errorf("invalid prefix bits: %d", prefixBits)
|
||||
}
|
||||
|
||||
maxPrefix := uint32((1 << prefixBits) - 1)
|
||||
|
||||
if value < maxPrefix {
|
||||
buf.WriteByte(prefix | byte(value))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value >= maxPrefix, need multiple bytes
|
||||
buf.WriteByte(prefix | byte(maxPrefix))
|
||||
value -= maxPrefix
|
||||
|
||||
for value >= 128 {
|
||||
buf.WriteByte(byte(value%128) | 0x80)
|
||||
value /= 128
|
||||
}
|
||||
buf.WriteByte(byte(value))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeString writes a string using HPACK string representation
|
||||
func (e *Encoder) writeString(buf *bytes.Buffer, str string, huffmanEncode bool) error {
|
||||
// For simplicity, we don't use Huffman encoding in this implementation
|
||||
// Huffman flag is bit 7, followed by length in remaining 7 bits
|
||||
|
||||
length := uint32(len(str))
|
||||
if huffmanEncode {
|
||||
// TODO: Implement Huffman encoding if needed
|
||||
return errors.New("huffman encoding not implemented")
|
||||
}
|
||||
|
||||
// Write length with H=0 (no Huffman)
|
||||
if err := e.writeInteger(buf, length, 7, 0x00); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write string bytes
|
||||
buf.WriteString(str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxTableSize updates the dynamic table size
|
||||
func (e *Encoder) SetMaxTableSize(size uint32) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.maxTableSize = size
|
||||
e.dynamicTable.SetMaxSize(size)
|
||||
}
|
||||
|
||||
// Reset clears the dynamic table
|
||||
func (e *Encoder) Reset() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.dynamicTable = NewDynamicTable(e.maxTableSize)
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// StaticTable implements the HPACK static table (RFC 7541 Appendix A)
|
||||
// The static table is predefined and never changes
|
||||
type StaticTable struct {
|
||||
entries []HeaderField
|
||||
nameMap map[string][]uint32 // Maps name to list of indices
|
||||
}
|
||||
|
||||
var (
|
||||
staticTableInstance *StaticTable
|
||||
staticTableOnce sync.Once
|
||||
)
|
||||
|
||||
// GetStaticTable returns the singleton static table instance
|
||||
func GetStaticTable() *StaticTable {
|
||||
staticTableOnce.Do(func() {
|
||||
staticTableInstance = newStaticTable()
|
||||
})
|
||||
return staticTableInstance
|
||||
}
|
||||
|
||||
// newStaticTable creates and initializes the static table
|
||||
func newStaticTable() *StaticTable {
|
||||
// RFC 7541 Appendix A - Static Table Definition
|
||||
// We include the most common headers for HTTP
|
||||
entries := []HeaderField{
|
||||
{Name: ":authority", Value: ""},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":method", Value: "POST"},
|
||||
{Name: ":path", Value: "/"},
|
||||
{Name: ":path", Value: "/index.html"},
|
||||
{Name: ":scheme", Value: "http"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: ":status", Value: "204"},
|
||||
{Name: ":status", Value: "206"},
|
||||
{Name: ":status", Value: "304"},
|
||||
{Name: ":status", Value: "400"},
|
||||
{Name: ":status", Value: "404"},
|
||||
{Name: ":status", Value: "500"},
|
||||
{Name: "accept-charset", Value: ""},
|
||||
{Name: "accept-encoding", Value: "gzip, deflate"},
|
||||
{Name: "accept-language", Value: ""},
|
||||
{Name: "accept-ranges", Value: ""},
|
||||
{Name: "accept", Value: ""},
|
||||
{Name: "access-control-allow-origin", Value: ""},
|
||||
{Name: "age", Value: ""},
|
||||
{Name: "allow", Value: ""},
|
||||
{Name: "authorization", Value: ""},
|
||||
{Name: "cache-control", Value: ""},
|
||||
{Name: "content-disposition", Value: ""},
|
||||
{Name: "content-encoding", Value: ""},
|
||||
{Name: "content-language", Value: ""},
|
||||
{Name: "content-length", Value: ""},
|
||||
{Name: "content-location", Value: ""},
|
||||
{Name: "content-range", Value: ""},
|
||||
{Name: "content-type", Value: ""},
|
||||
{Name: "cookie", Value: ""},
|
||||
{Name: "date", Value: ""},
|
||||
{Name: "etag", Value: ""},
|
||||
{Name: "expect", Value: ""},
|
||||
{Name: "expires", Value: ""},
|
||||
{Name: "from", Value: ""},
|
||||
{Name: "host", Value: ""},
|
||||
{Name: "if-match", Value: ""},
|
||||
{Name: "if-modified-since", Value: ""},
|
||||
{Name: "if-none-match", Value: ""},
|
||||
{Name: "if-range", Value: ""},
|
||||
{Name: "if-unmodified-since", Value: ""},
|
||||
{Name: "last-modified", Value: ""},
|
||||
{Name: "link", Value: ""},
|
||||
{Name: "location", Value: ""},
|
||||
{Name: "max-forwards", Value: ""},
|
||||
{Name: "proxy-authenticate", Value: ""},
|
||||
{Name: "proxy-authorization", Value: ""},
|
||||
{Name: "range", Value: ""},
|
||||
{Name: "referer", Value: ""},
|
||||
{Name: "refresh", Value: ""},
|
||||
{Name: "retry-after", Value: ""},
|
||||
{Name: "server", Value: ""},
|
||||
{Name: "set-cookie", Value: ""},
|
||||
{Name: "strict-transport-security", Value: ""},
|
||||
{Name: "transfer-encoding", Value: ""},
|
||||
{Name: "user-agent", Value: ""},
|
||||
{Name: "vary", Value: ""},
|
||||
{Name: "via", Value: ""},
|
||||
{Name: "www-authenticate", Value: ""},
|
||||
}
|
||||
|
||||
// Build name index map
|
||||
nameMap := make(map[string][]uint32)
|
||||
for i, entry := range entries {
|
||||
nameMap[entry.Name] = append(nameMap[entry.Name], uint32(i))
|
||||
}
|
||||
|
||||
return &StaticTable{
|
||||
entries: entries,
|
||||
nameMap: nameMap,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header field by index (0-based)
|
||||
func (st *StaticTable) Get(index uint32) (string, string, error) {
|
||||
if index >= uint32(len(st.entries)) {
|
||||
return "", "", fmt.Errorf("index %d out of range (static table size: %d)", index, len(st.entries))
|
||||
}
|
||||
|
||||
field := st.entries[index]
|
||||
return field.Name, field.Value, nil
|
||||
}
|
||||
|
||||
// FindExact searches for an exact match (name and value)
|
||||
// Returns the index (0-based) and true if found
|
||||
func (st *StaticTable) FindExact(name, value string) (uint32, bool) {
|
||||
indices, exists := st.nameMap[name]
|
||||
if !exists {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, index := range indices {
|
||||
field := st.entries[index]
|
||||
if field.Value == value {
|
||||
return index, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// FindName searches for a name match
|
||||
// Returns the first matching index (0-based) and true if found
|
||||
func (st *StaticTable) FindName(name string) (uint32, bool) {
|
||||
indices, exists := st.nameMap[name]
|
||||
if !exists || len(indices) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return indices[0], true
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the static table
|
||||
func (st *StaticTable) Size() int {
|
||||
return len(st.entries)
|
||||
}
|
||||
@@ -9,6 +9,10 @@ const (
|
||||
// DefaultWSPort is the default WebSocket port
|
||||
DefaultWSPort = 8080
|
||||
|
||||
// YamuxAcceptBacklog controls how many incoming streams can be queued
|
||||
// before yamux starts blocking stream opens under load.
|
||||
YamuxAcceptBacklog = 4096
|
||||
|
||||
// HeartbeatInterval is how often clients send heartbeat messages
|
||||
HeartbeatInterval = 2 * time.Second
|
||||
|
||||
|
||||
71
internal/shared/httputil/helpers.go
Normal file
71
internal/shared/httputil/helpers.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CopyHeaders copies all headers from src to dst.
|
||||
func CopyHeaders(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CleanHopByHopHeaders removes hop-by-hop headers that should not be forwarded.
|
||||
func CleanHopByHopHeaders(headers http.Header) {
|
||||
if headers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if connectionHeaders := headers.Get("Connection"); connectionHeaders != "" {
|
||||
for _, token := range strings.Split(connectionHeaders, ",") {
|
||||
if t := strings.TrimSpace(token); t != "" {
|
||||
headers.Del(http.CanonicalHeaderKey(t))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
"Proxy-Connection",
|
||||
} {
|
||||
headers.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteProxyError writes an HTTP error response to the writer.
|
||||
func WriteProxyError(w io.Writer, code int, msg string) {
|
||||
body := msg
|
||||
resp := &http.Response{
|
||||
StatusCode: code,
|
||||
Status: fmt.Sprintf("%d %s", code, http.StatusText(code)),
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
Close: true,
|
||||
}
|
||||
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
|
||||
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
_ = resp.Write(w)
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade checks if the request is a WebSocket upgrade request.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") &&
|
||||
strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||
}
|
||||
35
internal/shared/netutil/counting_conn.go
Normal file
35
internal/shared/netutil/counting_conn.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package netutil
|
||||
|
||||
import "net"
|
||||
|
||||
// CountingConn wraps a net.Conn to count bytes read/written.
|
||||
type CountingConn struct {
|
||||
net.Conn
|
||||
OnRead func(int64)
|
||||
OnWrite func(int64)
|
||||
}
|
||||
|
||||
// NewCountingConn creates a new CountingConn.
|
||||
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
|
||||
return &CountingConn{
|
||||
Conn: conn,
|
||||
OnRead: onRead,
|
||||
OnWrite: onWrite,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CountingConn) Read(p []byte) (int, error) {
|
||||
n, err := c.Conn.Read(p)
|
||||
if n > 0 && c.OnRead != nil {
|
||||
c.OnRead(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *CountingConn) Write(p []byte) (int, error) {
|
||||
n, err := c.Conn.Write(p)
|
||||
if n > 0 && c.OnWrite != nil {
|
||||
c.OnWrite(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
164
internal/shared/netutil/pipe.go
Normal file
164
internal/shared/netutil/pipe.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
const tcpWaitTimeout = 10 * time.Second
|
||||
|
||||
type closeReader interface {
|
||||
CloseRead() error
|
||||
}
|
||||
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type readDeadliner interface {
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// Pipe copies bytes bidirectionally between a and b (gost-like),
|
||||
// and applies TCP half-close when supported.
|
||||
func Pipe(ctx context.Context, a, b io.ReadWriteCloser) error {
|
||||
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, nil, nil)
|
||||
}
|
||||
|
||||
// PipeWithCallbacks is Pipe with optional byte counters for each direction:
|
||||
// onAToB is called with bytes copied from a -> b, onBToA for b -> a.
|
||||
func PipeWithCallbacks(ctx context.Context, a, b io.ReadWriteCloser, onAToB func(n int64), onBToA func(n int64)) error {
|
||||
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, onAToB, onBToA)
|
||||
}
|
||||
|
||||
// PipeWithBufferSize is Pipe with a custom buffer size.
|
||||
func PipeWithBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int) error {
|
||||
return PipeWithCallbacksAndBufferSize(ctx, a, b, bufSize, nil, nil)
|
||||
}
|
||||
|
||||
// PipeWithCallbacksAndBufferSize is PipeWithCallbacks with a custom buffer size.
|
||||
func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int, onAToB func(n int64), onBToA func(n int64)) error {
|
||||
if bufSize <= 0 {
|
||||
bufSize = pool.SizeMedium
|
||||
}
|
||||
if bufSize > pool.SizeLarge {
|
||||
bufSize = pool.SizeLarge
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
var closeOnce sync.Once
|
||||
closeAll := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(stopCh)
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
})
|
||||
}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
if ctx != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeAll()
|
||||
case <-stopCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
closeAll()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := pipeBuffer(a, b, bufSize, onBToA, stopCh)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
closeAll()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pipeBuffer(dst io.ReadWriteCloser, src io.ReadWriteCloser, bufSize int, onCopied func(n int64), stopCh <-chan struct{}) error {
|
||||
bufPtr := pool.GetBuffer(bufSize)
|
||||
defer pool.PutBuffer(bufPtr)
|
||||
|
||||
buf := (*bufPtr)[:bufSize]
|
||||
_, err := copyBuffer(dst, src, buf, onCopied, stopCh)
|
||||
|
||||
if cr, ok := src.(closeReader); ok {
|
||||
_ = cr.CloseRead()
|
||||
}
|
||||
|
||||
if cw, ok := dst.(closeWriter); ok {
|
||||
if e := cw.CloseWrite(); e != nil {
|
||||
_ = dst.Close()
|
||||
}
|
||||
if rd, ok := dst.(readDeadliner); ok {
|
||||
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
|
||||
}
|
||||
} else {
|
||||
_ = dst.Close()
|
||||
if rd, ok := dst.(readDeadliner); ok {
|
||||
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func copyBuffer(dst io.Writer, src io.Reader, buf []byte, onCopied func(n int64), stopCh <-chan struct{}) (written int64, err error) {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return written, io.EOF
|
||||
default:
|
||||
}
|
||||
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
if onCopied != nil {
|
||||
onCopied(int64(nw))
|
||||
}
|
||||
}
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er == io.EOF {
|
||||
return written, nil
|
||||
}
|
||||
return written, er
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// AdaptiveBufferPool manages reusable buffers of different sizes
|
||||
// This eliminates the massive memory allocation overhead seen in profiling
|
||||
type AdaptiveBufferPool struct {
|
||||
// Large buffers for streaming threshold (1MB)
|
||||
largePool *sync.Pool
|
||||
|
||||
// Medium buffers for temporary reads (32KB)
|
||||
mediumPool *sync.Pool
|
||||
}
|
||||
|
||||
const (
|
||||
// LargeBufferSize is 1MB for streaming threshold
|
||||
LargeBufferSize = 1 * 1024 * 1024
|
||||
|
||||
// MediumBufferSize is 32KB for temporary reads
|
||||
MediumBufferSize = 32 * 1024
|
||||
)
|
||||
|
||||
// NewAdaptiveBufferPool creates a new adaptive buffer pool
|
||||
func NewAdaptiveBufferPool() *AdaptiveBufferPool {
|
||||
return &AdaptiveBufferPool{
|
||||
largePool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, LargeBufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
mediumPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, MediumBufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetLarge returns a large buffer (1MB) from the pool
|
||||
// The returned buffer should be returned via PutLarge when done
|
||||
func (p *AdaptiveBufferPool) GetLarge() *[]byte {
|
||||
return p.largePool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutLarge returns a large buffer to the pool for reuse
|
||||
func (p *AdaptiveBufferPool) PutLarge(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
// Reset to full capacity to allow reuse
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.largePool.Put(buf)
|
||||
}
|
||||
|
||||
// GetMedium returns a medium buffer (32KB) from the pool
|
||||
// The returned buffer should be returned via PutMedium when done
|
||||
func (p *AdaptiveBufferPool) GetMedium() *[]byte {
|
||||
return p.mediumPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutMedium returns a medium buffer to the pool for reuse
|
||||
func (p *AdaptiveBufferPool) PutMedium(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
// Reset to full capacity to allow reuse
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.mediumPool.Put(buf)
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HeaderPool manages a pool of http.Header objects for reuse.
|
||||
type HeaderPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
// NewHeaderPool creates a new header pool
|
||||
func NewHeaderPool() *HeaderPool {
|
||||
return &HeaderPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(http.Header, 12)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header from the pool.
|
||||
func (p *HeaderPool) Get() http.Header {
|
||||
h := p.pool.Get().(http.Header)
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Put returns a header to the pool.
|
||||
func (p *HeaderPool) Put(h http.Header) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
p.pool.Put(h)
|
||||
}
|
||||
|
||||
// Clone creates a copy of src into dst, reusing dst's underlying storage
|
||||
// This is more efficient than creating a new header from scratch
|
||||
func (p *HeaderPool) Clone(dst, src http.Header) {
|
||||
// Clear dst first
|
||||
for k := range dst {
|
||||
delete(dst, k)
|
||||
}
|
||||
|
||||
// Copy all headers from src to dst
|
||||
for k, vv := range src {
|
||||
// Allocate new slice with exact capacity to avoid over-allocation
|
||||
dst[k] = make([]string, len(vv))
|
||||
copy(dst[k], vv)
|
||||
}
|
||||
}
|
||||
|
||||
// CloneWithExtra clones src into dst and adds/overwrites extra headers
|
||||
// This is optimized for the common pattern of cloning + adding Host header
|
||||
func (p *HeaderPool) CloneWithExtra(dst, src http.Header, extraKey, extraValue string) {
|
||||
// Clear dst first
|
||||
for k := range dst {
|
||||
delete(dst, k)
|
||||
}
|
||||
|
||||
// Copy all headers from src to dst
|
||||
for k, vv := range src {
|
||||
dst[k] = make([]string, len(vv))
|
||||
copy(dst[k], vv)
|
||||
}
|
||||
|
||||
// Set extra header (overwrite if exists)
|
||||
dst.Set(extraKey, extraValue)
|
||||
}
|
||||
|
||||
// globalHeaderPool is a package-level pool for convenience
|
||||
var globalHeaderPool = NewHeaderPool()
|
||||
|
||||
// GetHeader retrieves a header from the global pool
|
||||
func GetHeader() http.Header {
|
||||
return globalHeaderPool.Get()
|
||||
}
|
||||
|
||||
// PutHeader returns a header to the global pool
|
||||
func PutHeader(h http.Header) {
|
||||
globalHeaderPool.Put(h)
|
||||
}
|
||||
@@ -2,81 +2,23 @@ package protocol
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// AdaptivePoolManager dynamically adjusts buffer pool usage based on load
|
||||
// AdaptivePoolManager tracks active connections for load monitoring
|
||||
type AdaptivePoolManager struct {
|
||||
activeConnections atomic.Int64
|
||||
currentThreshold atomic.Int64
|
||||
highLoadConnectionThreshold int64
|
||||
midLoadConnectionThreshold int64
|
||||
midLoadThreshold int64
|
||||
highLoadThreshold int64
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
var globalAdaptiveManager = NewAdaptivePoolManager()
|
||||
|
||||
func NewAdaptivePoolManager() *AdaptivePoolManager {
|
||||
m := &AdaptivePoolManager{
|
||||
highLoadConnectionThreshold: 300,
|
||||
midLoadConnectionThreshold: 150,
|
||||
midLoadThreshold: int64(pool.SizeLarge),
|
||||
highLoadThreshold: int64(pool.SizeMedium),
|
||||
}
|
||||
|
||||
m.currentThreshold.Store(m.midLoadThreshold)
|
||||
go m.monitor()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) monitor() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
connections := m.activeConnections.Load()
|
||||
|
||||
if connections >= m.highLoadConnectionThreshold {
|
||||
m.currentThreshold.Store(m.highLoadThreshold)
|
||||
} else if connections < m.midLoadConnectionThreshold {
|
||||
m.currentThreshold.Store(m.midLoadThreshold)
|
||||
}
|
||||
// Hysteresis zone (150-300): maintain current threshold
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) GetThreshold() int {
|
||||
return int(m.currentThreshold.Load())
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) RegisterConnection() {
|
||||
m.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) UnregisterConnection() {
|
||||
m.activeConnections.Add(-1)
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) GetActiveConnections() int64 {
|
||||
return m.activeConnections.Load()
|
||||
}
|
||||
|
||||
func GetAdaptiveThreshold() int {
|
||||
return globalAdaptiveManager.GetThreshold()
|
||||
}
|
||||
var globalAdaptiveManager = &AdaptivePoolManager{}
|
||||
|
||||
func RegisterConnection() {
|
||||
globalAdaptiveManager.RegisterConnection()
|
||||
globalAdaptiveManager.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func UnregisterConnection() {
|
||||
globalAdaptiveManager.UnregisterConnection()
|
||||
globalAdaptiveManager.activeConnections.Add(-1)
|
||||
}
|
||||
|
||||
func GetActiveConnections() int64 {
|
||||
return globalAdaptiveManager.GetActiveConnections()
|
||||
return globalAdaptiveManager.activeConnections.Load()
|
||||
}
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// DataHeader represents a binary-encoded data header for data plane
|
||||
// All data transmission uses pure binary encoding for performance
|
||||
type DataHeader struct {
|
||||
Type DataType
|
||||
IsLast bool
|
||||
StreamID string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
// DataType represents the type of data frame
|
||||
type DataType uint8
|
||||
|
||||
const (
|
||||
DataTypeData DataType = 0x00 // 000
|
||||
DataTypeResponse DataType = 0x01 // 001
|
||||
DataTypeClose DataType = 0x02 // 010
|
||||
DataTypeHTTPRequest DataType = 0x03 // 011
|
||||
DataTypeHTTPResponse DataType = 0x04 // 100
|
||||
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
|
||||
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
|
||||
|
||||
// Reuse the same type codes for request streaming to stay within 3 bits.
|
||||
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
|
||||
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
|
||||
)
|
||||
|
||||
// String returns the string representation of DataType
|
||||
func (t DataType) String() string {
|
||||
switch t {
|
||||
case DataTypeData:
|
||||
return "data"
|
||||
case DataTypeResponse:
|
||||
return "response"
|
||||
case DataTypeClose:
|
||||
return "close"
|
||||
case DataTypeHTTPRequest:
|
||||
return "http_request"
|
||||
case DataTypeHTTPResponse:
|
||||
return "http_response"
|
||||
case DataTypeHTTPHead:
|
||||
return "http_head"
|
||||
case DataTypeHTTPBodyChunk:
|
||||
return "http_body_chunk"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// FromString converts a string to DataType
|
||||
func DataTypeFromString(s string) DataType {
|
||||
switch s {
|
||||
case "data":
|
||||
return DataTypeData
|
||||
case "response":
|
||||
return DataTypeResponse
|
||||
case "close":
|
||||
return DataTypeClose
|
||||
case "http_request":
|
||||
return DataTypeHTTPRequest
|
||||
case "http_response":
|
||||
return DataTypeHTTPResponse
|
||||
case "http_head":
|
||||
return DataTypeHTTPHead
|
||||
case "http_body_chunk":
|
||||
return DataTypeHTTPBodyChunk
|
||||
default:
|
||||
return DataTypeData
|
||||
}
|
||||
}
|
||||
|
||||
// Binary format:
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | Flags | StreamID Length | RequestID Len |
|
||||
// | 1 byte | 2 bytes | 2 bytes |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | StreamID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | RequestID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
//
|
||||
// Flags (8 bits):
|
||||
// - Bit 0-2: Type (3 bits)
|
||||
// - Bit 3: IsLast (1 bit)
|
||||
// - Bit 4-7: Reserved (4 bits)
|
||||
|
||||
const (
|
||||
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
|
||||
)
|
||||
|
||||
// MarshalBinary encodes the header to binary format
|
||||
func (h *DataHeader) MarshalBinary() []byte {
|
||||
streamIDLen := len(h.StreamID)
|
||||
requestIDLen := len(h.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
buf := make([]byte, totalLen)
|
||||
|
||||
// Encode flags
|
||||
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
|
||||
if h.IsLast {
|
||||
flags |= 0x08 // IsLast uses bit 3
|
||||
}
|
||||
buf[0] = flags
|
||||
|
||||
// Encode lengths (big-endian)
|
||||
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
|
||||
|
||||
// Encode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
copy(buf[offset:], h.StreamID)
|
||||
offset += streamIDLen
|
||||
|
||||
// Encode RequestID
|
||||
copy(buf[offset:], h.RequestID)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes the header from binary format
|
||||
func (h *DataHeader) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < binaryHeaderMinSize {
|
||||
return errors.New("invalid binary header: too short")
|
||||
}
|
||||
|
||||
// Decode flags
|
||||
flags := data[0]
|
||||
h.Type = DataType(flags & 0x07) // Bits 0-2
|
||||
h.IsLast = (flags & 0x08) != 0 // Bit 3
|
||||
|
||||
// Decode lengths
|
||||
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
// Validate total length
|
||||
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
if len(data) < expectedLen {
|
||||
return errors.New("invalid binary header: length mismatch")
|
||||
}
|
||||
|
||||
// Decode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
h.StreamID = string(data[offset : offset+streamIDLen])
|
||||
offset += streamIDLen
|
||||
|
||||
// Decode RequestID
|
||||
h.RequestID = string(data[offset : offset+requestIDLen])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size returns the size of the binary-encoded header
|
||||
func (h *DataHeader) Size() int {
|
||||
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
type FlowControlAction string
|
||||
|
||||
const (
|
||||
FlowControlPause FlowControlAction = "pause"
|
||||
FlowControlResume FlowControlAction = "resume"
|
||||
)
|
||||
|
||||
type FlowControlMessage struct {
|
||||
StreamID string `json:"stream_id"`
|
||||
Action FlowControlAction `json:"action"`
|
||||
}
|
||||
|
||||
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
|
||||
msg := FlowControlMessage{
|
||||
StreamID: streamID,
|
||||
Action: action,
|
||||
}
|
||||
payload, _ := json.Marshal(&msg)
|
||||
return NewFrame(FrameTypeFlowControl, payload)
|
||||
}
|
||||
|
||||
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
|
||||
var msg FlowControlMessage
|
||||
if err := json.Unmarshal(payload, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
@@ -18,14 +18,14 @@ const (
|
||||
type FrameType byte
|
||||
|
||||
const (
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeData FrameType = 0x05
|
||||
FrameTypeClose FrameType = 0x06
|
||||
FrameTypeError FrameType = 0x07
|
||||
FrameTypeFlowControl FrameType = 0x08
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeClose FrameType = 0x05
|
||||
FrameTypeError FrameType = 0x06
|
||||
FrameTypeDataConnect FrameType = 0x07
|
||||
FrameTypeDataConnectAck FrameType = 0x08
|
||||
)
|
||||
|
||||
// String returns the string representation of frame type
|
||||
@@ -39,14 +39,14 @@ func (t FrameType) String() string {
|
||||
return "Heartbeat"
|
||||
case FrameTypeHeartbeatAck:
|
||||
return "HeartbeatAck"
|
||||
case FrameTypeData:
|
||||
return "Data"
|
||||
case FrameTypeClose:
|
||||
return "Close"
|
||||
case FrameTypeError:
|
||||
return "Error"
|
||||
case FrameTypeFlowControl:
|
||||
return "FlowControl"
|
||||
case FrameTypeDataConnect:
|
||||
return "DataConnect"
|
||||
case FrameTypeDataConnectAck:
|
||||
return "DataConnectAck"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown(%d)", t)
|
||||
}
|
||||
@@ -56,6 +56,9 @@ type Frame struct {
|
||||
Type FrameType
|
||||
Payload []byte
|
||||
poolBuffer *[]byte
|
||||
// queuedBytes is set by FrameWriter when the frame is enqueued.
|
||||
// It allows the writer to decrement backlog counters exactly once.
|
||||
queuedBytes int64
|
||||
}
|
||||
|
||||
func WriteFrame(w io.Writer, frame *Frame) error {
|
||||
@@ -130,6 +133,8 @@ func (f *Frame) Release() {
|
||||
f.poolBuffer = nil
|
||||
f.Payload = nil
|
||||
}
|
||||
// Reset queued marker to avoid carrying over stale state if the frame is reused.
|
||||
f.queuedBytes = 0
|
||||
}
|
||||
|
||||
// NewFrame creates a new frame
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
|
||||
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
|
||||
return msgpack.Marshal(req)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var req HTTPRequest
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
|
||||
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
|
||||
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPRequestHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
|
||||
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
|
||||
return msgpack.Marshal(resp)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var resp HTTPResponse
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
|
||||
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
|
||||
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPResponseHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package protocol
|
||||
|
||||
// MessageType defines the type of tunnel message
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// TypeRegister is sent when a client connects and gets a subdomain assigned
|
||||
TypeRegister MessageType = "register"
|
||||
// TypeRequest is sent from server to client when an HTTP request arrives
|
||||
TypeRequest MessageType = "request"
|
||||
// TypeResponse is sent from client to server with the HTTP response
|
||||
TypeResponse MessageType = "response"
|
||||
// TypeHeartbeat is sent periodically to keep the connection alive
|
||||
TypeHeartbeat MessageType = "heartbeat"
|
||||
// TypeError is sent when an error occurs
|
||||
TypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// Message represents a tunnel protocol message
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Subdomain string `json:"subdomain,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequest represents an HTTP request to be forwarded
|
||||
type HTTPRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequestHead represents HTTP request headers for streaming (no body)
|
||||
type HTTPRequestHead struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// HTTPResponse represents an HTTP response from the local service
|
||||
type HTTPResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPResponseHead represents HTTP response headers for streaming (no body)
|
||||
type HTTPResponseHead struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// RegisterData contains information sent when a tunnel is registered
|
||||
type RegisterData struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
URL string `json:"url"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ErrorData contains error information
|
||||
type ErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -2,12 +2,23 @@ package protocol
|
||||
|
||||
import json "github.com/goccy/go-json"
|
||||
|
||||
// PoolCapabilities advertises client connection pool capabilities
|
||||
type PoolCapabilities struct {
|
||||
MaxDataConns int `json:"max_data_conns"` // Maximum data connections client supports
|
||||
Version int `json:"version"` // Protocol version for pool features
|
||||
}
|
||||
|
||||
// RegisterRequest is sent by client to register a tunnel
|
||||
type RegisterRequest struct {
|
||||
Token string `json:"token"` // Authentication token
|
||||
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
|
||||
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
|
||||
LocalPort int `json:"local_port"` // Local port to forward to
|
||||
|
||||
// Connection pool fields (optional, for multi-connection support)
|
||||
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
|
||||
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
|
||||
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
|
||||
}
|
||||
|
||||
// RegisterResponse is sent by server after successful registration
|
||||
@@ -16,6 +27,25 @@ type RegisterResponse struct {
|
||||
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
|
||||
URL string `json:"url"` // Full tunnel URL
|
||||
Message string `json:"message"` // Success message
|
||||
|
||||
// Connection pool fields (optional, for multi-connection support)
|
||||
TunnelID string `json:"tunnel_id,omitempty"` // Unique tunnel identifier
|
||||
SupportsDataConn bool `json:"supports_data_conn,omitempty"` // Server supports multi-connection
|
||||
RecommendedConns int `json:"recommended_conns,omitempty"` // Suggested data connection count
|
||||
}
|
||||
|
||||
// DataConnectRequest is sent by data connections to join a tunnel
|
||||
type DataConnectRequest struct {
|
||||
TunnelID string `json:"tunnel_id"` // Tunnel to join
|
||||
Token string `json:"token"` // Same auth token as primary
|
||||
ConnectionID string `json:"connection_id"` // Unique connection identifier
|
||||
}
|
||||
|
||||
// DataConnectResponse acknowledges data connection
|
||||
type DataConnectResponse struct {
|
||||
Accepted bool `json:"accepted"` // Whether connection was accepted
|
||||
ConnectionID string `json:"connection_id"` // Echoed connection ID
|
||||
Message string `json:"message,omitempty"` // Optional message
|
||||
}
|
||||
|
||||
// ErrorMessage represents an error
|
||||
@@ -24,9 +54,6 @@ type ErrorMessage struct {
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// Note: DataHeader is now defined in binary_header.go as a pure binary structure
|
||||
// TCPData has been removed - use DataHeader + raw bytes directly
|
||||
|
||||
// Marshal helpers for control plane messages (JSON encoding)
|
||||
func MarshalJSON(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// encodeDataPayload encodes a data header and payload into a frame payload.
|
||||
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
|
||||
payload := make([]byte, totalLen)
|
||||
|
||||
flags := uint8(header.Type) & 0x07
|
||||
if header.IsLast {
|
||||
flags |= 0x08
|
||||
}
|
||||
payload[0] = flags
|
||||
|
||||
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
|
||||
|
||||
offset := binaryHeaderMinSize
|
||||
copy(payload[offset:], header.StreamID)
|
||||
offset += streamIDLen
|
||||
copy(payload[offset:], header.RequestID)
|
||||
offset += requestIDLen
|
||||
copy(payload[offset:], data)
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
|
||||
// Returns payload slice and pool buffer pointer (may be nil).
|
||||
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
|
||||
|
||||
dynamicThreshold := GetAdaptiveThreshold()
|
||||
|
||||
if totalLen < dynamicThreshold {
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
if totalLen > pool.SizeLarge {
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
poolBuffer = pool.GetBuffer(totalLen)
|
||||
payload = (*poolBuffer)[:totalLen]
|
||||
|
||||
flags := uint8(header.Type) & 0x07
|
||||
if header.IsLast {
|
||||
flags |= 0x08
|
||||
}
|
||||
payload[0] = flags
|
||||
|
||||
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
|
||||
|
||||
offset := binaryHeaderMinSize
|
||||
copy(payload[offset:], header.StreamID)
|
||||
offset += streamIDLen
|
||||
copy(payload[offset:], header.RequestID)
|
||||
offset += requestIDLen
|
||||
copy(payload[offset:], data)
|
||||
|
||||
return payload, poolBuffer, nil
|
||||
}
|
||||
|
||||
// DecodeDataPayload decodes a frame payload into header and data.
|
||||
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
|
||||
if len(payload) < binaryHeaderMinSize {
|
||||
return DataHeader{}, nil, errors.New("invalid payload: too short")
|
||||
}
|
||||
|
||||
var header DataHeader
|
||||
if err := header.UnmarshalBinary(payload); err != nil {
|
||||
return DataHeader{}, nil, err
|
||||
}
|
||||
|
||||
headerSize := header.Size()
|
||||
if len(payload) < headerSize {
|
||||
return DataHeader{}, nil, errors.New("invalid payload: data missing")
|
||||
}
|
||||
|
||||
data := payload[headerSize:]
|
||||
return header, data, nil
|
||||
}
|
||||
@@ -4,20 +4,11 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SafeFrame wraps Frame with automatic resource cleanup
|
||||
type SafeFrame struct {
|
||||
*Frame
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewSafeFrame creates a SafeFrame that implements io.Closer
|
||||
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
|
||||
return &SafeFrame{
|
||||
Frame: NewFrame(frameType, payload),
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements io.Closer, ensures Release is called exactly once
|
||||
func (sf *SafeFrame) Close() error {
|
||||
sf.once.Do(func() {
|
||||
if sf.Frame != nil {
|
||||
@@ -27,14 +18,6 @@ func (sf *SafeFrame) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithFrame wraps an existing Frame with automatic cleanup
|
||||
func WithFrame(frame *Frame) *SafeFrame {
|
||||
return &SafeFrame{Frame: frame}
|
||||
}
|
||||
|
||||
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
|
||||
func (sf *SafeFrame) MustClose() {
|
||||
if err := sf.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FrameWriter struct {
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
controlQueue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
|
||||
maxBatch int
|
||||
maxBatchWait time.Duration
|
||||
@@ -24,13 +26,20 @@ type FrameWriter struct {
|
||||
heartbeatControl chan struct{}
|
||||
|
||||
// Error handling
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
|
||||
// Adaptive flushing
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
|
||||
// Hooks
|
||||
preWriteHook func(*Frame) // Called right before a frame is written to conn
|
||||
|
||||
// Backlog tracking
|
||||
queuedFrames atomic.Int64
|
||||
queuedBytes atomic.Int64
|
||||
}
|
||||
|
||||
func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
@@ -41,8 +50,14 @@ func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
|
||||
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
|
||||
w := &FrameWriter{
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
controlQueue: make(chan *Frame, func() int {
|
||||
if queueSize < 256 {
|
||||
return queueSize
|
||||
}
|
||||
return 256
|
||||
}()), // control path needs small, fast buffer
|
||||
batch: make([]*Frame, 0, maxBatch),
|
||||
maxBatch: maxBatch,
|
||||
maxBatchWait: maxBatchWait,
|
||||
@@ -74,6 +89,22 @@ func (w *FrameWriter) writeLoop() {
|
||||
}()
|
||||
|
||||
for {
|
||||
// Always drain control queue first to prioritize control/heartbeat frames.
|
||||
select {
|
||||
case frame, ok := <-w.controlQueue:
|
||||
if !ok {
|
||||
w.mu.Lock()
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
w.flushFrameLocked(frame)
|
||||
w.mu.Unlock()
|
||||
continue
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case frame, ok := <-w.queue:
|
||||
if !ok {
|
||||
@@ -105,8 +136,7 @@ func (w *FrameWriter) writeLoop() {
|
||||
w.mu.Lock()
|
||||
if w.heartbeatCallback != nil {
|
||||
if frame := w.heartbeatCallback(); frame != nil {
|
||||
w.batch = append(w.batch, frame)
|
||||
w.flushBatchLocked()
|
||||
w.flushFrameLocked(frame)
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
@@ -139,22 +169,47 @@ func (w *FrameWriter) flushBatchLocked() {
|
||||
}
|
||||
|
||||
for _, frame := range w.batch {
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
frame.Release()
|
||||
w.flushFrameLocked(frame)
|
||||
}
|
||||
|
||||
w.batch = w.batch[:0]
|
||||
}
|
||||
|
||||
// flushFrameLocked writes a single frame immediately. Caller must hold w.mu.
|
||||
func (w *FrameWriter) flushFrameLocked(frame *Frame) {
|
||||
if frame == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if w.preWriteHook != nil {
|
||||
w.preWriteHook(frame)
|
||||
}
|
||||
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return w.WriteFrameWithCancel(frame, nil)
|
||||
}
|
||||
|
||||
// WriteFrameWithCancel writes a frame with an optional cancellation channel
|
||||
// If cancel is closed, the write will be aborted immediately
|
||||
func (w *FrameWriter) WriteFrameWithCancel(frame *Frame, cancel <-chan struct{}) error {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
@@ -165,10 +220,19 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
size := int64(len(frame.Payload) + FrameHeaderSize)
|
||||
w.queuedFrames.Add(1)
|
||||
w.queuedBytes.Add(size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, size)
|
||||
|
||||
// Try non-blocking first for best performance
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
@@ -176,6 +240,54 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue full - block with cancellation support
|
||||
if cancel != nil {
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-cancel:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("write cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// No cancel channel - block with timeout
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-time.After(30 * time.Second):
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("write queue full timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,8 +301,14 @@ func (w *FrameWriter) Close() error {
|
||||
w.mu.Unlock()
|
||||
|
||||
close(w.queue)
|
||||
close(w.controlQueue)
|
||||
|
||||
for frame := range w.queue {
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
for frame := range w.controlQueue {
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
@@ -264,3 +382,97 @@ func (w *FrameWriter) DisableAdaptiveFlush() {
|
||||
w.adaptiveFlush = false
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// WriteControl enqueues a control/prioritized frame to be written ahead of data frames.
|
||||
func (w *FrameWriter) WriteControl(frame *Frame) error {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
if w.writeErr != nil {
|
||||
return w.writeErr
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
size := int64(len(frame.Payload) + FrameHeaderSize)
|
||||
w.queuedFrames.Add(1)
|
||||
w.queuedBytes.Add(size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, size)
|
||||
|
||||
// Try non-blocking first
|
||||
select {
|
||||
case w.controlQueue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue full - wait with timeout
|
||||
select {
|
||||
case w.controlQueue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Control frames should have priority, shorter timeout
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("control queue full timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// SetPreWriteHook registers a callback invoked just before a frame is written to the underlying writer.
|
||||
func (w *FrameWriter) SetPreWriteHook(hook func(*Frame)) {
|
||||
w.mu.Lock()
|
||||
w.preWriteHook = hook
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// QueuedFrames returns the number of frames currently queued (data + control).
|
||||
func (w *FrameWriter) QueuedFrames() int64 {
|
||||
return w.queuedFrames.Load()
|
||||
}
|
||||
|
||||
// QueuedBytes returns the approximate number of bytes currently queued.
|
||||
func (w *FrameWriter) QueuedBytes() int64 {
|
||||
return w.queuedBytes.Load()
|
||||
}
|
||||
|
||||
// unmarkQueued decrements backlog counters for a frame once it is written or discarded.
|
||||
func (w *FrameWriter) unmarkQueued(frame *Frame) {
|
||||
if frame == nil {
|
||||
return
|
||||
}
|
||||
size := atomic.SwapInt64(&frame.queuedBytes, 0)
|
||||
if size <= 0 {
|
||||
return
|
||||
}
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
}
|
||||
|
||||
77
internal/shared/stats/format.go
Normal file
77
internal/shared/stats/format.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package stats
|
||||
|
||||
// FormatBytes formats bytes to human readable string
|
||||
func FormatBytes(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return formatFloat(float64(bytes)/float64(GB)) + " GB"
|
||||
case bytes >= MB:
|
||||
return formatFloat(float64(bytes)/float64(MB)) + " MB"
|
||||
case bytes >= KB:
|
||||
return formatFloat(float64(bytes)/float64(KB)) + " KB"
|
||||
default:
|
||||
return formatInt(bytes) + " B"
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSpeed formats speed (bytes per second) to human readable string
|
||||
func FormatSpeed(bytesPerSec int64) string {
|
||||
if bytesPerSec == 0 {
|
||||
return "0 B/s"
|
||||
}
|
||||
return FormatBytes(bytesPerSec) + "/s"
|
||||
}
|
||||
|
||||
func formatFloat(f float64) string {
|
||||
if f >= 100 {
|
||||
return formatInt(int64(f))
|
||||
} else if f >= 10 {
|
||||
return formatOneDecimal(f)
|
||||
}
|
||||
return formatTwoDecimal(f)
|
||||
}
|
||||
|
||||
func formatInt(i int64) string {
|
||||
return intToStr(i)
|
||||
}
|
||||
|
||||
func formatOneDecimal(f float64) string {
|
||||
i := int64(f * 10)
|
||||
whole := i / 10
|
||||
frac := i % 10
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func formatTwoDecimal(f float64) string {
|
||||
i := int64(f * 100)
|
||||
whole := i / 100
|
||||
frac := i % 100
|
||||
if frac < 10 {
|
||||
return intToStr(whole) + ".0" + intToStr(frac)
|
||||
}
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func intToStr(i int64) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
if i < 0 {
|
||||
return "-" + intToStr(-i)
|
||||
}
|
||||
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tcp
|
||||
package stats
|
||||
|
||||
import (
|
||||
"sync"
|
||||
@@ -13,7 +13,8 @@ type TrafficStats struct {
|
||||
totalBytesOut int64
|
||||
|
||||
// Request counts
|
||||
totalRequests int64
|
||||
totalRequests int64
|
||||
activeConnections int64
|
||||
|
||||
// For speed calculation
|
||||
lastBytesIn int64
|
||||
@@ -53,6 +54,17 @@ func (s *TrafficStats) AddRequest() {
|
||||
atomic.AddInt64(&s.totalRequests, 1)
|
||||
}
|
||||
|
||||
func (s *TrafficStats) IncActiveConnections() {
|
||||
atomic.AddInt64(&s.activeConnections, 1)
|
||||
}
|
||||
|
||||
func (s *TrafficStats) DecActiveConnections() {
|
||||
v := atomic.AddInt64(&s.activeConnections, -1)
|
||||
if v < 0 {
|
||||
atomic.StoreInt64(&s.activeConnections, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalBytesIn returns total incoming bytes
|
||||
func (s *TrafficStats) GetTotalBytesIn() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesIn)
|
||||
@@ -68,6 +80,10 @@ func (s *TrafficStats) GetTotalRequests() int64 {
|
||||
return atomic.LoadInt64(&s.totalRequests)
|
||||
}
|
||||
|
||||
func (s *TrafficStats) GetActiveConnections() int64 {
|
||||
return atomic.LoadInt64(&s.activeConnections)
|
||||
}
|
||||
|
||||
// GetTotalBytes returns total bytes (in + out)
|
||||
func (s *TrafficStats) GetTotalBytes() int64 {
|
||||
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
|
||||
@@ -81,8 +97,10 @@ func (s *TrafficStats) UpdateSpeed() {
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(s.lastTime).Seconds()
|
||||
|
||||
// Require minimum interval of 100ms to avoid division issues
|
||||
if elapsed < 0.1 {
|
||||
return // Avoid division by zero or too frequent updates
|
||||
return
|
||||
}
|
||||
|
||||
currentIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
@@ -91,8 +109,20 @@ func (s *TrafficStats) UpdateSpeed() {
|
||||
deltaIn := currentIn - s.lastBytesIn
|
||||
deltaOut := currentOut - s.lastBytesOut
|
||||
|
||||
s.speedIn = int64(float64(deltaIn) / elapsed)
|
||||
s.speedOut = int64(float64(deltaOut) / elapsed)
|
||||
// Calculate instantaneous speed
|
||||
if deltaIn > 0 {
|
||||
s.speedIn = int64(float64(deltaIn) / elapsed)
|
||||
} else {
|
||||
// No new bytes - set speed to 0 immediately
|
||||
s.speedIn = 0
|
||||
}
|
||||
|
||||
if deltaOut > 0 {
|
||||
s.speedOut = int64(float64(deltaOut) / elapsed)
|
||||
} else {
|
||||
// No new bytes - set speed to 0 immediately
|
||||
s.speedOut = 0
|
||||
}
|
||||
|
||||
s.lastBytesIn = currentIn
|
||||
s.lastBytesOut = currentOut
|
||||
@@ -119,18 +149,19 @@ func (s *TrafficStats) GetUptime() time.Duration {
|
||||
}
|
||||
|
||||
// Snapshot returns a snapshot of all stats
|
||||
type StatsSnapshot struct {
|
||||
TotalBytesIn int64
|
||||
TotalBytesOut int64
|
||||
TotalBytes int64
|
||||
TotalRequests int64
|
||||
SpeedIn int64 // bytes per second
|
||||
SpeedOut int64 // bytes per second
|
||||
Uptime time.Duration
|
||||
type Snapshot struct {
|
||||
TotalBytesIn int64
|
||||
TotalBytesOut int64
|
||||
TotalBytes int64
|
||||
TotalRequests int64
|
||||
ActiveConnections int64
|
||||
SpeedIn int64 // bytes per second
|
||||
SpeedOut int64 // bytes per second
|
||||
Uptime time.Duration
|
||||
}
|
||||
|
||||
// GetSnapshot returns a snapshot of current stats
|
||||
func (s *TrafficStats) GetSnapshot() StatsSnapshot {
|
||||
func (s *TrafficStats) GetSnapshot() Snapshot {
|
||||
s.speedMu.Lock()
|
||||
speedIn := s.speedIn
|
||||
speedOut := s.speedOut
|
||||
@@ -138,90 +169,16 @@ func (s *TrafficStats) GetSnapshot() StatsSnapshot {
|
||||
|
||||
totalIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
totalOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
active := atomic.LoadInt64(&s.activeConnections)
|
||||
|
||||
return StatsSnapshot{
|
||||
TotalBytesIn: totalIn,
|
||||
TotalBytesOut: totalOut,
|
||||
TotalBytes: totalIn + totalOut,
|
||||
TotalRequests: atomic.LoadInt64(&s.totalRequests),
|
||||
SpeedIn: speedIn,
|
||||
SpeedOut: speedOut,
|
||||
Uptime: time.Since(s.startTime),
|
||||
return Snapshot{
|
||||
TotalBytesIn: totalIn,
|
||||
TotalBytesOut: totalOut,
|
||||
TotalBytes: totalIn + totalOut,
|
||||
TotalRequests: atomic.LoadInt64(&s.totalRequests),
|
||||
ActiveConnections: active,
|
||||
SpeedIn: speedIn,
|
||||
SpeedOut: speedOut,
|
||||
Uptime: time.Since(s.startTime),
|
||||
}
|
||||
}
|
||||
|
||||
// FormatBytes formats bytes to human readable string
|
||||
func FormatBytes(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return formatFloat(float64(bytes)/float64(GB)) + " GB"
|
||||
case bytes >= MB:
|
||||
return formatFloat(float64(bytes)/float64(MB)) + " MB"
|
||||
case bytes >= KB:
|
||||
return formatFloat(float64(bytes)/float64(KB)) + " KB"
|
||||
default:
|
||||
return formatInt(bytes) + " B"
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSpeed formats speed (bytes per second) to human readable string
|
||||
func FormatSpeed(bytesPerSec int64) string {
|
||||
if bytesPerSec == 0 {
|
||||
return "0 B/s"
|
||||
}
|
||||
return FormatBytes(bytesPerSec) + "/s"
|
||||
}
|
||||
|
||||
func formatFloat(f float64) string {
|
||||
if f >= 100 {
|
||||
return formatInt(int64(f))
|
||||
} else if f >= 10 {
|
||||
return formatOneDecimal(f)
|
||||
}
|
||||
return formatTwoDecimal(f)
|
||||
}
|
||||
|
||||
func formatInt(i int64) string {
|
||||
return intToStr(i)
|
||||
}
|
||||
|
||||
func formatOneDecimal(f float64) string {
|
||||
i := int64(f * 10)
|
||||
whole := i / 10
|
||||
frac := i % 10
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func formatTwoDecimal(f float64) string {
|
||||
i := int64(f * 100)
|
||||
whole := i / 100
|
||||
frac := i % 100
|
||||
if frac < 10 {
|
||||
return intToStr(whole) + ".0" + intToStr(frac)
|
||||
}
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func intToStr(i int64) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
if i < 0 {
|
||||
return "-" + intToStr(-i)
|
||||
}
|
||||
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}
|
||||
@@ -2,6 +2,9 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// RenderConfigInit renders config initialization UI
|
||||
@@ -100,18 +103,59 @@ func RenderConfigValidation(serverValid bool, serverMsg string, tokenSet bool, t
|
||||
}
|
||||
|
||||
// RenderDaemonStarted renders daemon started message
|
||||
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
|
||||
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string, url string, forwardAddr string, serverAddr string) string {
|
||||
if forwardAddr == "" {
|
||||
forwardAddr = fmt.Sprintf("localhost:%d", port)
|
||||
}
|
||||
|
||||
urlLine := Muted("(resolving...)")
|
||||
if url != "" {
|
||||
urlBadge := lipgloss.NewStyle().
|
||||
Background(successColor).
|
||||
Foreground(lipgloss.Color("#f8fafc")).
|
||||
Bold(true).
|
||||
Padding(0, 1).
|
||||
Render(url)
|
||||
urlLine = urlBadge
|
||||
}
|
||||
|
||||
headline := successStyle.Render("✓ Tunnel Started in Background")
|
||||
|
||||
lines := []string{
|
||||
KeyValue("Type", Highlight(tunnelType)),
|
||||
KeyValue("Port", fmt.Sprintf("%d", port)),
|
||||
KeyValue("PID", fmt.Sprintf("%d", pid)),
|
||||
KeyValue("Forward", forwardAddr),
|
||||
}
|
||||
if serverAddr != "" {
|
||||
lines = append(lines, KeyValue("Server", serverAddr))
|
||||
}
|
||||
|
||||
lines = append(lines,
|
||||
"",
|
||||
Muted("Commands:"),
|
||||
Cyan(" drip list") + Muted(" Check tunnel status"),
|
||||
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
|
||||
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
|
||||
Cyan(" drip list")+Muted(" Check tunnel status"),
|
||||
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port))+Muted(" View logs"),
|
||||
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port))+Muted(" Stop tunnel"),
|
||||
"",
|
||||
Muted("Logs: ") + mutedStyle.Render(logPath),
|
||||
Muted("Logs: ")+mutedStyle.Render(logPath),
|
||||
)
|
||||
|
||||
contentWidth := 0
|
||||
for _, line := range append([]string{headline}, lines...) {
|
||||
if w := lipgloss.Width(line); w > contentWidth {
|
||||
contentWidth = w
|
||||
}
|
||||
}
|
||||
return SuccessBox("Tunnel Started in Background", lines...)
|
||||
if w := lipgloss.Width(urlLine); w > contentWidth {
|
||||
contentWidth = w
|
||||
}
|
||||
|
||||
centeredURL := lipgloss.PlaceHorizontal(contentWidth, lipgloss.Center, urlLine)
|
||||
|
||||
contentLines := make([]string, 0, len(lines)+4)
|
||||
contentLines = append(contentLines, headline, "", centeredURL, "")
|
||||
contentLines = append(contentLines, lines...)
|
||||
|
||||
return successBoxStyle.Render(strings.Join(contentLines, "\n"))
|
||||
}
|
||||
@@ -93,6 +93,11 @@ func RenderTunnelStats(status *TunnelStatus) string {
|
||||
|
||||
_, _, accent := tunnelVisuals(status.Type)
|
||||
|
||||
requestLabel := "Requests"
|
||||
if status.Type == "tcp" {
|
||||
requestLabel = "Connections"
|
||||
}
|
||||
|
||||
header := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render("◉"),
|
||||
@@ -102,7 +107,7 @@ func RenderTunnelStats(status *TunnelStatus) string {
|
||||
row1 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Latency", latencyStr, statsColumnWidth),
|
||||
statColumn("Requests", highlightStyle.Render(requestsStr), statsColumnWidth),
|
||||
statColumn(requestLabel, highlightStyle.Render(requestsStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
row2 := lipgloss.JoinHorizontal(
|
||||
Reference in New Issue
Block a user