Merge pull request #8 from Gouryella/refactor/protocol-v2

refactor(protocol): switch to yamux multiplexing with connection pooling
This commit is contained in:
Gouryella
2025-12-14 11:13:57 +08:00
committed by GitHub
55 changed files with 3485 additions and 4856 deletions

1
.gitignore vendored
View File

@@ -53,3 +53,4 @@ certs/
.drip-server.env
benchmark-results/
drip
drip-linux-amd64

3
go.mod
View File

@@ -6,8 +6,8 @@ require (
github.com/charmbracelet/lipgloss v1.1.0
github.com/goccy/go-json v0.10.5
github.com/gorilla/websocket v1.5.3
github.com/hashicorp/yamux v0.1.2
github.com/spf13/cobra v1.10.1
github.com/vmihailenco/msgpack/v5 v5.4.1
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.45.0
golang.org/x/sys v0.38.0
@@ -28,7 +28,6 @@ require (
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.47.0 // indirect

6
go.sum
View File

@@ -17,6 +17,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
@@ -40,10 +42,6 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=

View File

@@ -12,7 +12,7 @@ import (
"syscall"
"time"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"github.com/spf13/cobra"
)
@@ -24,6 +24,7 @@ var attachCmd = &cobra.Command{
Examples:
drip attach List running tunnels and select one
drip attach http 3000 Attach to HTTP tunnel on port 3000
drip attach https 8443 Attach to HTTPS tunnel on port 8443
drip attach tcp 5432 Attach to TCP tunnel on port 5432
Press Ctrl+C to detach (tunnel will continue running).`,
@@ -36,7 +37,7 @@ func init() {
rootCmd.AddCommand(attachCmd)
}
func runAttach(cmd *cobra.Command, args []string) error {
func runAttach(_ *cobra.Command, args []string) error {
CleanupStaleDaemons()
daemons, err := ListAllDaemons()
@@ -59,8 +60,8 @@ func runAttach(cmd *cobra.Command, args []string) error {
if len(args) == 2 {
tunnelType := args[0]
if tunnelType != "http" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
}
port, err := strconv.Atoi(args[1])
@@ -119,10 +120,13 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
uptime := time.Since(d.StartTime)
var typeStr string
if d.Type == "http" {
typeStr = ui.Success("HTTP")
} else {
typeStr = ui.Highlight("TCP")
switch d.Type {
case "http":
typeStr = ui.Highlight("HTTP")
case "https":
typeStr = ui.Highlight("HTTPS")
default:
typeStr = ui.Cyan("TCP")
}
table.AddRow([]string{
@@ -196,7 +200,7 @@ func attachToDaemon(daemon *DaemonInfo) error {
select {
case <-sigCh:
if tailCmd.Process != nil {
tailCmd.Process.Kill()
_ = tailCmd.Process.Kill()
}
fmt.Println()
fmt.Println(ui.Warning("Detached from tunnel (tunnel is still running)"))

View File

@@ -7,7 +7,7 @@ import (
"os"
"strings"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"drip/pkg/config"
"github.com/spf13/cobra"
)
@@ -77,7 +77,7 @@ func init() {
rootCmd.AddCommand(configCmd)
}
func runConfigInit(cmd *cobra.Command, args []string) error {
func runConfigInit(_ *cobra.Command, _ []string) error {
fmt.Print(ui.RenderConfigInit())
reader := bufio.NewReader(os.Stdin)
@@ -109,7 +109,7 @@ func runConfigInit(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigShow(cmd *cobra.Command, args []string) error {
func runConfigShow(_ *cobra.Command, _ []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
return err
@@ -137,7 +137,7 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigSet(cmd *cobra.Command, args []string) error {
func runConfigSet(_ *cobra.Command, _ []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
cfg = &config.ClientConfig{
@@ -173,7 +173,7 @@ func runConfigSet(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigReset(cmd *cobra.Command, args []string) error {
func runConfigReset(_ *cobra.Command, _ []string) error {
configPath := config.DefaultClientConfigPath()
if !config.ConfigExists("") {
@@ -202,7 +202,7 @@ func runConfigReset(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigValidate(cmd *cobra.Command, args []string) error {
func runConfigValidate(_ *cobra.Command, _ []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
fmt.Println(ui.Error("Failed to load configuration"))

View File

@@ -5,10 +5,10 @@ import (
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"drip/pkg/config"
json "github.com/goccy/go-json"
)
@@ -194,14 +194,67 @@ func StartDaemon(tunnelType string, port int, args []string) error {
return fmt.Errorf("failed to start daemon: %w", err)
}
// Don't wait for the process - let it run in background
// The child process will save its own daemon info after connecting
_ = logFile.Close()
_ = devNull.Close()
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath))
localHost := parseFlagValue(cleanArgs, "--address", "-a", "127.0.0.1")
displayHost := localHost
if displayHost == "127.0.0.1" {
displayHost = "localhost"
}
forwardAddr := fmt.Sprintf("%s:%d", displayHost, port)
serverAddr := parseFlagValue(cleanArgs, "--server", "-s", "")
if serverAddr == "" {
if cfg, err := config.LoadClientConfig(""); err == nil {
serverAddr = cfg.Server
}
}
var url string
info, err := waitForDaemonInfo(tunnelType, port, cmd.Process.Pid, 30*time.Second)
if err == nil && info != nil && info.PID == cmd.Process.Pid && info.URL != "" {
url = info.URL
if info.Server != "" {
serverAddr = info.Server
}
}
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath, url, forwardAddr, serverAddr))
return nil
}
func parseFlagValue(args []string, longName string, shortName string, defaultValue string) string {
for i := 0; i < len(args); i++ {
if args[i] == longName || args[i] == shortName {
if i+1 < len(args) && args[i+1] != "" {
return args[i+1]
}
}
}
return defaultValue
}
func waitForDaemonInfo(tunnelType string, port int, pid int, timeout time.Duration) (*DaemonInfo, error) {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if !IsProcessRunning(pid) {
return nil, nil
}
info, err := LoadDaemonInfo(tunnelType, port)
if err == nil && info != nil && info.PID == pid {
if info.URL != "" {
return info, nil
}
}
time.Sleep(50 * time.Millisecond)
}
return nil, nil
}
// CleanupStaleDaemons removes daemon info for processes that are no longer running
func CleanupStaleDaemons() error {
daemons, err := ListAllDaemons()
@@ -220,28 +273,16 @@ func CleanupStaleDaemons() error {
// FormatDuration formats a duration in a human-readable way
func FormatDuration(d time.Duration) string {
if d < time.Minute {
switch {
case d < time.Minute:
return fmt.Sprintf("%ds", int(d.Seconds()))
} else if d < time.Hour {
case d < time.Hour:
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
} else if d < 24*time.Hour {
case d < 24*time.Hour:
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
}
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
return fmt.Sprintf("%dd %dh", days, hours)
}
// ParsePortFromArgs extracts the port number from command arguments
func ParsePortFromArgs(args []string) (int, error) {
for _, arg := range args {
if len(arg) > 0 && arg[0] == '-' {
continue
}
port, err := strconv.Atoi(arg)
if err == nil && port > 0 && port <= 65535 {
return port, nil
}
}
return 0, fmt.Errorf("port number not found in arguments")
}

View File

@@ -2,22 +2,14 @@ package cli
import (
"fmt"
"os"
"strconv"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/pkg/config"
"github.com/spf13/cobra"
)
const (
maxReconnectAttempts = 5
reconnectInterval = 3 * time.Second
)
var (
subdomain string
daemonMode bool
@@ -52,55 +44,19 @@ func init() {
rootCmd.AddCommand(httpCmd)
}
func runHTTP(cmd *cobra.Command, args []string) error {
func runHTTP(_ *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if daemonMode && !daemonMarker {
daemonArgs := append([]string{"http"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("http", port, daemonArgs)
return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress))
}
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip http %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
serverAddr, token, err := resolveServerAddrAndToken("http", port)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
@@ -115,15 +71,7 @@ Please run 'drip config init' first, or use flags:
var daemon *DaemonInfo
if daemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "http",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
daemon = newDaemonInfo("http", port, subdomain, serverAddr)
}
return runTunnelWithUI(connConfig, daemon)

View File

@@ -2,24 +2,14 @@ package cli
import (
"fmt"
"os"
"strconv"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/pkg/config"
"github.com/spf13/cobra"
)
var (
httpsSubdomain string
httpsDaemonMode bool
httpsDaemonMarker bool
httpsLocalAddress string
)
var httpsCmd = &cobra.Command{
Use: "https <port>",
Short: "Start HTTPS tunnel",
@@ -39,86 +29,42 @@ Note: Uses TCP over TLS 1.3 for secure communication`,
}
func init() {
httpsCmd.Flags().StringVarP(&httpsSubdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpsCmd.Flags().BoolVarP(&httpsDaemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpsCmd.Flags().StringVarP(&httpsLocalAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpsCmd.Flags().BoolVar(&httpsDaemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
}
func runHTTPS(cmd *cobra.Command, args []string) error {
func runHTTPS(_ *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if httpsDaemonMode && !httpsDaemonMarker {
daemonArgs := append([]string{"https"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if httpsSubdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", httpsSubdomain)
}
if httpsLocalAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", httpsLocalAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("https", port, daemonArgs)
if daemonMode && !daemonMarker {
return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress))
}
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip https %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
serverAddr, token, err := resolveServerAddrAndToken("https", port)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
TunnelType: protocol.TunnelTypeHTTPS,
LocalHost: httpsLocalAddress,
LocalHost: localAddress,
LocalPort: port,
Subdomain: httpsSubdomain,
Subdomain: subdomain,
Insecure: insecure,
}
var daemon *DaemonInfo
if httpsDaemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "https",
Port: port,
Subdomain: httpsSubdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
if daemonMarker {
daemon = newDaemonInfo("https", port, subdomain, serverAddr)
}
return runTunnelWithUI(connConfig, daemon)

View File

@@ -8,13 +8,11 @@ import (
"strings"
"time"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"github.com/spf13/cobra"
)
var (
interactiveMode bool
)
var interactiveMode bool
var listCmd = &cobra.Command{
Use: "list",
@@ -26,7 +24,7 @@ Example:
drip list -i Interactive mode (select to attach/stop)
This command shows:
- Tunnel type (HTTP/TCP)
- Tunnel type (HTTP/HTTPS/TCP)
- Local port being tunneled
- Public URL
- Process ID (PID)
@@ -44,7 +42,7 @@ func init() {
rootCmd.AddCommand(listCmd)
}
func runList(cmd *cobra.Command, args []string) error {
func runList(_ *cobra.Command, _ []string) error {
CleanupStaleDaemons()
daemons, err := ListAllDaemons()
@@ -99,7 +97,7 @@ func runList(cmd *cobra.Command, args []string) error {
fmt.Print(table.Render())
if interactiveMode || shouldPromptForAction() {
if interactiveMode {
return runInteractiveList(daemons)
}
@@ -114,13 +112,6 @@ func runList(cmd *cobra.Command, args []string) error {
return nil
}
func shouldPromptForAction() bool {
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
return false
}
return false
}
func runInteractiveList(daemons []*DaemonInfo) error {
var runningDaemons []*DaemonInfo
for _, d := range daemons {
@@ -161,9 +152,9 @@ func runInteractiveList(daemons []*DaemonInfo) error {
"",
ui.Muted("What would you like to do?"),
"",
ui.Cyan(" 1.") + " Attach (view logs)",
ui.Cyan(" 2.") + " Stop tunnel",
ui.Muted(" q.") + " Cancel",
ui.Cyan(" 1.")+" Attach (view logs)",
ui.Cyan(" 2.")+" Stop tunnel",
ui.Muted(" q.")+" Cancel",
))
fmt.Print(ui.Muted("Choose an action: "))

View File

@@ -3,7 +3,7 @@ package cli
import (
"fmt"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"github.com/spf13/cobra"
)
@@ -56,14 +56,12 @@ func init() {
versionCmd.Flags().BoolVar(&versionPlain, "short", false, "Print version information without styling")
rootCmd.AddCommand(versionCmd)
// http and tcp commands are added in their respective init() functions
// config command is added in config.go init() function
}
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
Run: func(cmd *cobra.Command, args []string) {
Run: func(_ *cobra.Command, _ []string) {
if versionPlain {
fmt.Printf("Version: %s\nGit Commit: %s\nBuild Time: %s\n", Version, GitCommit, BuildTime)
return

View File

@@ -59,7 +59,7 @@ func init() {
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
}
func runServer(cmd *cobra.Command, args []string) error {
func runServer(_ *cobra.Command, _ []string) error {
if serverTLSCert == "" {
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
}
@@ -126,11 +126,9 @@ func runServer(cmd *cobra.Command, args []string) error {
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
responseHandler := proxy.NewResponseHandler(logger)
httpHandler := proxy.NewHandler(tunnelManager, logger, serverDomain, serverAuthToken)
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler)
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))

View File

@@ -28,7 +28,7 @@ func init() {
rootCmd.AddCommand(stopCmd)
}
func runStop(cmd *cobra.Command, args []string) error {
func runStop(_ *cobra.Command, args []string) error {
if args[0] == "all" {
return stopAllDaemons()
}

View File

@@ -2,13 +2,10 @@ package cli
import (
"fmt"
"os"
"strconv"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/pkg/config"
"github.com/spf13/cobra"
)
@@ -47,55 +44,19 @@ func init() {
rootCmd.AddCommand(tcpCmd)
}
func runTCP(cmd *cobra.Command, args []string) error {
func runTCP(_ *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if daemonMode && !daemonMarker {
daemonArgs := append([]string{"tcp"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("tcp", port, daemonArgs)
return StartDaemon("tcp", port, buildDaemonArgs("tcp", args, subdomain, localAddress))
}
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
serverAddr, token, err := resolveServerAddrAndToken("tcp", port)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
@@ -110,15 +71,7 @@ Please run 'drip config init' first, or use flags:
var daemon *DaemonInfo
if daemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "tcp",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
daemon = newDaemonInfo("tcp", port, subdomain, serverAddr)
}
return runTunnelWithUI(connConfig, daemon)

View File

@@ -0,0 +1,67 @@
package cli
import (
"fmt"
"os"
"time"
"drip/pkg/config"
)
func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAddress string) []string {
daemonArgs := append([]string{tunnelType}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return daemonArgs
}
func resolveServerAddrAndToken(tunnelType string, port int) (string, string, error) {
if serverURL != "" {
return serverURL, authToken, nil
}
cfg, err := config.LoadClientConfig("")
if err != nil {
return "", "", fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip %s %d --server SERVER:PORT --token TOKEN`, tunnelType, port)
}
if cfg.Server == "" {
return "", "", fmt.Errorf("server address is required")
}
return cfg.Server, cfg.Token, nil
}
func newDaemonInfo(tunnelType string, port int, subdomain string, serverAddr string) *DaemonInfo {
return &DaemonInfo{
PID: os.Getpid(),
Type: tunnelType,
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
}

View File

@@ -8,13 +8,17 @@ import (
"syscall"
"time"
"drip/internal/client/cli/ui"
"drip/internal/client/tcp"
"drip/internal/shared/ui"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
// runTunnelWithUI runs a tunnel with the new UI
const (
maxReconnectAttempts = 5
reconnectInterval = 3 * time.Second
)
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
@@ -28,7 +32,7 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
reconnectAttempts := 0
for {
connector := tcp.NewConnector(connConfig, logger)
connector := tcp.NewTunnelClient(connConfig, logger)
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
@@ -87,8 +91,8 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
renderTicker := time.NewTicker(1 * time.Second)
defer renderTicker.Stop()
var lastLatency time.Duration
lastRenderedLines := 0
@@ -97,27 +101,38 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
case <-renderTicker.C:
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
status.Latency = lastLatency
status.BytesIn = snapshot.TotalBytesIn
status.BytesOut = snapshot.TotalBytesOut
status.SpeedIn = float64(snapshot.SpeedIn)
status.SpeedOut = float64(snapshot.SpeedOut)
status.TotalRequest = snapshot.TotalRequests
statsView := ui.RenderTunnelStats(status)
if lastRenderedLines > 0 {
fmt.Print(clearLines(lastRenderedLines))
}
fmt.Print(statsView)
lastRenderedLines = countRenderedLines(statsView)
if stats == nil {
continue
}
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
status.Latency = lastLatency
status.BytesIn = snapshot.TotalBytesIn
status.BytesOut = snapshot.TotalBytesOut
status.SpeedIn = float64(snapshot.SpeedIn)
status.SpeedOut = float64(snapshot.SpeedOut)
if status.Type == "tcp" {
if snapshot.SpeedIn == 0 && snapshot.SpeedOut == 0 {
status.TotalRequest = 0
} else {
status.TotalRequest = snapshot.ActiveConnections
}
} else {
status.TotalRequest = snapshot.TotalRequests
}
statsView := ui.RenderTunnelStats(status)
if lastRenderedLines > 0 {
fmt.Print(clearLines(lastRenderedLines))
}
fmt.Print(statsView)
lastRenderedLines = countRenderedLines(statsView)
case <-stopDisplay:
return
}

View File

@@ -1,493 +1,50 @@
package tcp
import (
"crypto/tls"
"fmt"
"net"
json "github.com/goccy/go-json"
"sync"
"strings"
"time"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/recovery"
"drip/pkg/config"
"drip/internal/shared/stats"
"go.uber.org/zap"
)
// LatencyCallback is called when latency is measured
type LatencyCallback func(latency time.Duration)
// Connector manages the TCP connection to the server
type Connector struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
conn net.Conn
logger *zap.Logger
stopCh chan struct{}
once sync.Once
registered bool
assignedURL string
frameHandler *FrameHandler
frameWriter *protocol.FrameWriter
latencyCallback LatencyCallback
heartbeatSentAt time.Time
heartbeatMu sync.Mutex
lastLatency time.Duration
handlerWg sync.WaitGroup // Tracks active data frame handlers
closed bool
closedMu sync.RWMutex
// Worker pool for handling data frames
dataFrameQueue chan *protocol.Frame
workerCount int
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
}
// ConnectorConfig holds connector configuration
type ConnectorConfig struct {
ServerAddr string
Token string
TunnelType protocol.TunnelType
LocalHost string // Local host address (default: 127.0.0.1)
LocalHost string
LocalPort int
Subdomain string // Optional custom subdomain
Insecure bool // Skip TLS verification (testing only)
Subdomain string
Insecure bool
PoolSize int
PoolMin int
PoolMax int
}
// NewConnector creates a new connector
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
numCPU := pool.NumCPU()
workerCount := max(numCPU+numCPU/2, 4)
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Connector{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: cfg.TunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
logger: logger,
stopCh: make(chan struct{}),
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
workerCount: workerCount,
recoverer: recoverer,
panicMetrics: panicMetrics,
}
type TunnelClient interface {
Connect() error
Close() error
Wait()
GetURL() string
GetSubdomain() string
SetLatencyCallback(cb LatencyCallback)
GetLatency() time.Duration
GetStats() *stats.TrafficStats
IsClosed() bool
}
// Connect connects to the server and registers the tunnel
func (c *Connector) Connect() error {
c.logger.Info("Connecting to server",
zap.String("server", c.serverAddr),
zap.String("tunnel_type", string(c.tunnelType)),
zap.String("local_host", c.localHost),
zap.Int("local_port", c.localPort),
)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
c.conn = conn
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
conn.Close()
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
c.logger.Info("TLS connection established",
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
)
if err := c.register(); err != nil {
conn.Close()
return fmt.Errorf("registration failed: %w", err)
}
c.frameWriter = protocol.NewFrameWriter(c.conn)
bufferPool := pool.NewBufferPool()
c.frameHandler = NewFrameHandler(
c.conn,
c.frameWriter,
c.localHost,
c.localPort,
c.tunnelType,
c.logger,
c.IsClosed,
bufferPool,
)
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
for i := 0; i < c.workerCount; i++ {
c.handlerWg.Add(1)
go c.dataFrameWorker(i)
}
go c.frameHandler.WarmupConnectionPool(3)
go c.monitorQueuePressure()
go c.handleFrames()
return nil
func NewTunnelClient(cfg *ConnectorConfig, logger *zap.Logger) TunnelClient {
return NewPoolClient(cfg, logger)
}
// register sends registration request and waits for acknowledgment
func (c *Connector) register() error {
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
}
payload, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
err = protocol.WriteFrame(c.conn, regFrame)
if err != nil {
return fmt.Errorf("failed to send registration: %w", err)
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ackFrame, err := protocol.ReadFrame(c.conn)
if err != nil {
return fmt.Errorf("failed to read ack: %w", err)
}
defer ackFrame.Release()
c.conn.SetReadDeadline(time.Time{})
if ackFrame.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
return fmt.Errorf("registration error")
}
if ackFrame.Type != protocol.FrameTypeRegisterAck {
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
c.registered = true
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
c.logger.Info("Tunnel registered successfully",
zap.String("subdomain", resp.Subdomain),
zap.String("url", resp.URL),
zap.Int("remote_port", resp.Port),
)
return nil
}
func (c *Connector) dataFrameWorker(workerID int) {
defer c.handlerWg.Done()
defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID))
for {
select {
case frame, ok := <-c.dataFrameQueue:
if !ok {
return
}
func() {
sf := protocol.WithFrame(frame)
defer sf.Close()
defer c.recoverer.Recover("handleDataFrame")
if err := c.frameHandler.HandleDataFrame(sf.Frame); err != nil {
c.logger.Error("Failed to handle data frame",
zap.Int("worker_id", workerID),
zap.Error(err))
}
}()
case <-c.stopCh:
return
}
}
}
// handleFrames handles incoming frames from server
func (c *Connector) handleFrames() {
defer c.Close()
defer c.recoverer.Recover("handleFrames")
for {
select {
case <-c.stopCh:
return
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(c.conn)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Warn("Read timeout")
return
}
select {
case <-c.stopCh:
return
default:
c.logger.Error("Failed to read frame", zap.Error(err))
return
}
}
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeHeartbeatAck:
c.heartbeatMu.Lock()
if !c.heartbeatSentAt.IsZero() {
latency := time.Since(c.heartbeatSentAt)
c.lastLatency = latency
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
if c.latencyCallback != nil {
c.latencyCallback(latency)
}
} else {
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack")
}
sf.Close()
case protocol.FrameTypeData:
select {
case c.dataFrameQueue <- sf.Frame:
case <-c.stopCh:
sf.Close()
return
default:
c.logger.Warn("Data frame queue full, dropping frame")
sf.Close()
}
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Server requested close")
return
case protocol.FrameTypeError:
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(sf.Frame.Payload, &errMsg); err == nil {
c.logger.Error("Received error from server",
zap.String("code", errMsg.Code),
zap.String("message", errMsg.Message),
)
}
sf.Close()
return
default:
sf.Close()
c.logger.Warn("Unexpected frame type",
zap.String("type", sf.Frame.Type.String()),
)
}
}
}
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
c.closedMu.RLock()
if c.closed {
c.closedMu.RUnlock()
return nil
}
c.closedMu.RUnlock()
c.heartbeatMu.Lock()
c.heartbeatSentAt = time.Now()
c.heartbeatMu.Unlock()
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
}
// SendFrame sends a frame to the server
func (c *Connector) SendFrame(frame *protocol.Frame) error {
if !c.registered {
return fmt.Errorf("not registered")
}
return c.frameWriter.WriteFrame(frame)
}
func (c *Connector) Close() error {
c.once.Do(func() {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
close(c.stopCh)
close(c.dataFrameQueue)
done := make(chan struct{})
go func() {
c.handlerWg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
c.logger.Warn("Force closing: some handlers are still active")
}
if c.conn != nil {
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
if c.frameWriter != nil {
c.frameWriter.WriteFrame(closeFrame)
c.frameWriter.Close()
} else {
protocol.WriteFrame(c.conn, closeFrame)
}
c.conn.Close()
}
c.logger.Info("Connector closed")
})
return nil
}
// Wait blocks until connection is closed
func (c *Connector) Wait() {
<-c.stopCh
}
// GetURL returns the assigned tunnel URL
func (c *Connector) GetURL() string {
return c.assignedURL
}
// GetSubdomain returns the assigned subdomain
func (c *Connector) GetSubdomain() string {
return c.subdomain
}
// SetLatencyCallback sets the callback for latency updates
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
c.latencyCallback = cb
}
// GetLatency returns the last measured latency
func (c *Connector) GetLatency() time.Duration {
c.heartbeatMu.Lock()
defer c.heartbeatMu.Unlock()
return c.lastLatency
}
// GetStats returns the traffic stats from the frame handler
func (c *Connector) GetStats() *TrafficStats {
if c.frameHandler != nil {
return c.frameHandler.GetStats()
}
return nil
}
// IsClosed returns whether the connector has been closed
func (c *Connector) IsClosed() bool {
c.closedMu.RLock()
defer c.closedMu.RUnlock()
return c.closed
}
func (c *Connector) monitorQueuePressure() {
defer c.recoverer.Recover("monitorQueuePressure")
const (
pauseThreshold = 0.80
resumeThreshold = 0.50
checkInterval = 100 * time.Millisecond
)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
isPaused := false
for {
select {
case <-ticker.C:
queueLen := len(c.dataFrameQueue)
queueCap := cap(c.dataFrameQueue)
usage := float64(queueLen) / float64(queueCap)
if usage > pauseThreshold && !isPaused {
c.sendFlowControl("*", protocol.FlowControlPause)
isPaused = true
c.logger.Warn("Queue pressure high, sent pause signal",
zap.Int("queue_len", queueLen),
zap.Int("queue_cap", queueCap),
zap.Float64("usage", usage))
} else if usage < resumeThreshold && isPaused {
c.sendFlowControl("*", protocol.FlowControlResume)
isPaused = false
c.logger.Info("Queue pressure normal, sent resume signal",
zap.Int("queue_len", queueLen),
zap.Int("queue_cap", queueCap),
zap.Float64("usage", usage))
}
case <-c.stopCh:
return
}
}
}
func (c *Connector) sendFlowControl(streamID string, action protocol.FlowControlAction) {
frame := protocol.NewFlowControlFrame(streamID, action)
if err := c.SendFrame(frame); err != nil {
c.logger.Error("Failed to send flow control",
zap.String("action", string(action)),
zap.Error(err))
}
func isExpectedCloseError(err error) bool {
s := err.Error()
return strings.Contains(s, "EOF") ||
strings.Contains(s, "use of closed") ||
strings.Contains(s, "connection reset")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,450 @@
package tcp
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/pkg/config"
)
// PoolClient manages a pool of yamux sessions for tunnel connections.
type PoolClient struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
assignedURL string
tunnelID string
minSessions int
maxSessions int
initialSessions int
stats *stats.TrafficStats
httpClient *http.Client
latencyCallback atomic.Value // LatencyCallback
latencyNanos atomic.Int64
ctx context.Context
cancel context.CancelFunc
stopCh chan struct{}
doneCh chan struct{}
once sync.Once
wg sync.WaitGroup
closed atomic.Bool
primary *sessionHandle
mu sync.RWMutex
dataSessions map[string]*sessionHandle
desiredTotal int
lastScale time.Time
logger *zap.Logger
}
// NewPoolClient creates a new pool client.
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
tunnelType := cfg.TunnelType
if tunnelType == "" {
tunnelType = protocol.TunnelTypeTCP
}
numCPU := runtime.NumCPU()
minSessions := cfg.PoolMin
if minSessions <= 0 {
minSessions = 2
}
maxSessions := cfg.PoolMax
if maxSessions <= 0 {
maxSessions = max(numCPU*16, minSessions)
}
if maxSessions < minSessions {
maxSessions = minSessions
}
initialSessions := cfg.PoolSize
if initialSessions <= 0 {
initialSessions = 4
}
initialSessions = min(max(initialSessions, minSessions), maxSessions)
ctx, cancel := context.WithCancel(context.Background())
c := &PoolClient{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: tunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
minSessions: minSessions,
maxSessions: maxSessions,
initialSessions: initialSessions,
stats: stats.NewTrafficStats(),
ctx: ctx,
cancel: cancel,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
dataSessions: make(map[string]*sessionHandle),
logger: logger,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
c.httpClient = newLocalHTTPClient(tunnelType)
}
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
return c
}
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dialTLS()
if err != nil {
return err
}
maxData := max(c.maxSessions-1, 0)
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
ConnectionType: "primary",
PoolCapabilities: &protocol.PoolCapabilities{
MaxDataConns: maxData,
Version: 1,
},
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to marshal request: %w", err)
}
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to send registration: %w", err)
}
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ack, err := protocol.ReadFrame(primaryConn)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to read register ack: %w", err)
}
defer ack.Release()
_ = primaryConn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = primaryConn.Close()
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = primaryConn.Close()
return fmt.Errorf("registration error")
}
if ack.Type != protocol.FrameTypeRegisterAck {
_ = primaryConn.Close()
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to parse register response: %w", err)
}
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
if resp.SupportsDataConn && resp.TunnelID != "" {
c.tunnelID = resp.TunnelID
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Server(primaryConn, yamuxCfg)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
primary := &sessionHandle{
id: "primary",
conn: primaryConn,
session: session,
}
primary.touch()
c.primary = primary
c.wg.Add(1)
go func() {
defer c.wg.Done()
<-c.stopCh
}()
c.wg.Add(1)
go c.acceptLoop(primary, true)
c.wg.Add(1)
go c.sessionWatcher(primary, true)
c.wg.Add(1)
go c.pingLoop(primary)
if c.tunnelID != "" {
c.mu.Lock()
c.desiredTotal = c.initialSessions
c.mu.Unlock()
c.ensureSessions()
c.wg.Add(1)
go c.scalerLoop()
}
go func() {
c.wg.Wait()
close(c.doneCh)
}()
return nil
}
func (c *PoolClient) dialTLS() (net.Conn, error) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
_ = conn.Close()
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
return conn, nil
}
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
for {
select {
case <-c.stopCh:
return
default:
}
stream, err := h.session.Accept()
if err != nil {
if c.IsClosed() || isExpectedCloseError(err) {
return
}
if isPrimary {
c.logger.Debug("Primary session accept failed", zap.Error(err))
_ = c.Close()
return
}
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
c.removeDataSession(h.id)
return
}
h.active.Add(1)
h.touch()
c.stats.AddRequest()
c.stats.IncActiveConnections()
c.wg.Add(1)
go c.handleStream(h, stream)
}
}
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
select {
case <-c.stopCh:
return
case <-h.session.CloseChan():
if isPrimary {
_ = c.Close()
return
}
c.removeDataSession(h.id)
}
}
func (c *PoolClient) pingLoop(h *sessionHandle) {
defer c.wg.Done()
const maxConsecutiveFailures = 3
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
consecutiveFailures := 0
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
if h.session == nil || h.session.IsClosed() {
return
}
latency, err := h.session.Ping()
if err != nil {
consecutiveFailures++
c.logger.Debug("Ping failed",
zap.String("session_id", h.id),
zap.Int("consecutive_failures", consecutiveFailures),
zap.Error(err),
)
if consecutiveFailures >= maxConsecutiveFailures {
c.logger.Warn("Session ping failed too many times, closing",
zap.String("session_id", h.id),
zap.Int("failures", consecutiveFailures),
)
if h.id == "primary" {
_ = c.Close()
return
}
c.removeDataSession(h.id)
return
}
continue
}
consecutiveFailures = 0
h.touch()
c.latencyNanos.Store(int64(latency))
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
cb(latency)
}
}
}
// Close shuts down the client and all sessions.
func (c *PoolClient) Close() error {
var closeErr error
c.once.Do(func() {
c.closed.Store(true)
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
var data []*sessionHandle
var primary *sessionHandle
c.mu.Lock()
for _, h := range c.dataSessions {
data = append(data, h)
}
c.dataSessions = make(map[string]*sessionHandle)
primary = c.primary
c.primary = nil
c.mu.Unlock()
for _, h := range data {
if h == nil {
continue
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
}
if primary != nil {
if primary.session != nil {
closeErr = primary.session.Close()
}
if primary.conn != nil {
_ = primary.conn.Close()
}
}
})
return closeErr
}
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
if cb == nil {
cb = func(time.Duration) {}
}
c.latencyCallback.Store(cb)
}

View File

@@ -0,0 +1,253 @@
package tcp
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"time"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// handleStream routes incoming stream to appropriate handler.
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
defer c.wg.Done()
defer func() {
h.active.Add(-1)
c.stats.DecActiveConnections()
}()
defer stream.Close()
switch c.tunnelType {
case protocol.TunnelTypeHTTP, protocol.TunnelTypeHTTPS:
c.handleHTTPStream(stream)
default:
c.handleTCPStream(stream)
}
}
// handleTCPStream handles raw TCP tunneling.
func (c *PoolClient) handleTCPStream(stream net.Conn) {
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
if err != nil {
c.logger.Debug("Dial local failed", zap.Error(err))
return
}
defer localConn.Close()
if tcpConn, ok := localConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
_ = netutil.PipeWithCallbacksAndBufferSize(
c.ctx,
stream,
localConn,
pool.SizeLarge,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
)
}
// handleHTTPStream handles HTTP/HTTPS proxy requests.
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
cc := netutil.NewCountingConn(stream,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
)
br := bufio.NewReaderSize(cc, 32*1024)
req, err := http.ReadRequest(br)
if err != nil {
return
}
defer req.Body.Close()
_ = stream.SetReadDeadline(time.Time{})
if httputil.IsWebSocketUpgrade(req) {
c.handleWebSocketUpgrade(cc, req)
return
}
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
scheme := "http"
if c.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL := fmt.Sprintf("%s://%s:%d%s", scheme, c.localHost, c.localPort, req.URL.RequestURI())
outReq, err := http.NewRequestWithContext(ctx, req.Method, targetURL, req.Body)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
return
}
origHost := req.Host
httputil.CopyHeaders(outReq.Header, req.Header)
httputil.CleanHopByHopHeaders(outReq.Header)
targetHost := c.localHost
if c.localPort != 80 && c.localPort != 443 {
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
}
outReq.Host = targetHost
outReq.Header.Set("Host", targetHost)
if origHost != "" {
outReq.Header.Set("X-Forwarded-Host", origHost)
}
outReq.Header.Set("X-Forwarded-Proto", "https")
resp, err := c.httpClient.Do(outReq)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Local service unavailable")
return
}
defer resp.Body.Close()
_ = stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
if err := writeResponseHeader(cc, resp); err != nil {
return
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
buf := make([]byte, 32*1024)
for {
nr, er := resp.Body.Read(buf)
if nr > 0 {
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
nw, ew := cc.Write(buf[:nr])
if ew != nil || nr != nw {
break
}
}
if er != nil {
break
}
}
close(done)
}
// handleWebSocketUpgrade handles WebSocket upgrade requests.
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
scheme := "ws"
if c.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "wss"
}
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "WebSocket backend unavailable")
return
}
defer localConn.Close()
if c.tunnelType == protocol.TunnelTypeHTTPS {
tlsConn := tls.Client(localConn, &tls.Config{InsecureSkipVerify: true})
if err := tlsConn.Handshake(); err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "TLS handshake failed")
return
}
localConn = tlsConn
}
req.URL.Scheme = scheme
req.URL.Host = targetAddr
if err := req.Write(localConn); err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
return
}
localBr := bufio.NewReader(localConn)
resp, err := http.ReadResponse(localBr, req)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to read upgrade response")
return
}
if err := resp.Write(cc); err != nil {
return
}
if resp.StatusCode == http.StatusSwitchingProtocols {
_ = netutil.PipeWithCallbacksAndBufferSize(
c.ctx,
cc,
localConn,
pool.SizeLarge,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
)
}
}
// newLocalHTTPClient creates an HTTP client for local service requests.
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {
tlsConfig = &tls.Config{InsecureSkipVerify: true}
}
return &http.Client{
Transport: &http.Transport{
MaxIdleConns: 2000,
MaxIdleConnsPerHost: 1000,
MaxConnsPerHost: 0,
IdleConnTimeout: 180 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
TLSHandshakeTimeout: 5 * time.Second,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 500 * time.Millisecond,
WriteBufferSize: 32 * 1024,
ReadBufferSize: 32 * 1024,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
func writeResponseHeader(w io.Writer, resp *http.Response) error {
statusLine := fmt.Sprintf("HTTP/%d.%d %d %s\r\n",
resp.ProtoMajor, resp.ProtoMinor,
resp.StatusCode, http.StatusText(resp.StatusCode))
if _, err := io.WriteString(w, statusLine); err != nil {
return err
}
if err := resp.Header.Write(w); err != nil {
return err
}
_, err := io.WriteString(w, "\r\n")
return err
}

View File

@@ -0,0 +1,312 @@
package tcp
import (
"fmt"
"io"
"net"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
)
// sessionHandle wraps a yamux session with metadata.
type sessionHandle struct {
id string
conn net.Conn
session *yamux.Session
active atomic.Int64
lastActive atomic.Int64 // unix nanos
closed atomic.Bool
}
func (h *sessionHandle) touch() {
h.lastActive.Store(time.Now().UnixNano())
}
func (h *sessionHandle) lastActiveTime() time.Time {
n := h.lastActive.Load()
if n == 0 {
return time.Time{}
}
return time.Unix(0, n)
}
// scalerLoop monitors load and adjusts session count.
func (c *PoolClient) scalerLoop() {
defer c.wg.Done()
const (
checkInterval = 5 * time.Second
scaleUpCooldown = 5 * time.Second
scaleDownCooldown = 60 * time.Second
capacityPerSession = int64(64)
scaleUpLoad = 0.7
scaleDownLoad = 0.3
)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
c.mu.Lock()
desired := c.desiredTotal
if desired == 0 {
desired = c.initialSessions
c.desiredTotal = desired
}
lastScale := c.lastScale
c.mu.Unlock()
current := c.sessionCount()
if current <= 0 {
continue
}
active := c.stats.GetActiveConnections()
load := float64(active) / float64(int64(current)*capacityPerSession)
sinceLastScale := time.Since(lastScale)
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
c.mu.Lock()
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
c.lastScale = time.Now()
c.mu.Unlock()
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
c.mu.Lock()
c.desiredTotal = max(c.desiredTotal-1, c.minSessions)
c.lastScale = time.Now()
c.mu.Unlock()
}
c.ensureSessions()
}
}
// ensureSessions adjusts session count to match desired.
func (c *PoolClient) ensureSessions() {
if c.IsClosed() || c.tunnelID == "" {
return
}
c.mu.RLock()
desired := c.desiredTotal
c.mu.RUnlock()
desired = min(max(desired, c.minSessions), c.maxSessions)
current := c.sessionCount()
if current < desired {
for i := 0; i < desired-current; i++ {
if err := c.addDataSession(); err != nil {
c.logger.Debug("Add data session failed", zap.Error(err))
break
}
}
return
}
if current > desired {
c.removeIdleSessions(current - desired)
}
}
// addDataSession creates a new data session.
func (c *PoolClient) addDataSession() error {
select {
case <-c.stopCh:
return net.ErrClosed
default:
}
if c.tunnelID == "" {
return fmt.Errorf("server does not support data connections")
}
conn, err := c.dialTLS()
if err != nil {
return err
}
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
req := protocol.DataConnectRequest{
TunnelID: c.tunnelID,
Token: c.token,
ConnectionID: connID,
}
payload, err := json.Marshal(req)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to marshal data connect request: %w", err)
}
if err := protocol.WriteFrame(conn, protocol.NewFrame(protocol.FrameTypeDataConnect, payload)); err != nil {
_ = conn.Close()
return fmt.Errorf("failed to send data connect: %w", err)
}
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
ack, err := protocol.ReadFrame(conn)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to read data connect ack: %w", err)
}
defer ack.Release()
_ = conn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = conn.Close()
return fmt.Errorf("data connect error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = conn.Close()
return fmt.Errorf("data connect error")
}
if ack.Type != protocol.FrameTypeDataConnectAck {
_ = conn.Close()
return fmt.Errorf("unexpected data connect ack frame: %s", ack.Type)
}
var resp protocol.DataConnectResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = conn.Close()
return fmt.Errorf("failed to parse data connect response: %w", err)
}
if !resp.Accepted {
_ = conn.Close()
return fmt.Errorf("data connection rejected: %s", resp.Message)
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Server(conn, yamuxCfg)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
h := &sessionHandle{
id: connID,
conn: conn,
session: session,
}
h.touch()
c.mu.Lock()
c.dataSessions[connID] = h
c.mu.Unlock()
c.wg.Add(1)
go c.acceptLoop(h, false)
c.wg.Add(1)
go c.sessionWatcher(h, false)
return nil
}
// removeIdleSessions removes n idle sessions.
func (c *PoolClient) removeIdleSessions(n int) {
if n <= 0 {
return
}
type candidate struct {
id string
active int64
lastActive time.Time
}
c.mu.RLock()
candidates := make([]candidate, 0, len(c.dataSessions))
for id, h := range c.dataSessions {
candidates = append(candidates, candidate{
id: id,
active: h.active.Load(),
lastActive: h.lastActiveTime(),
})
}
c.mu.RUnlock()
removed := 0
for removed < n {
var best candidate
found := false
for _, cand := range candidates {
if cand.active != 0 {
continue
}
if !found || cand.lastActive.Before(best.lastActive) {
best = cand
found = true
}
}
if !found {
return
}
if c.removeDataSession(best.id) {
removed++
}
for i := range candidates {
if candidates[i].id == best.id {
candidates[i].active = 1
break
}
}
}
}
// removeDataSession removes a data session by ID.
func (c *PoolClient) removeDataSession(id string) bool {
var h *sessionHandle
c.mu.Lock()
h = c.dataSessions[id]
delete(c.dataSessions, id)
c.mu.Unlock()
if h == nil {
return false
}
if !h.closed.CompareAndSwap(false, true) {
return false
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
return true
}
// sessionCount returns the total number of active sessions.
func (c *PoolClient) sessionCount() int {
c.mu.RLock()
defer c.mu.RUnlock()
count := len(c.dataSessions)
if c.primary != nil {
count++
}
return count
}

View File

@@ -1,490 +1,251 @@
package proxy
import (
"context"
"bufio"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
const openStreamTimeout = 10 * time.Second
type Handler struct {
manager *tunnel.Manager
logger *zap.Logger
responses *ResponseHandler
domain string
authToken string
headerPool *pool.HeaderPool
bufferPool *pool.AdaptiveBufferPool
manager *tunnel.Manager
logger *zap.Logger
domain string
authToken string
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler {
return &Handler{
manager: manager,
logger: logger,
responses: responses,
domain: domain,
authToken: authToken,
headerPool: pool.NewHeaderPool(),
bufferPool: pool.NewAdaptiveBufferPool(),
manager: manager,
logger: logger,
domain: domain,
authToken: authToken,
}
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Always handle /health and /stats directly, regardless of subdomain
// Always handle /health and /stats directly, regardless of subdomain.
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
}
if r.URL.Path == "/stats" {
h.serveStats(w, r)
return
}
subdomain := h.extractSubdomain(r.Host)
if subdomain == "" {
h.serveHomePage(w, r)
return
}
conn, ok := h.manager.Get(subdomain)
if !ok {
tconn, ok := h.manager.Get(subdomain)
if !ok || tconn == nil {
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
return
}
if conn.IsClosed() {
if tconn.IsClosed() {
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
return
}
transport := conn.GetTransport()
if transport == nil {
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
return
}
tType := conn.GetTunnelType()
tType := tconn.GetTunnelType()
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
return
}
requestID := utils.GenerateID()
// Check for WebSocket upgrade
if httputil.IsWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
}
// Open stream with timeout
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
w.Header().Set("Connection", "close")
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
defer stream.Close()
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
// Track active connections
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
// Wrap stream with counting for traffic stats
countingStream := netutil.NewCountingConn(stream,
tconn.AddBytesOut, // Data read from stream = bytes out to client
tconn.AddBytesIn, // Data written to stream = bytes in from client
)
// 1) Write request over the stream (net/http handles large bodies correctly).
if err := r.Write(countingStream); err != nil {
w.Header().Set("Connection", "close")
_ = r.Body.Close()
http.Error(w, "Forward failed", http.StatusBadGateway)
return
}
// 2) Read response from stream.
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
if err != nil {
w.Header().Set("Connection", "close")
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 3) Copy headers (strip hop-by-hop).
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
// Ensure message delimiting works with our custom ResponseWriter:
// - If Content-Length is known, send it.
// - Otherwise, re-chunk the decoded body ourselves.
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
} else {
w.Header().Del("Content-Length")
}
w.WriteHeader(statusCode)
return
}
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
w.WriteHeader(statusCode)
ctx := r.Context()
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
return
}
w.Header().Del("Content-Length")
w.Header().Set("Transfer-Encoding", "chunked")
if len(resp.Trailer) > 0 {
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
}
w.WriteHeader(statusCode)
ctx := r.Context()
var cancelTransport func()
if transport != nil {
cancelOnce := sync.Once{}
cancelFunc := func() {
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
if err := transport.SendFrame(frame); err != nil {
h.logger.Debug("Failed to send cancel frame to client",
zap.String("request_id", requestID),
zap.Error(err),
)
}
}
cancelTransport = func() {
cancelOnce.Do(cancelFunc)
}
h.responses.RegisterCancelFunc(requestID, cancelTransport)
defer h.responses.CleanupCancelFunc(requestID)
}
largeBufferPtr := h.bufferPool.GetLarge()
tempBufPtr := h.bufferPool.GetMedium()
defer func() {
h.bufferPool.PutLarge(largeBufferPtr)
h.bufferPool.PutMedium(tempBufPtr)
}()
buffer := (*largeBufferPtr)[:0]
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
var totalRead int64
var hitThreshold bool
for totalRead < streamingThreshold {
n, err := r.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
r.Body.Close()
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
if err != nil {
r.Body.Close()
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if totalRead >= streamingThreshold {
hitThreshold = true
break
}
}
if !hitThreshold {
r.Body.Close()
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
}
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReq := protocol.HTTPRequest{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
Body: body,
}
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequest,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
if err != nil {
h.logger.Error("Encode data payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(frame); err != nil {
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReqHead := protocol.HTTPRequestHead{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
ContentLength: r.ContentLength,
}
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead, // shared streaming head type
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
h.logger.Error("Encode head payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(headFrame); err != nil {
h.logger.Error("Send head frame failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
if len(bufferedData) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
if err != nil {
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
streamBufPtr := h.bufferPool.GetMedium()
defer h.bufferPool.PutMedium(streamBufPtr)
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
for {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming request cancelled via context",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
default:
stream.Close()
case <-done:
}
}()
n, readErr := r.Body.Read(buffer)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
if err != nil {
h.logger.Error("Encode chunk payload failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send chunk frame failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
if readErr == io.EOF {
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
}
break
}
if readErr != nil {
h.logger.Error("Read request body failed", zap.Error(readErr))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
h.logger.Debug("Write chunked response failed", zap.Error(err))
}
close(done)
stream.Close()
}
r.Body.Close()
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
type result struct {
stream net.Conn
err error
}
ch := make(chan result, 1)
go func() {
s, err := tconn.OpenStream()
ch <- result{s, err}
}()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Streaming request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
case r := <-ch:
return r.stream, r.err
case <-time.After(openStreamTimeout):
return nil, fmt.Errorf("open stream timeout")
}
}
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
if resp == nil {
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
// For buffered responses, we have the complete body, so we can set Content-Length
// Skip ALL hop-by-hop headers - client should have already cleaned them
for key, values := range resp.Headers {
tconn.IncActiveConnections()
hj, ok := w.(http.Hijacker)
if !ok {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
return
}
if err := r.Write(stream); err != nil {
stream.Close()
clientConn.Close()
tconn.DecActiveConnections()
return
}
go func() {
defer stream.Close()
defer clientConn.Close()
defer tconn.DecActiveConnections()
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
}()
}
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip hop-by-hop headers completely using canonical key comparison
// Hop-by-hop headers must not be forwarded.
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
@@ -496,29 +257,61 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
if canonicalKey == "Location" && len(values) > 0 {
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
w.Header().Set("Location", rewrittenLocation)
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
continue
}
for _, value := range values {
w.Header().Add(key, value)
dst.Add(key, value)
}
}
}
func trailerKeys(hdr http.Header) string {
keys := make([]string, 0, len(hdr))
for k := range hdr {
keys = append(keys, k)
}
// Deterministic order is nicer for debugging; no semantic impact.
sortStrings(keys)
return strings.Join(keys, ", ")
}
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
buf := make([]byte, 32*1024)
for {
n, err := r.Read(buf)
if n > 0 {
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
return werr
}
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
return werr
}
}
if err == io.EOF {
break
}
if err != nil {
return err
}
}
// For buffered mode, always set Content-Length with the actual body size
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
if _, err := io.WriteString(w, "0\r\n"); err != nil {
return err
}
w.WriteHeader(statusCode)
if len(resp.Body) > 0 {
w.Write(resp.Body)
for k, vv := range trailer {
for _, v := range vv {
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
return err
}
}
}
_, err := io.WriteString(w, "\r\n")
return err
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
@@ -535,22 +328,13 @@ func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
strings.HasPrefix(locationURL.Host, "localhost:") ||
locationURL.Host == "127.0.0.1" ||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
scheme := "https"
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
parts := strings.Split(proxyHost, ":")
if len(parts) == 2 && parts[1] != "443" {
scheme = "https"
}
}
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
if locationURL.RawQuery != "" {
rewritten += "?" + locationURL.RawQuery
}
if locationURL.Fragment != "" {
rewritten += "#" + locationURL.Fragment
}
return rewritten
}
@@ -568,8 +352,7 @@ func (h *Handler) extractSubdomain(host string) string {
suffix := "." + h.domain
if strings.HasSuffix(host, suffix) {
subdomain := strings.TrimSuffix(host, suffix)
return subdomain
return strings.TrimSuffix(host, suffix)
}
return ""
@@ -652,9 +435,17 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
}
for _, conn := range connections {
if conn == nil {
continue
}
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
"subdomain": conn.Subdomain,
"last_active": conn.LastActive.Unix(),
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
"bytes_in": conn.GetBytesIn(),
"bytes_out": conn.GetBytesOut(),
"active_connections": conn.GetActiveConnections(),
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
})
}
@@ -668,3 +459,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func sortStrings(s []string) {
for i := 0; i < len(s); i++ {
for j := i + 1; j < len(s); j++ {
if s[j] < s[i] {
s[i], s[j] = s[j], s[i]
}
}
}
}

View File

@@ -1,421 +0,0 @@
package proxy
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// responseChanEntry holds a response channel and its creation time
type responseChanEntry struct {
ch chan *protocol.HTTPResponse
createdAt time.Time
}
// streamingResponseEntry holds a streaming response writer
type streamingResponseEntry struct {
w http.ResponseWriter
flusher http.Flusher
createdAt time.Time
lastActivityAt time.Time
headersSent bool
done chan struct{}
mu sync.Mutex
}
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
type ResponseHandler struct {
channels map[string]*responseChanEntry
streamingChannels map[string]*streamingResponseEntry
cancelFuncs map[string]func()
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
}
// NewResponseHandler creates a new response handler
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
streamingChannels: make(map[string]*streamingResponseEntry),
cancelFuncs: make(map[string]func()),
logger: logger,
stopCh: make(chan struct{}),
}
// Start single cleanup goroutine instead of one per request
go h.cleanupLoop()
return h
}
// CreateResponseChan creates a response channel for a request ID
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
h.mu.Lock()
defer h.mu.Unlock()
ch := make(chan *protocol.HTTPResponse, 1)
h.channels[requestID] = &responseChanEntry{
ch: ch,
createdAt: time.Now(),
}
return ch
}
// CreateStreamingResponse creates a streaming response entry for a request ID
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
h.mu.Lock()
defer h.mu.Unlock()
flusher, _ := w.(http.Flusher)
done := make(chan struct{})
now := time.Now()
h.streamingChannels[requestID] = &streamingResponseEntry{
w: w,
flusher: flusher,
createdAt: now,
lastActivityAt: now,
done: done,
}
return done
}
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
if cancel == nil {
return
}
h.mu.Lock()
h.cancelFuncs[requestID] = cancel
h.mu.Unlock()
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
defer h.mu.RUnlock()
if entry := h.channels[requestID]; entry != nil {
return entry.ch
}
return nil
}
// SendResponse sends a response to the waiting channel
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
h.mu.RLock()
entry, exists := h.channels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return
}
select {
case entry.ch <- resp:
case <-time.After(30 * time.Second):
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
zap.String("request_id", requestID),
zap.Int("status_code", resp.StatusCode),
zap.Int("body_size", len(resp.Body)),
)
}
}
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if entry.headersSent {
return nil
}
// Copy headers, removing hop-by-hop headers that were already handled by client
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
// But we need to check again in case they slipped through
hasContentLength := false
for key, values := range head.Headers {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip ALL hop-by-hop headers
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if canonicalKey == "Content-Length" {
hasContentLength = true
}
for _, value := range values {
entry.w.Header().Add(key, value)
}
}
// For streaming responses, decide how to indicate message length
if head.ContentLength >= 0 && !hasContentLength {
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
}
statusCode := head.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
entry.w.WriteHeader(statusCode)
entry.headersSent = true
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
return nil
}
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if len(chunk) > 0 {
_, err := entry.w.Write(chunk)
if err != nil {
if isClientDisconnectError(err) {
select {
case <-entry.done:
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
select {
case <-entry.done:
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
}
if isLast {
select {
case <-entry.done:
default:
close(entry.done)
}
}
return nil
}
func isClientDisconnectError(err error) bool {
if err == nil {
return false
}
if netErr, ok := err.(*net.OpError); ok {
if netErr.Err != nil {
errStr := netErr.Err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "connection refused") {
return true
}
}
}
errStr := err.Error()
return strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "use of closed network connection")
}
// triggerCancel invokes and removes the cancel callback for a request.
func (h *ResponseHandler) triggerCancel(requestID string) {
h.mu.Lock()
cancel := h.cancelFuncs[requestID]
if cancel != nil {
delete(h.cancelFuncs, requestID)
}
h.mu.Unlock()
if cancel != nil {
go func() {
cancel()
}()
}
}
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.channels[requestID]; exists {
close(entry.ch)
delete(h.channels, requestID)
}
}
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.streamingChannels[requestID]; exists {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
}
}
// CleanupCancelFunc removes a registered cancel callback.
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
h.mu.Lock()
delete(h.cancelFuncs, requestID)
h.mu.Unlock()
}
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels) + len(h.streamingChannels)
}
func (h *ResponseHandler) cleanupLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.cleanupExpiredChannels()
case <-h.stopCh:
return
}
}
}
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 5 * time.Minute
streamingTimeout := 5 * time.Minute
h.mu.Lock()
defer h.mu.Unlock()
expiredCount := 0
cancelList := make([]string, 0)
for requestID, entry := range h.channels {
if now.Sub(entry.createdAt) > timeout {
close(entry.ch)
delete(h.channels, requestID)
expiredCount++
}
}
for requestID, entry := range h.streamingChannels {
if now.Sub(entry.lastActivityAt) > streamingTimeout {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
cancelList = append(cancelList, requestID)
expiredCount++
}
}
for _, requestID := range cancelList {
if cancel := h.cancelFuncs[requestID]; cancel != nil {
delete(h.cancelFuncs, requestID)
go cancel()
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
)
}
}
func (h *ResponseHandler) Close() {
close(h.stopCh)
h.mu.Lock()
defer h.mu.Unlock()
for _, entry := range h.channels {
close(entry.ch)
}
h.channels = make(map[string]*responseChanEntry)
for _, entry := range h.streamingChannels {
select {
case <-entry.done:
default:
close(entry.done)
}
}
h.streamingChannels = make(map[string]*streamingResponseEntry)
for _, cancel := range h.cancelFuncs {
cancel()
}
h.cancelFuncs = make(map[string]func())
}

View File

@@ -13,6 +13,7 @@ import (
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
@@ -33,36 +34,27 @@ type Connection struct {
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
proxy *TunnelProxy
stopCh chan struct{}
once sync.Once
lastHeartbeat time.Time
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
responseChans HTTPResponseHandler
tunnelType protocol.TunnelType // Track tunnel type
ctx context.Context
cancel context.CancelFunc
// Flow control
paused bool
pauseCond *sync.Cond
}
// gost-like TCP tunnel (yamux)
session *yamux.Session
proxy *Proxy
// HTTPResponseHandler interface for response channel operations
type HTTPResponseHandler interface {
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
CleanupResponseChan(requestID string)
SendResponse(requestID string, resp *protocol.HTTPResponse)
// Streaming response methods
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
// Multi-connection support
tunnelID string
groupManager *ConnectionGroupManager
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
conn: conn,
@@ -73,13 +65,12 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
lastHeartbeat: time.Now(),
ctx: ctx,
cancel: cancel,
groupManager: groupManager,
}
c.pauseCond = sync.NewCond(&c.mu)
return c
}
@@ -97,8 +88,8 @@ func (c *Connection) Handle() error {
// Use buffered reader to support peeking
reader := bufio.NewReader(c.conn)
// Peek first 8 bytes to detect protocol
peek, err := reader.Peek(8)
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
peek, err := reader.Peek(4)
if err != nil {
return fmt.Errorf("failed to peek connection: %w", err)
}
@@ -127,6 +118,11 @@ func (c *Connection) Handle() error {
sf := protocol.WithFrame(frame)
defer sf.Close()
// Handle data connection request (for multi-connection pool)
if sf.Frame.Type == protocol.FrameTypeDataConnect {
return c.handleDataConnect(sf.Frame, reader)
}
if sf.Frame.Type != protocol.FrameTypeRegister {
return fmt.Errorf("expected register frame, got %s", sf.Frame.Type)
}
@@ -180,7 +176,6 @@ func (c *Connection) Handle() error {
// Store TCP connection reference and metadata for HTTP proxy routing
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
c.tunnelConn.SetTransport(c, req.TunnelType)
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
@@ -208,11 +203,33 @@ func (c *Connection) Handle() error {
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
}
// Generate TunnelID for multi-connection support if client supports it
var tunnelID string
var supportsDataConn bool
recommendedConns := 0
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
// Client supports connection pooling
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
tunnelID = group.TunnelID
c.tunnelID = tunnelID
supportsDataConn = true
recommendedConns = 4 // Recommend 4 data connections
c.logger.Info("Created connection group for multi-connection support",
zap.String("tunnel_id", tunnelID),
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
)
}
resp := protocol.RegisterResponse{
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
TunnelID: tunnelID,
SupportsDataConn: supportsDataConn,
RecommendedConns: recommendedConns,
}
respData, _ := json.Marshal(resp)
@@ -224,6 +241,17 @@ func (c *Connection) Handle() error {
return fmt.Errorf("failed to send registration ack: %w", err)
}
// Clear deadline for tunnel data-plane.
c.conn.SetReadDeadline(time.Time{})
// gost-like tunnels: switch to yamux after RegisterAck.
if req.TunnelType == protocol.TunnelTypeTCP {
return c.handleTCPTunnel(reader)
}
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
return c.handleHTTPProxyTunnel(reader)
}
c.frameWriter = protocol.NewFrameWriter(c.conn)
c.frameWriter.SetWriteErrorHandler(func(err error) {
@@ -231,15 +259,6 @@ func (c *Connection) Handle() error {
c.Close()
})
c.conn.SetReadDeadline(time.Time{})
if req.TunnelType == protocol.TunnelTypeTCP {
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start TCP proxy: %w", err)
}
}
go c.heartbeatChecker()
return c.handleFrames(reader)
@@ -376,7 +395,7 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if isTimeoutError(err) {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
@@ -404,15 +423,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
c.handleHeartbeat()
sf.Close()
case protocol.FrameTypeData:
// Data frame from client (response to forwarded request)
c.handleDataFrame(sf.Frame)
sf.Close()
case protocol.FrameTypeFlowControl:
c.handleFlowControl(sf.Frame)
sf.Close()
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Client requested close")
@@ -436,127 +446,12 @@ func (c *Connection) handleHeartbeat() {
// Send heartbeat ack
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
err := c.frameWriter.WriteFrame(ackFrame)
err := c.frameWriter.WriteControl(ackFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
}
}
// handleDataFrame handles data frame (response from client)
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
// Decode payload (auto-detects protocol version)
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
c.logger.Error("Failed to decode data payload",
zap.Error(err),
)
return
}
c.logger.Debug("Received data frame",
zap.String("stream_id", header.StreamID),
zap.String("type", header.Type.String()),
zap.Int("data_size", len(data)),
)
switch header.Type {
case protocol.DataTypeResponse:
// TCP tunnel response, forward to proxy
if c.proxy != nil {
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
c.logger.Error("Failed to handle response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
}
}
case protocol.DataTypeHTTPResponse:
if c.responseChans == nil {
c.logger.Warn("No response channel handler for HTTP response",
zap.String("stream_id", header.StreamID),
)
return
}
// Decode HTTP response (auto-detects JSON vs msgpack)
httpResp, err := protocol.DecodeHTTPResponse(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
// Route by request ID when provided to keep request/response aligned.
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
c.responseChans.SendResponse(reqID, httpResp)
case protocol.DataTypeHTTPHead:
// Streaming HTTP response headers
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP head",
zap.String("stream_id", header.StreamID),
)
return
}
httpHead, err := protocol.DecodeHTTPResponseHead(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response head",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
c.logger.Error("Failed to send streaming head",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeHTTPBodyChunk:
// Streaming HTTP response body chunk
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP chunk",
zap.String("stream_id", header.StreamID),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
c.logger.Error("Failed to send streaming chunk",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeClose:
// Client is closing the stream
if c.proxy != nil {
c.proxy.CloseStream(header.StreamID)
}
default:
c.logger.Warn("Unknown data frame type",
zap.String("type", header.Type.String()),
zap.String("stream_id", header.StreamID),
)
}
}
// heartbeatChecker checks for heartbeat timeout
func (c *Connection) heartbeatChecker() {
ticker := time.NewTicker(constants.HeartbeatInterval)
@@ -583,16 +478,6 @@ func (c *Connection) heartbeatChecker() {
}
}
func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
if frame.Type == protocol.FrameTypeData {
return c.sendWithBackpressure(frame)
}
return c.frameWriter.WriteFrame(frame)
}
func (c *Connection) sendError(code, message string) {
errMsg := protocol.ErrorMessage{
Code: code,
@@ -618,8 +503,12 @@ func (c *Connection) Close() {
c.cancel()
}
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
if c.conn != nil {
_ = c.conn.SetDeadline(time.Now())
}
if c.frameWriter != nil {
c.frameWriter.Flush()
c.frameWriter.Close()
}
@@ -627,7 +516,13 @@ func (c *Connection) Close() {
c.proxy.Stop()
}
c.conn.Close()
if c.session != nil {
_ = c.session.Close()
}
if c.conn != nil {
c.conn.Close()
}
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
@@ -635,6 +530,12 @@ func (c *Connection) Close() {
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
// Clean up connection group when PRIMARY connection closes
// (only primary connections have subdomain set)
if c.tunnelID != "" && c.groupManager != nil {
c.groupManager.RemoveGroup(c.tunnelID)
}
}
c.logger.Info("Connection closed",
@@ -643,11 +544,6 @@ func (c *Connection) Close() {
})
}
// GetSubdomain returns the assigned subdomain
func (c *Connection) GetSubdomain() string {
return c.subdomain
}
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
@@ -698,39 +594,196 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
return w.writer.Write(data)
}
func (c *Connection) handleFlowControl(frame *protocol.Frame) {
msg, err := protocol.DecodeFlowControlMessage(frame.Payload)
// handleDataConnect handles a data connection join request
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
var req protocol.DataConnectRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
c.sendError("invalid_request", "Failed to parse data connect request")
return fmt.Errorf("failed to parse data connect request: %w", err)
}
c.logger.Info("Data connection request received",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Validate the request
if c.groupManager == nil {
c.sendDataConnectError("not_supported", "Multi-connection not supported")
return fmt.Errorf("group manager not available")
}
// Validate auth token
if c.authToken != "" && req.Token != c.authToken {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
group, ok := c.groupManager.GetGroup(req.TunnelID)
if !ok || group == nil {
c.sendDataConnectError("join_failed", "Tunnel not found")
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token against the primary registration token.
if group.Token != "" && req.Token != group.Token {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
// Store tunnelID for cleanup
c.tunnelID = req.TunnelID
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
// stream forwarding, not framed request/response routing.
if group.TunnelType == protocol.TunnelTypeTCP {
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
c.logger.Info("TCP data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Clear deadline for yamux data-plane.
_ = c.conn.SetReadDeadline(time.Time{})
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
// Add data connection to group
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
if err != nil {
c.logger.Error("Failed to decode flow control", zap.Error(err))
return
c.sendDataConnectError("join_failed", err.Error())
return fmt.Errorf("failed to join connection group: %w", err)
}
c.mu.Lock()
defer c.mu.Unlock()
// Send success response
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
switch msg.Action {
case protocol.FlowControlPause:
c.paused = true
c.logger.Warn("Client requested pause",
zap.String("stream", msg.StreamID))
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
case protocol.FlowControlResume:
c.paused = false
c.pauseCond.Broadcast()
c.logger.Info("Client requested resume",
zap.String("stream", msg.StreamID))
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
default:
c.logger.Warn("Unknown flow control action",
zap.String("action", string(msg.Action)))
c.logger.Info("Data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Handle data frames on this connection
return c.handleDataConnectionFrames(dataConn, reader)
}
// handleDataConnectionFrames handles frames on a data connection
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
defer func() {
// Get the group and remove this data connection
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
group.RemoveDataConnection(dataConn.ID)
}
}()
for {
select {
case <-dataConn.stopCh:
return nil
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
// Timeout is OK, continue
if isTimeoutError(err) {
continue
}
return err
}
dataConn.mu.Lock()
dataConn.LastActive = time.Now()
dataConn.mu.Unlock()
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Data connection closed by client",
zap.String("connection_id", dataConn.ID))
return nil
default:
sf.Close()
c.logger.Warn("Unexpected frame type on data connection",
zap.String("type", sf.Frame.Type.String()),
zap.String("connection_id", dataConn.ID),
)
}
}
}
func (c *Connection) sendWithBackpressure(frame *protocol.Frame) error {
c.mu.Lock()
for c.paused {
c.pauseCond.Wait()
func isTimeoutError(err error) bool {
if err == nil {
return false
}
c.mu.Unlock()
return c.frameWriter.WriteFrame(frame)
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// Fallback for wrapped errors without net.Error
return strings.Contains(err.Error(), "i/o timeout")
}
// sendDataConnectError sends a data connect error response
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,
Message: fmt.Sprintf("%s: %s", code, message),
}
respData, _ := json.Marshal(resp)
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
protocol.WriteFrame(c.conn, frame)
}

View File

@@ -0,0 +1,438 @@
package tcp
import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
type DataConnection struct {
ID string
Conn net.Conn
LastActive time.Time
closed bool
closedMu sync.RWMutex
stopCh chan struct{}
mu sync.RWMutex
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
Token string
PrimaryConn *Connection
DataConns map[string]*DataConnection
Sessions map[string]*yamux.Session
TunnelType protocol.TunnelType
RegisteredAt time.Time
LastActivity time.Time
sessionIdx uint32
mu sync.RWMutex
stopCh chan struct{}
logger *zap.Logger
heartbeatStarted bool
}
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
return &ConnectionGroup{
TunnelID: tunnelID,
Subdomain: subdomain,
Token: token,
PrimaryConn: primaryConn,
DataConns: make(map[string]*DataConnection),
Sessions: make(map[string]*yamux.Session),
TunnelType: tunnelType,
RegisteredAt: time.Now(),
LastActivity: time.Now(),
stopCh: make(chan struct{}),
logger: logger.With(zap.String("tunnel_id", tunnelID)),
}
}
// StartHeartbeat starts a goroutine that periodically pings all sessions
// and removes dead ones. The caller should ensure this is only called once.
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
go g.heartbeatLoop(interval, timeout)
}
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
const maxConsecutiveFailures = 3
failureCount := make(map[string]int)
for {
select {
case <-g.stopCh:
return
case <-ticker.C:
}
g.mu.RLock()
sessions := make(map[string]*yamux.Session, len(g.Sessions))
for id, s := range g.Sessions {
sessions[id] = s
}
g.mu.RUnlock()
for id, session := range sessions {
if session == nil || session.IsClosed() {
g.RemoveSession(id)
delete(failureCount, id)
continue
}
// Ping with timeout
done := make(chan error, 1)
go func(s *yamux.Session) {
_, err := s.Ping()
done <- err
}(session)
var err error
select {
case err = <-done:
case <-time.After(timeout):
err = fmt.Errorf("ping timeout")
case <-g.stopCh:
return
}
if err != nil {
failureCount[id]++
g.logger.Debug("Session ping failed",
zap.String("session_id", id),
zap.Int("consecutive_failures", failureCount[id]),
zap.Error(err),
)
if failureCount[id] >= maxConsecutiveFailures {
g.logger.Warn("Session ping failed too many times, removing",
zap.String("session_id", id),
zap.Int("failures", failureCount[id]),
)
g.RemoveSession(id)
delete(failureCount, id)
}
} else {
// Reset on success
failureCount[id] = 0
g.mu.Lock()
g.LastActivity = time.Now()
g.mu.Unlock()
}
}
// Check if all sessions are gone
g.mu.RLock()
sessionCount := len(g.Sessions)
g.mu.RUnlock()
if sessionCount == 0 {
g.logger.Info("All sessions closed, tunnel will be cleaned up")
}
}
}
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
g.mu.Lock()
defer g.mu.Unlock()
dataConn := &DataConnection{
ID: connID,
Conn: conn,
LastActive: time.Now(),
stopCh: make(chan struct{}),
}
g.DataConns[connID] = dataConn
g.LastActivity = time.Now()
return dataConn
}
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
g.mu.Lock()
defer g.mu.Unlock()
if dataConn, ok := g.DataConns[connID]; ok {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
delete(g.DataConns, connID)
}
}
func (g *ConnectionGroup) DataConnectionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.DataConns)
}
func (g *ConnectionGroup) Close() {
g.mu.Lock()
select {
case <-g.stopCh:
g.mu.Unlock()
return
default:
close(g.stopCh)
}
dataConns := make([]*DataConnection, 0, len(g.DataConns))
for _, dataConn := range g.DataConns {
dataConns = append(dataConns, dataConn)
}
g.DataConns = make(map[string]*DataConnection)
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for _, session := range g.Sessions {
if session != nil {
sessions = append(sessions, session)
}
}
g.Sessions = make(map[string]*yamux.Session)
g.mu.Unlock()
for _, dataConn := range dataConns {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
_ = dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
}
for _, session := range sessions {
_ = session.Close()
}
}
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
g.mu.RLock()
defer g.mu.RUnlock()
return time.Since(g.LastActivity) > timeout
}
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
if connID == "" || session == nil {
return
}
g.mu.Lock()
if g.Sessions == nil {
g.Sessions = make(map[string]*yamux.Session)
}
g.Sessions[connID] = session
g.LastActivity = time.Now()
// Start heartbeat on first session
shouldStartHeartbeat := !g.heartbeatStarted
if shouldStartHeartbeat {
g.heartbeatStarted = true
}
g.mu.Unlock()
if shouldStartHeartbeat {
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
}
}
func (g *ConnectionGroup) RemoveSession(connID string) {
if connID == "" {
return
}
var session *yamux.Session
g.mu.Lock()
if g.Sessions != nil {
session = g.Sessions[connID]
delete(g.Sessions, connID)
}
g.mu.Unlock()
if session != nil {
_ = session.Close()
}
}
func (g *ConnectionGroup) SessionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.Sessions)
}
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
const (
maxStreamsPerSession = 256
maxRetries = 3
backoffBase = 25 * time.Millisecond
)
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
select {
case <-g.stopCh:
return nil, net.ErrClosed
default:
}
sessions := g.sessionsSnapshot()
if len(sessions) == 0 {
return nil, net.ErrClosed
}
tried := make([]bool, len(sessions))
anyUnderCap := false
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
for range sessions {
bestIdx := -1
minStreams := int(^uint(0) >> 1)
for i := 0; i < len(sessions); i++ {
idx := (start + i) % len(sessions)
if tried[idx] {
continue
}
session := sessions[idx]
if session == nil || session.IsClosed() {
tried[idx] = true
continue
}
n := session.NumStreams()
if n >= maxStreamsPerSession {
continue
}
anyUnderCap = true
if n < minStreams {
minStreams = n
bestIdx = idx
}
}
if bestIdx == -1 {
break
}
tried[bestIdx] = true
session := sessions[bestIdx]
if session == nil || session.IsClosed() {
continue
}
stream, err := session.Open()
if err == nil {
return stream, nil
}
lastErr = err
if session.IsClosed() {
g.deleteClosedSessions()
}
}
if !anyUnderCap {
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
}
if attempt < maxRetries-1 {
select {
case <-g.stopCh:
return nil, net.ErrClosed
case <-time.After(backoffBase * time.Duration(attempt+1)):
}
}
}
if lastErr == nil {
lastErr = fmt.Errorf("failed to open stream")
}
return nil, lastErr
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot()
if len(sessions) == 0 {
return nil
}
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
minStreams := int(^uint(0) >> 1)
var best *yamux.Session
for i := 0; i < len(sessions); i++ {
session := sessions[(start+i)%len(sessions)]
if session == nil || session.IsClosed() {
continue
}
if n := session.NumStreams(); n < minStreams {
minStreams = n
best = session
}
}
return best
}
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
g.mu.Lock()
defer g.mu.Unlock()
if len(g.Sessions) == 0 {
return nil
}
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
continue
}
sessions = append(sessions, session)
}
if len(sessions) > 0 {
g.LastActivity = time.Now()
}
return sessions
}
func (g *ConnectionGroup) deleteClosedSessions() {
g.mu.Lock()
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
}
}
g.mu.Unlock()
}

View File

@@ -0,0 +1,163 @@
package tcp
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// ConnectionGroupManager manages all connection groups
type ConnectionGroupManager struct {
groups map[string]*ConnectionGroup // TunnelID -> ConnectionGroup
mu sync.RWMutex
logger *zap.Logger
// Cleanup
cleanupInterval time.Duration
staleTimeout time.Duration
stopCh chan struct{}
}
// NewConnectionGroupManager creates a new connection group manager
func NewConnectionGroupManager(logger *zap.Logger) *ConnectionGroupManager {
m := &ConnectionGroupManager{
groups: make(map[string]*ConnectionGroup),
logger: logger,
cleanupInterval: 60 * time.Second,
staleTimeout: 5 * time.Minute,
stopCh: make(chan struct{}),
}
go m.cleanupLoop()
return m
}
// GenerateTunnelID generates a unique tunnel ID
func GenerateTunnelID() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
// CreateGroup creates a new connection group
func (m *ConnectionGroupManager) CreateGroup(subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType) *ConnectionGroup {
m.mu.Lock()
defer m.mu.Unlock()
tunnelID := GenerateTunnelID()
group := NewConnectionGroup(tunnelID, subdomain, token, primaryConn, tunnelType, m.logger)
m.groups[tunnelID] = group
return group
}
// GetGroup returns a connection group by tunnel ID
func (m *ConnectionGroupManager) GetGroup(tunnelID string) (*ConnectionGroup, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
group, ok := m.groups[tunnelID]
return group, ok
}
// RemoveGroup removes and closes a connection group
func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
m.mu.Lock()
group, ok := m.groups[tunnelID]
if ok {
delete(m.groups, tunnelID)
}
m.mu.Unlock()
if ok && group != nil {
group.Close()
}
}
// AddDataConnection adds a data connection to a group
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
m.mu.RLock()
group, ok := m.groups[req.TunnelID]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token
if group.Token != "" && req.Token != group.Token {
return nil, fmt.Errorf("invalid token")
}
dataConn := group.AddDataConnection(req.ConnectionID, conn)
return dataConn, nil
}
// cleanupLoop periodically cleans up stale groups
func (m *ConnectionGroupManager) cleanupLoop() {
ticker := time.NewTicker(m.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanupStaleGroups()
case <-m.stopCh:
return
}
}
}
func (m *ConnectionGroupManager) cleanupStaleGroups() {
// Collect stale groups under lock
m.mu.Lock()
var staleGroups []*ConnectionGroup
var staleIDs []string
for tunnelID, group := range m.groups {
if group.IsStale(m.staleTimeout) {
staleIDs = append(staleIDs, tunnelID)
staleGroups = append(staleGroups, group)
}
}
// Remove from map while holding lock
for _, tunnelID := range staleIDs {
delete(m.groups, tunnelID)
}
m.mu.Unlock()
// Close groups without holding lock to avoid blocking other operations
for _, group := range staleGroups {
group.Close()
}
}
// Close shuts down the manager
func (m *ConnectionGroupManager) Close() {
close(m.stopCh)
// Collect all groups under lock
m.mu.Lock()
groups := make([]*ConnectionGroup, 0, len(m.groups))
for _, group := range m.groups {
groups = append(groups, group)
}
m.groups = make(map[string]*ConnectionGroup)
m.mu.Unlock()
// Close groups without holding lock
for _, group := range groups {
group.Close()
}
}

View File

@@ -12,32 +12,34 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/recovery"
"go.uber.org/zap"
)
// Listener handles TCP connections with TLS 1.3
type Listener struct {
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
responseChans HTTPResponseHandler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
groupManager *ConnectionGroupManager
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
numCPU := pool.NumCPU()
workers := numCPU * 5
queueSize := workers * 20
@@ -53,21 +55,21 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Listener{
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
manager: manager,
portAlloc: portAlloc,
logger: logger,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
recoverer: recoverer,
panicMetrics: panicMetrics,
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
manager: manager,
portAlloc: portAlloc,
logger: logger,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
recoverer: recoverer,
panicMetrics: panicMetrics,
groupManager: NewConnectionGroupManager(logger),
}
}
@@ -206,7 +208,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -222,14 +224,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
if err := conn.Handle(); err != nil {
errStr := err.Error()
// Client disconnection errors - normal network behavior, log as DEBUG
if strings.Contains(errStr, "connection reset by peer") ||
// Client disconnection errors - normal network behavior, ignore
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
l.logger.Debug("Client disconnected",
zap.String("remote_addr", connID),
zap.Error(err),
)
return
}
@@ -277,6 +276,10 @@ func (l *Listener) Stop() error {
l.workerPool.Close()
}
if l.groupManager != nil {
l.groupManager.Close()
}
l.logger.Info("TCP listener stopped")
return nil
}

View File

@@ -1,64 +1,79 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// TunnelProxy handles TCP connections for a specific tunnel
type TunnelProxy struct {
port int
subdomain string
tcpConn net.Conn // The tunnel control connection
listener net.Listener
logger *zap.Logger
stopCh chan struct{}
wg sync.WaitGroup
clientAddr string
streams map[string]*proxyStream // streamID -> stream info
streamMu sync.RWMutex
frameWriter *protocol.FrameWriter
bufferPool *pool.BufferPool
// Proxy exposes a public TCP port and forwards each incoming
// connection over a dedicated mux stream.
type Proxy struct {
port int
subdomain string
logger *zap.Logger
listener net.Listener
stopCh chan struct{}
once sync.Once
wg sync.WaitGroup
openStream func() (net.Conn, error)
stats trafficStats
sem chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// proxyStream holds connection info with close state
type proxyStream struct {
conn net.Conn
closed bool
mu sync.Mutex
type trafficStats interface {
AddBytesIn(n int64)
AddBytesOut(n int64)
IncActiveConnections()
DecActiveConnections()
}
// NewTunnelProxy creates a new TCP tunnel proxy
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
return &TunnelProxy{
port: port,
subdomain: subdomain,
tcpConn: tcpConn,
logger: logger,
stopCh: make(chan struct{}),
clientAddr: tcpConn.RemoteAddr().String(),
streams: make(map[string]*proxyStream),
bufferPool: pool.NewBufferPool(),
frameWriter: protocol.NewFrameWriter(tcpConn),
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
if ctx == nil {
ctx = context.Background()
}
cctx, cancel := context.WithCancel(ctx)
const maxConcurrentConnections = 10000
var sem chan struct{}
if maxConcurrentConnections > 0 {
sem = make(chan struct{}, maxConcurrentConnections)
}
return &Proxy{
port: port,
subdomain: subdomain,
logger: logger,
stopCh: make(chan struct{}),
openStream: openStream,
stats: stats,
sem: sem,
ctx: cctx,
cancel: cancel,
}
}
// Start starts listening on the allocated port
func (p *TunnelProxy) Start() error {
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
listener, err := net.Listen("tcp", addr)
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
p.listener = ln
p.logger.Info("TCP proxy started",
zap.Int("port", p.port),
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
p.wg.Add(1)
go p.acceptLoop()
return nil
}
// acceptLoop accepts incoming TCP connections
func (p *TunnelProxy) acceptLoop() {
func (p *Proxy) Stop() {
p.once.Do(func() {
close(p.stopCh)
p.cancel()
if p.listener != nil {
_ = p.listener.Close()
}
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
const stopTimeout = 30 * time.Second
select {
case <-done:
p.logger.Info("TCP proxy stopped",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
case <-time.After(stopTimeout):
p.logger.Warn("TCP proxy stop timed out",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
zap.Duration("timeout", stopTimeout),
)
}
})
}
func (p *Proxy) acceptLoop() {
defer p.wg.Done()
tcpLn, _ := p.listener.(*net.TCPListener)
for {
select {
case <-p.stopCh:
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
default:
}
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
if tcpLn != nil {
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := p.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
select {
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
}
p.wg.Add(1)
go p.handleConnection(conn)
go p.handleConn(conn)
}
}
func (p *TunnelProxy) handleConnection(conn net.Conn) {
func (p *Proxy) handleConn(conn net.Conn) {
defer p.wg.Done()
defer conn.Close()
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
stream := &proxyStream{
conn: conn,
closed: false,
if p.sem != nil {
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
}
p.streamMu.Lock()
p.streams[streamID] = stream
p.streamMu.Unlock()
if p.stats != nil {
p.stats.IncActiveConnections()
defer p.stats.DecActiveConnections()
}
defer func() {
p.streamMu.Lock()
delete(p.streams, streamID)
p.streamMu.Unlock()
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
if p.openStream == nil {
return
}
// Open stream with timeout to prevent goroutine leak
const openStreamTimeout = 10 * time.Second
type streamResult struct {
stream net.Conn
err error
}
resultCh := make(chan streamResult, 1)
go func() {
s, err := p.openStream()
resultCh <- streamResult{s, err}
}()
bufPtr := p.bufferPool.Get(pool.SizeMedium)
defer p.bufferPool.Put(bufPtr)
buffer := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := conn.Read(buffer)
if err != nil {
break
}
if n > 0 {
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
p.logger.Debug("Send to tunnel failed", zap.Error(err))
break
var stream net.Conn
select {
case result := <-resultCh:
if result.err != nil {
if !errors.Is(result.err, net.ErrClosed) {
p.logger.Debug("Open stream failed", zap.Error(result.err))
}
return
}
}
select {
stream = result.stream
case <-time.After(openStreamTimeout):
p.logger.Debug("Open stream timeout")
return
case <-p.stopCh:
default:
p.sendCloseToTunnel(streamID)
}
}
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
select {
case <-p.stopCh:
return fmt.Errorf("tunnel proxy stopped")
default:
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeData,
IsLast: false,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
err = p.frameWriter.WriteFrame(frame)
if err != nil {
return fmt.Errorf("failed to write frame: %w", err)
}
return nil
}
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
p.frameWriter.WriteFrame(frame)
}
defer stream.Close()
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
p.streamMu.RLock()
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
// Stream may have been closed by client, this is normal
return nil
}
// Check if stream is closed
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return nil
}
stream.mu.Unlock()
if _, err := stream.conn.Write(data); err != nil {
p.logger.Debug("Write to client failed", zap.Error(err))
return err
}
return nil
}
// CloseStream closes a stream
func (p *TunnelProxy) CloseStream(streamID string) {
p.streamMu.RLock()
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
return
}
// Mark as closed first
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Now close the connection
stream.conn.Close()
}
func (p *TunnelProxy) Stop() {
p.logger.Info("Stopping TCP proxy",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
pool.SizeLarge,
func(n int64) {
if p.stats != nil {
p.stats.AddBytesIn(n)
}
},
func(n int64) {
if p.stats != nil {
p.stats.AddBytesOut(n)
}
},
)
close(p.stopCh)
if p.listener != nil {
p.listener.Close()
}
p.streamMu.Lock()
for _, stream := range p.streams {
stream.mu.Lock()
stream.closed = true
stream.mu.Unlock()
stream.conn.Close()
}
p.streams = make(map[string]*proxyStream)
p.streamMu.Unlock()
p.wg.Wait()
if p.frameWriter != nil {
p.frameWriter.Close()
}
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
}

View File

@@ -0,0 +1,98 @@
package tcp
import (
"bufio"
"fmt"
"io"
"net"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
)
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
}
}
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start tcp proxy: %w", err)
}
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
}
}
if c.tunnelConn != nil {
c.tunnelConn.SetOpenStream(openStream)
}
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}

View File

@@ -1,7 +1,9 @@
package tunnel
import (
"net"
"sync"
"sync/atomic"
"time"
"drip/internal/shared/protocol"
@@ -9,13 +11,6 @@ import (
"go.uber.org/zap"
)
// Transport represents the control channel to the client.
// It is implemented by the TCP control connection so the HTTP proxy
// can push frames directly to the client without depending on WebSockets.
type Transport interface {
SendFrame(frame *protocol.Frame) error
}
// Connection represents a tunnel connection from a client
type Connection struct {
Subdomain string
@@ -26,8 +21,12 @@ type Connection struct {
mu sync.RWMutex
logger *zap.Logger
closed bool
transport Transport
tunnelType protocol.TunnelType
openStream func() (net.Conn, error)
bytesIn atomic.Int64
bytesOut atomic.Int64
activeConnections atomic.Int64
}
// NewConnection creates a new tunnel connection
@@ -106,21 +105,6 @@ func (c *Connection) IsClosed() bool {
return c.closed
}
// SetTransport attaches the control transport and tunnel type.
func (c *Connection) SetTransport(t Transport, tType protocol.TunnelType) {
c.mu.Lock()
c.transport = t
c.tunnelType = tType
c.mu.Unlock()
}
// GetTransport returns the attached transport (if any).
func (c *Connection) GetTransport() Transport {
c.mu.RLock()
defer c.mu.RUnlock()
return c.transport
}
// SetTunnelType sets the tunnel type.
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
c.mu.Lock()
@@ -135,6 +119,63 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
return c.tunnelType
}
// SetOpenStream registers a yamux stream opener for this tunnel.
// It is used by the HTTP proxy to forward each request over a mux stream.
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
c.mu.Lock()
c.openStream = open
c.mu.Unlock()
}
// OpenStream opens a new mux stream to the tunnel client.
func (c *Connection) OpenStream() (net.Conn, error) {
c.mu.RLock()
open := c.openStream
closed := c.closed
c.mu.RUnlock()
if closed || open == nil {
return nil, ErrConnectionClosed
}
return open()
}
func (c *Connection) AddBytesIn(n int64) {
if n <= 0 {
return
}
c.bytesIn.Add(n)
}
func (c *Connection) AddBytesOut(n int64) {
if n <= 0 {
return
}
c.bytesOut.Add(n)
}
func (c *Connection) GetBytesIn() int64 {
return c.bytesIn.Load()
}
func (c *Connection) GetBytesOut() int64 {
return c.bytesOut.Load()
}
func (c *Connection) IncActiveConnections() {
c.activeConnections.Add(1)
}
func (c *Connection) DecActiveConnections() {
if v := c.activeConnections.Add(-1); v < 0 {
c.activeConnections.Store(0)
}
}
func (c *Connection) GetActiveConnections() int64 {
return c.activeConnections.Load()
}
// StartWritePump starts the write pump for sending messages
func (c *Connection) StartWritePump() {
// Skip write pump for TCP-only connections (no WebSocket)

View File

@@ -1,280 +0,0 @@
package hpack
import (
"bytes"
"errors"
"fmt"
"net/http"
"sync"
)
// Decoder decompresses HPACK-encoded headers
// Each connection MUST have its own decoder instance to maintain correct state
type Decoder struct {
mu sync.Mutex
dynamicTable *DynamicTable
staticTable *StaticTable
maxTableSize uint32
}
// NewDecoder creates a new HPACK decoder
func NewDecoder(maxTableSize uint32) *Decoder {
if maxTableSize == 0 {
maxTableSize = DefaultDynamicTableSize
}
return &Decoder{
dynamicTable: NewDynamicTable(maxTableSize),
staticTable: GetStaticTable(),
maxTableSize: maxTableSize,
}
}
// Decode decodes HPACK-encoded headers
func (d *Decoder) Decode(data []byte) (http.Header, error) {
d.mu.Lock()
defer d.mu.Unlock()
if len(data) == 0 {
return http.Header{}, nil
}
headers := make(http.Header)
buf := bytes.NewReader(data)
for buf.Len() > 0 {
b, err := buf.ReadByte()
if err != nil {
return nil, fmt.Errorf("read header byte: %w", err)
}
// Unread the byte so we can process it properly
if err := buf.UnreadByte(); err != nil {
return nil, err
}
var name, value string
if b&0x80 != 0 {
// Indexed header field (10xxxxxx)
name, value, err = d.decodeIndexedHeader(buf)
} else if b&0x40 != 0 {
// Literal with incremental indexing (01xxxxxx)
name, value, err = d.decodeLiteralWithIndexing(buf)
} else {
// Literal without indexing (0000xxxx)
name, value, err = d.decodeLiteralWithoutIndexing(buf)
}
if err != nil {
return nil, err
}
headers.Add(name, value)
}
return headers, nil
}
// decodeIndexedHeader decodes an indexed header field
func (d *Decoder) decodeIndexedHeader(buf *bytes.Reader) (string, string, error) {
index, err := d.readInteger(buf, 7)
if err != nil {
return "", "", fmt.Errorf("read index: %w", err)
}
if index == 0 {
return "", "", errors.New("invalid index: 0")
}
staticSize := uint32(d.staticTable.Size())
if index <= staticSize {
// Static table
return d.staticTable.Get(index - 1)
}
// Dynamic table (indices start after static table)
dynamicIndex := index - staticSize - 1
return d.dynamicTable.Get(dynamicIndex)
}
// decodeLiteralWithIndexing decodes a literal header with incremental indexing
func (d *Decoder) decodeLiteralWithIndexing(buf *bytes.Reader) (string, string, error) {
nameIndex, err := d.readInteger(buf, 6)
if err != nil {
return "", "", err
}
var name string
if nameIndex == 0 {
// Name is literal
name, err = d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read name: %w", err)
}
} else {
// Name is indexed
staticSize := uint32(d.staticTable.Size())
if nameIndex <= staticSize {
name, _, err = d.staticTable.Get(nameIndex - 1)
} else {
dynamicIndex := nameIndex - staticSize - 1
name, _, err = d.dynamicTable.Get(dynamicIndex)
}
if err != nil {
return "", "", fmt.Errorf("get indexed name: %w", err)
}
}
// Value is always literal
value, err := d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read value: %w", err)
}
// Add to dynamic table
d.dynamicTable.Add(name, value)
return name, value, nil
}
// decodeLiteralWithoutIndexing decodes a literal header without indexing
func (d *Decoder) decodeLiteralWithoutIndexing(buf *bytes.Reader) (string, string, error) {
nameIndex, err := d.readInteger(buf, 4)
if err != nil {
return "", "", err
}
var name string
if nameIndex == 0 {
// Name is literal
name, err = d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read name: %w", err)
}
} else {
// Name is indexed
staticSize := uint32(d.staticTable.Size())
if nameIndex <= staticSize {
name, _, err = d.staticTable.Get(nameIndex - 1)
} else {
dynamicIndex := nameIndex - staticSize - 1
name, _, err = d.dynamicTable.Get(dynamicIndex)
}
if err != nil {
return "", "", fmt.Errorf("get indexed name: %w", err)
}
}
// Value is always literal
value, err := d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read value: %w", err)
}
// Do NOT add to dynamic table
return name, value, nil
}
// readInteger reads an HPACK integer
func (d *Decoder) readInteger(buf *bytes.Reader, prefixBits int) (uint32, error) {
if prefixBits < 1 || prefixBits > 8 {
return 0, fmt.Errorf("invalid prefix bits: %d", prefixBits)
}
b, err := buf.ReadByte()
if err != nil {
return 0, err
}
maxPrefix := uint32((1 << prefixBits) - 1)
mask := byte(maxPrefix)
value := uint32(b & mask)
if value < maxPrefix {
return value, nil
}
// Multi-byte integer
m := uint32(0)
for {
b, err := buf.ReadByte()
if err != nil {
return 0, err
}
value += uint32(b&0x7f) << m
m += 7
if b&0x80 == 0 {
break
}
if m > 28 {
return 0, errors.New("integer overflow")
}
}
return value, nil
}
// readString reads an HPACK string
func (d *Decoder) readString(buf *bytes.Reader) (string, error) {
b, err := buf.ReadByte()
if err != nil {
return "", err
}
if err := buf.UnreadByte(); err != nil {
return "", err
}
huffmanEncoded := (b & 0x80) != 0
length, err := d.readInteger(buf, 7)
if err != nil {
return "", fmt.Errorf("read string length: %w", err)
}
if length == 0 {
return "", nil
}
if length > uint32(buf.Len()) {
return "", fmt.Errorf("string length %d exceeds buffer size %d", length, buf.Len())
}
data := make([]byte, length)
n, err := buf.Read(data)
if err != nil {
return "", err
}
if n != int(length) {
return "", fmt.Errorf("expected %d bytes, read %d", length, n)
}
if huffmanEncoded {
// TODO: Implement Huffman decoding if needed
return "", errors.New("huffman decoding not implemented")
}
return string(data), nil
}
// SetMaxTableSize updates the dynamic table size
func (d *Decoder) SetMaxTableSize(size uint32) {
d.mu.Lock()
defer d.mu.Unlock()
d.maxTableSize = size
d.dynamicTable.SetMaxSize(size)
}
// Reset clears the dynamic table
func (d *Decoder) Reset() {
d.mu.Lock()
defer d.mu.Unlock()
d.dynamicTable = NewDynamicTable(d.maxTableSize)
}

View File

@@ -1,124 +0,0 @@
package hpack
import (
"fmt"
)
// DynamicTable implements the HPACK dynamic table (RFC 7541 Section 2.3.2)
// The dynamic table is a FIFO queue where new entries are added at the beginning
// and old entries are evicted when the table size exceeds the maximum
type DynamicTable struct {
entries []HeaderField
size uint32 // Current size in bytes
maxSize uint32 // Maximum size in bytes
}
// HeaderField represents a header name-value pair
type HeaderField struct {
Name string
Value string
}
// Size returns the size of this header field in bytes
// RFC 7541: size = len(name) + len(value) + 32
func (h *HeaderField) Size() uint32 {
return uint32(len(h.Name) + len(h.Value) + 32)
}
// NewDynamicTable creates a new dynamic table with the specified maximum size
func NewDynamicTable(maxSize uint32) *DynamicTable {
return &DynamicTable{
entries: make([]HeaderField, 0, 32),
size: 0,
maxSize: maxSize,
}
}
// Add adds a header field to the dynamic table
// New entries are added at the beginning (index 0)
func (dt *DynamicTable) Add(name, value string) {
field := HeaderField{Name: name, Value: value}
fieldSize := field.Size()
// If the field is larger than maxSize, don't add it
if fieldSize > dt.maxSize {
dt.evictAll()
return
}
// Evict entries if necessary to make room
for dt.size+fieldSize > dt.maxSize && len(dt.entries) > 0 {
dt.evictOldest()
}
// Add new entry at the beginning
dt.entries = append([]HeaderField{field}, dt.entries...)
dt.size += fieldSize
}
// Get retrieves a header field by index (0-based)
// Index 0 is the most recently added entry
func (dt *DynamicTable) Get(index uint32) (string, string, error) {
if index >= uint32(len(dt.entries)) {
return "", "", fmt.Errorf("index %d out of range (table size: %d)", index, len(dt.entries))
}
field := dt.entries[index]
return field.Name, field.Value, nil
}
// FindExact searches for an exact match (name and value)
// Returns the index (0-based) and true if found
func (dt *DynamicTable) FindExact(name, value string) (uint32, bool) {
for i, field := range dt.entries {
if field.Name == name && field.Value == value {
return uint32(i), true
}
}
return 0, false
}
// FindName searches for a name match
// Returns the index (0-based) and true if found
func (dt *DynamicTable) FindName(name string) (uint32, bool) {
for i, field := range dt.entries {
if field.Name == name {
return uint32(i), true
}
}
return 0, false
}
// SetMaxSize updates the maximum table size
// If the new size is smaller, entries are evicted
func (dt *DynamicTable) SetMaxSize(maxSize uint32) {
dt.maxSize = maxSize
// Evict entries if current size exceeds new max
for dt.size > dt.maxSize && len(dt.entries) > 0 {
dt.evictOldest()
}
}
// CurrentSize returns the current size of the table in bytes
func (dt *DynamicTable) CurrentSize() uint32 {
return dt.size
}
// evictOldest removes the oldest entry (last in the slice)
func (dt *DynamicTable) evictOldest() {
if len(dt.entries) == 0 {
return
}
lastIndex := len(dt.entries) - 1
evicted := dt.entries[lastIndex]
dt.entries = dt.entries[:lastIndex]
dt.size -= evicted.Size()
}
// evictAll removes all entries
func (dt *DynamicTable) evictAll() {
dt.entries = dt.entries[:0]
dt.size = 0
}

View File

@@ -1,200 +0,0 @@
package hpack
import (
"bytes"
"errors"
"fmt"
"net/http"
"strings"
"sync"
)
const (
// DefaultDynamicTableSize is the default size of the dynamic table (4KB)
DefaultDynamicTableSize = 4096
// IndexedHeaderField represents a fully indexed header field
indexedHeaderField = 0x80 // 10xxxxxx
// LiteralHeaderFieldWithIndexing represents a literal with incremental indexing
literalHeaderFieldWithIndexing = 0x40 // 01xxxxxx
)
// Encoder compresses HTTP headers using HPACK
// Each connection MUST have its own encoder instance to avoid state corruption
type Encoder struct {
mu sync.Mutex
dynamicTable *DynamicTable
staticTable *StaticTable
maxTableSize uint32
}
// NewEncoder creates a new HPACK encoder with the specified dynamic table size
// This encoder is NOT thread-safe and should be used by a single connection
func NewEncoder(maxTableSize uint32) *Encoder {
if maxTableSize == 0 {
maxTableSize = DefaultDynamicTableSize
}
return &Encoder{
dynamicTable: NewDynamicTable(maxTableSize),
staticTable: GetStaticTable(),
maxTableSize: maxTableSize,
}
}
// Encode encodes HTTP headers into HPACK binary format
// This method is safe to call concurrently within the same encoder instance
func (e *Encoder) Encode(headers http.Header) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
if headers == nil {
return nil, errors.New("headers cannot be nil")
}
buf := &bytes.Buffer{}
for name, values := range headers {
for _, value := range values {
if err := e.encodeHeaderField(buf, name, value); err != nil {
return nil, fmt.Errorf("encode header %s: %w", name, err)
}
}
}
return buf.Bytes(), nil
}
// encodeHeaderField encodes a single header field
func (e *Encoder) encodeHeaderField(buf *bytes.Buffer, name, value string) error {
// HTTP/2 requires header names to be lowercase (RFC 7540 Section 8.1.2)
// Convert to lowercase for table lookups and storage
nameLower := strings.ToLower(name)
// Try to find in static table first
if index, found := e.staticTable.FindExact(nameLower, value); found {
return e.writeIndexedHeader(buf, index+1)
}
// Check if name exists in static table (for literal with name reference)
if index, found := e.staticTable.FindName(nameLower); found {
return e.writeLiteralWithIndexing(buf, index+1, value, true)
}
// Try dynamic table
if index, found := e.dynamicTable.FindExact(nameLower, value); found {
// Dynamic table indices start after static table
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
return e.writeIndexedHeader(buf, dynamicIndex)
}
if index, found := e.dynamicTable.FindName(nameLower); found {
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
return e.writeLiteralWithIndexing(buf, dynamicIndex, value, true)
}
// Not found anywhere - literal with indexing and new name
// Write literal flag
buf.WriteByte(literalHeaderFieldWithIndexing)
// Write name as literal string (must come before value)
// Use lowercase name for consistency
if err := e.writeString(buf, nameLower, false); err != nil {
return err
}
// Write value as literal string
if err := e.writeString(buf, value, false); err != nil {
return err
}
// Add to dynamic table with lowercase name
e.dynamicTable.Add(nameLower, value)
return nil
}
// writeIndexedHeader writes an indexed header field (10xxxxxx)
func (e *Encoder) writeIndexedHeader(buf *bytes.Buffer, index uint32) error {
return e.writeInteger(buf, index, 7, indexedHeaderField)
}
// writeLiteralWithIndexing writes a literal header with incremental indexing (01xxxxxx)
func (e *Encoder) writeLiteralWithIndexing(buf *bytes.Buffer, nameIndex uint32, value string, hasIndex bool) error {
if hasIndex {
// Write name as index
if err := e.writeInteger(buf, nameIndex, 6, literalHeaderFieldWithIndexing); err != nil {
return err
}
} else {
// Write literal flag
buf.WriteByte(literalHeaderFieldWithIndexing)
}
// Write value as literal string
return e.writeString(buf, value, false)
}
// writeInteger writes an integer using HPACK integer representation
func (e *Encoder) writeInteger(buf *bytes.Buffer, value uint32, prefixBits int, prefix byte) error {
if prefixBits < 1 || prefixBits > 8 {
return fmt.Errorf("invalid prefix bits: %d", prefixBits)
}
maxPrefix := uint32((1 << prefixBits) - 1)
if value < maxPrefix {
buf.WriteByte(prefix | byte(value))
return nil
}
// Value >= maxPrefix, need multiple bytes
buf.WriteByte(prefix | byte(maxPrefix))
value -= maxPrefix
for value >= 128 {
buf.WriteByte(byte(value%128) | 0x80)
value /= 128
}
buf.WriteByte(byte(value))
return nil
}
// writeString writes a string using HPACK string representation
func (e *Encoder) writeString(buf *bytes.Buffer, str string, huffmanEncode bool) error {
// For simplicity, we don't use Huffman encoding in this implementation
// Huffman flag is bit 7, followed by length in remaining 7 bits
length := uint32(len(str))
if huffmanEncode {
// TODO: Implement Huffman encoding if needed
return errors.New("huffman encoding not implemented")
}
// Write length with H=0 (no Huffman)
if err := e.writeInteger(buf, length, 7, 0x00); err != nil {
return err
}
// Write string bytes
buf.WriteString(str)
return nil
}
// SetMaxTableSize updates the dynamic table size
func (e *Encoder) SetMaxTableSize(size uint32) {
e.mu.Lock()
defer e.mu.Unlock()
e.maxTableSize = size
e.dynamicTable.SetMaxSize(size)
}
// Reset clears the dynamic table
func (e *Encoder) Reset() {
e.mu.Lock()
defer e.mu.Unlock()
e.dynamicTable = NewDynamicTable(e.maxTableSize)
}

View File

@@ -1,150 +0,0 @@
package hpack
import (
"fmt"
"sync"
)
// StaticTable implements the HPACK static table (RFC 7541 Appendix A)
// The static table is predefined and never changes
type StaticTable struct {
entries []HeaderField
nameMap map[string][]uint32 // Maps name to list of indices
}
var (
staticTableInstance *StaticTable
staticTableOnce sync.Once
)
// GetStaticTable returns the singleton static table instance
func GetStaticTable() *StaticTable {
staticTableOnce.Do(func() {
staticTableInstance = newStaticTable()
})
return staticTableInstance
}
// newStaticTable creates and initializes the static table
func newStaticTable() *StaticTable {
// RFC 7541 Appendix A - Static Table Definition
// We include the most common headers for HTTP
entries := []HeaderField{
{Name: ":authority", Value: ""},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "POST"},
{Name: ":path", Value: "/"},
{Name: ":path", Value: "/index.html"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "500"},
{Name: "accept-charset", Value: ""},
{Name: "accept-encoding", Value: "gzip, deflate"},
{Name: "accept-language", Value: ""},
{Name: "accept-ranges", Value: ""},
{Name: "accept", Value: ""},
{Name: "access-control-allow-origin", Value: ""},
{Name: "age", Value: ""},
{Name: "allow", Value: ""},
{Name: "authorization", Value: ""},
{Name: "cache-control", Value: ""},
{Name: "content-disposition", Value: ""},
{Name: "content-encoding", Value: ""},
{Name: "content-language", Value: ""},
{Name: "content-length", Value: ""},
{Name: "content-location", Value: ""},
{Name: "content-range", Value: ""},
{Name: "content-type", Value: ""},
{Name: "cookie", Value: ""},
{Name: "date", Value: ""},
{Name: "etag", Value: ""},
{Name: "expect", Value: ""},
{Name: "expires", Value: ""},
{Name: "from", Value: ""},
{Name: "host", Value: ""},
{Name: "if-match", Value: ""},
{Name: "if-modified-since", Value: ""},
{Name: "if-none-match", Value: ""},
{Name: "if-range", Value: ""},
{Name: "if-unmodified-since", Value: ""},
{Name: "last-modified", Value: ""},
{Name: "link", Value: ""},
{Name: "location", Value: ""},
{Name: "max-forwards", Value: ""},
{Name: "proxy-authenticate", Value: ""},
{Name: "proxy-authorization", Value: ""},
{Name: "range", Value: ""},
{Name: "referer", Value: ""},
{Name: "refresh", Value: ""},
{Name: "retry-after", Value: ""},
{Name: "server", Value: ""},
{Name: "set-cookie", Value: ""},
{Name: "strict-transport-security", Value: ""},
{Name: "transfer-encoding", Value: ""},
{Name: "user-agent", Value: ""},
{Name: "vary", Value: ""},
{Name: "via", Value: ""},
{Name: "www-authenticate", Value: ""},
}
// Build name index map
nameMap := make(map[string][]uint32)
for i, entry := range entries {
nameMap[entry.Name] = append(nameMap[entry.Name], uint32(i))
}
return &StaticTable{
entries: entries,
nameMap: nameMap,
}
}
// Get retrieves a header field by index (0-based)
func (st *StaticTable) Get(index uint32) (string, string, error) {
if index >= uint32(len(st.entries)) {
return "", "", fmt.Errorf("index %d out of range (static table size: %d)", index, len(st.entries))
}
field := st.entries[index]
return field.Name, field.Value, nil
}
// FindExact searches for an exact match (name and value)
// Returns the index (0-based) and true if found
func (st *StaticTable) FindExact(name, value string) (uint32, bool) {
indices, exists := st.nameMap[name]
if !exists {
return 0, false
}
for _, index := range indices {
field := st.entries[index]
if field.Value == value {
return index, true
}
}
return 0, false
}
// FindName searches for a name match
// Returns the first matching index (0-based) and true if found
func (st *StaticTable) FindName(name string) (uint32, bool) {
indices, exists := st.nameMap[name]
if !exists || len(indices) == 0 {
return 0, false
}
return indices[0], true
}
// Size returns the number of entries in the static table
func (st *StaticTable) Size() int {
return len(st.entries)
}

View File

@@ -9,6 +9,10 @@ const (
// DefaultWSPort is the default WebSocket port
DefaultWSPort = 8080
// YamuxAcceptBacklog controls how many incoming streams can be queued
// before yamux starts blocking stream opens under load.
YamuxAcceptBacklog = 4096
// HeartbeatInterval is how often clients send heartbeat messages
HeartbeatInterval = 2 * time.Second

View File

@@ -0,0 +1,71 @@
package httputil
import (
"fmt"
"io"
"net/http"
"strings"
)
// CopyHeaders copies all headers from src to dst.
func CopyHeaders(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// CleanHopByHopHeaders removes hop-by-hop headers that should not be forwarded.
func CleanHopByHopHeaders(headers http.Header) {
if headers == nil {
return
}
if connectionHeaders := headers.Get("Connection"); connectionHeaders != "" {
for _, token := range strings.Split(connectionHeaders, ",") {
if t := strings.TrimSpace(token); t != "" {
headers.Del(http.CanonicalHeaderKey(t))
}
}
}
for _, key := range []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Proxy-Connection",
} {
headers.Del(key)
}
}
// WriteProxyError writes an HTTP error response to the writer.
func WriteProxyError(w io.Writer, code int, msg string) {
body := msg
resp := &http.Response{
StatusCode: code,
Status: fmt.Sprintf("%d %s", code, http.StatusText(code)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
ContentLength: int64(len(body)),
Close: true,
}
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
_ = resp.Write(w)
_ = resp.Body.Close()
}
// IsWebSocketUpgrade checks if the request is a WebSocket upgrade request.
func IsWebSocketUpgrade(req *http.Request) bool {
return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") &&
strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}

View File

@@ -0,0 +1,35 @@
package netutil
import "net"
// CountingConn wraps a net.Conn to count bytes read/written.
type CountingConn struct {
net.Conn
OnRead func(int64)
OnWrite func(int64)
}
// NewCountingConn creates a new CountingConn.
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
return &CountingConn{
Conn: conn,
OnRead: onRead,
OnWrite: onWrite,
}
}
func (c *CountingConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
if n > 0 && c.OnRead != nil {
c.OnRead(int64(n))
}
return n, err
}
func (c *CountingConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
if n > 0 && c.OnWrite != nil {
c.OnWrite(int64(n))
}
return n, err
}

View File

@@ -0,0 +1,164 @@
package netutil
import (
"context"
"io"
"sync"
"time"
"drip/internal/shared/pool"
)
const tcpWaitTimeout = 10 * time.Second
type closeReader interface {
CloseRead() error
}
type closeWriter interface {
CloseWrite() error
}
type readDeadliner interface {
SetReadDeadline(t time.Time) error
}
// Pipe copies bytes bidirectionally between a and b (gost-like),
// and applies TCP half-close when supported.
func Pipe(ctx context.Context, a, b io.ReadWriteCloser) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, nil, nil)
}
// PipeWithCallbacks is Pipe with optional byte counters for each direction:
// onAToB is called with bytes copied from a -> b, onBToA for b -> a.
func PipeWithCallbacks(ctx context.Context, a, b io.ReadWriteCloser, onAToB func(n int64), onBToA func(n int64)) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, onAToB, onBToA)
}
// PipeWithBufferSize is Pipe with a custom buffer size.
func PipeWithBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, bufSize, nil, nil)
}
// PipeWithCallbacksAndBufferSize is PipeWithCallbacks with a custom buffer size.
func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int, onAToB func(n int64), onBToA func(n int64)) error {
if bufSize <= 0 {
bufSize = pool.SizeMedium
}
if bufSize > pool.SizeLarge {
bufSize = pool.SizeLarge
}
var wg sync.WaitGroup
wg.Add(2)
stopCh := make(chan struct{})
var closeOnce sync.Once
closeAll := func() {
closeOnce.Do(func() {
close(stopCh)
_ = a.Close()
_ = b.Close()
})
}
errCh := make(chan error, 2)
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
go func() {
defer wg.Done()
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
if err != nil {
errCh <- err
}
closeAll()
}()
go func() {
defer wg.Done()
err := pipeBuffer(a, b, bufSize, onBToA, stopCh)
if err != nil {
errCh <- err
}
closeAll()
}()
wg.Wait()
select {
case err := <-errCh:
return err
default:
return nil
}
}
func pipeBuffer(dst io.ReadWriteCloser, src io.ReadWriteCloser, bufSize int, onCopied func(n int64), stopCh <-chan struct{}) error {
bufPtr := pool.GetBuffer(bufSize)
defer pool.PutBuffer(bufPtr)
buf := (*bufPtr)[:bufSize]
_, err := copyBuffer(dst, src, buf, onCopied, stopCh)
if cr, ok := src.(closeReader); ok {
_ = cr.CloseRead()
}
if cw, ok := dst.(closeWriter); ok {
if e := cw.CloseWrite(); e != nil {
_ = dst.Close()
}
if rd, ok := dst.(readDeadliner); ok {
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
} else {
_ = dst.Close()
if rd, ok := dst.(readDeadliner); ok {
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
}
return err
}
func copyBuffer(dst io.Writer, src io.Reader, buf []byte, onCopied func(n int64), stopCh <-chan struct{}) (written int64, err error) {
for {
select {
case <-stopCh:
return written, io.EOF
default:
}
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
if onCopied != nil {
onCopied(int64(nw))
}
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if er != nil {
if er == io.EOF {
return written, nil
}
return written, er
}
}
}

View File

@@ -1,73 +0,0 @@
package pool
import (
"sync"
)
// AdaptiveBufferPool manages reusable buffers of different sizes
// This eliminates the massive memory allocation overhead seen in profiling
type AdaptiveBufferPool struct {
// Large buffers for streaming threshold (1MB)
largePool *sync.Pool
// Medium buffers for temporary reads (32KB)
mediumPool *sync.Pool
}
const (
// LargeBufferSize is 1MB for streaming threshold
LargeBufferSize = 1 * 1024 * 1024
// MediumBufferSize is 32KB for temporary reads
MediumBufferSize = 32 * 1024
)
// NewAdaptiveBufferPool creates a new adaptive buffer pool
func NewAdaptiveBufferPool() *AdaptiveBufferPool {
return &AdaptiveBufferPool{
largePool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, LargeBufferSize)
return &buf
},
},
mediumPool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, MediumBufferSize)
return &buf
},
},
}
}
// GetLarge returns a large buffer (1MB) from the pool
// The returned buffer should be returned via PutLarge when done
func (p *AdaptiveBufferPool) GetLarge() *[]byte {
return p.largePool.Get().(*[]byte)
}
// PutLarge returns a large buffer to the pool for reuse
func (p *AdaptiveBufferPool) PutLarge(buf *[]byte) {
if buf == nil {
return
}
// Reset to full capacity to allow reuse
*buf = (*buf)[:cap(*buf)]
p.largePool.Put(buf)
}
// GetMedium returns a medium buffer (32KB) from the pool
// The returned buffer should be returned via PutMedium when done
func (p *AdaptiveBufferPool) GetMedium() *[]byte {
return p.mediumPool.Get().(*[]byte)
}
// PutMedium returns a medium buffer to the pool for reuse
func (p *AdaptiveBufferPool) PutMedium(buf *[]byte) {
if buf == nil {
return
}
// Reset to full capacity to allow reuse
*buf = (*buf)[:cap(*buf)]
p.mediumPool.Put(buf)
}

View File

@@ -1,86 +0,0 @@
package pool
import (
"net/http"
"sync"
)
// HeaderPool manages a pool of http.Header objects for reuse.
type HeaderPool struct {
pool sync.Pool
}
// NewHeaderPool creates a new header pool
func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
return make(http.Header, 12)
},
},
}
}
// Get retrieves a header from the pool.
func (p *HeaderPool) Get() http.Header {
h := p.pool.Get().(http.Header)
for k := range h {
delete(h, k)
}
return h
}
// Put returns a header to the pool.
func (p *HeaderPool) Put(h http.Header) {
if h == nil {
return
}
p.pool.Put(h)
}
// Clone creates a copy of src into dst, reusing dst's underlying storage
// This is more efficient than creating a new header from scratch
func (p *HeaderPool) Clone(dst, src http.Header) {
// Clear dst first
for k := range dst {
delete(dst, k)
}
// Copy all headers from src to dst
for k, vv := range src {
// Allocate new slice with exact capacity to avoid over-allocation
dst[k] = make([]string, len(vv))
copy(dst[k], vv)
}
}
// CloneWithExtra clones src into dst and adds/overwrites extra headers
// This is optimized for the common pattern of cloning + adding Host header
func (p *HeaderPool) CloneWithExtra(dst, src http.Header, extraKey, extraValue string) {
// Clear dst first
for k := range dst {
delete(dst, k)
}
// Copy all headers from src to dst
for k, vv := range src {
dst[k] = make([]string, len(vv))
copy(dst[k], vv)
}
// Set extra header (overwrite if exists)
dst.Set(extraKey, extraValue)
}
// globalHeaderPool is a package-level pool for convenience
var globalHeaderPool = NewHeaderPool()
// GetHeader retrieves a header from the global pool
func GetHeader() http.Header {
return globalHeaderPool.Get()
}
// PutHeader returns a header to the global pool
func PutHeader(h http.Header) {
globalHeaderPool.Put(h)
}

View File

@@ -2,81 +2,23 @@ package protocol
import (
"sync/atomic"
"time"
"drip/internal/shared/pool"
)
// AdaptivePoolManager dynamically adjusts buffer pool usage based on load
// AdaptivePoolManager tracks active connections for load monitoring
type AdaptivePoolManager struct {
activeConnections atomic.Int64
currentThreshold atomic.Int64
highLoadConnectionThreshold int64
midLoadConnectionThreshold int64
midLoadThreshold int64
highLoadThreshold int64
activeConnections atomic.Int64
}
var globalAdaptiveManager = NewAdaptivePoolManager()
func NewAdaptivePoolManager() *AdaptivePoolManager {
m := &AdaptivePoolManager{
highLoadConnectionThreshold: 300,
midLoadConnectionThreshold: 150,
midLoadThreshold: int64(pool.SizeLarge),
highLoadThreshold: int64(pool.SizeMedium),
}
m.currentThreshold.Store(m.midLoadThreshold)
go m.monitor()
return m
}
func (m *AdaptivePoolManager) monitor() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for range ticker.C {
connections := m.activeConnections.Load()
if connections >= m.highLoadConnectionThreshold {
m.currentThreshold.Store(m.highLoadThreshold)
} else if connections < m.midLoadConnectionThreshold {
m.currentThreshold.Store(m.midLoadThreshold)
}
// Hysteresis zone (150-300): maintain current threshold
}
}
func (m *AdaptivePoolManager) GetThreshold() int {
return int(m.currentThreshold.Load())
}
func (m *AdaptivePoolManager) RegisterConnection() {
m.activeConnections.Add(1)
}
func (m *AdaptivePoolManager) UnregisterConnection() {
m.activeConnections.Add(-1)
}
func (m *AdaptivePoolManager) GetActiveConnections() int64 {
return m.activeConnections.Load()
}
func GetAdaptiveThreshold() int {
return globalAdaptiveManager.GetThreshold()
}
var globalAdaptiveManager = &AdaptivePoolManager{}
func RegisterConnection() {
globalAdaptiveManager.RegisterConnection()
globalAdaptiveManager.activeConnections.Add(1)
}
func UnregisterConnection() {
globalAdaptiveManager.UnregisterConnection()
globalAdaptiveManager.activeConnections.Add(-1)
}
func GetActiveConnections() int64 {
return globalAdaptiveManager.GetActiveConnections()
return globalAdaptiveManager.activeConnections.Load()
}

View File

@@ -1,162 +0,0 @@
package protocol
import (
"encoding/binary"
"errors"
)
// DataHeader represents a binary-encoded data header for data plane
// All data transmission uses pure binary encoding for performance
type DataHeader struct {
Type DataType
IsLast bool
StreamID string
RequestID string
}
// DataType represents the type of data frame
type DataType uint8
const (
DataTypeData DataType = 0x00 // 000
DataTypeResponse DataType = 0x01 // 001
DataTypeClose DataType = 0x02 // 010
DataTypeHTTPRequest DataType = 0x03 // 011
DataTypeHTTPResponse DataType = 0x04 // 100
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
// Reuse the same type codes for request streaming to stay within 3 bits.
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
)
// String returns the string representation of DataType
func (t DataType) String() string {
switch t {
case DataTypeData:
return "data"
case DataTypeResponse:
return "response"
case DataTypeClose:
return "close"
case DataTypeHTTPRequest:
return "http_request"
case DataTypeHTTPResponse:
return "http_response"
case DataTypeHTTPHead:
return "http_head"
case DataTypeHTTPBodyChunk:
return "http_body_chunk"
default:
return "unknown"
}
}
// FromString converts a string to DataType
func DataTypeFromString(s string) DataType {
switch s {
case "data":
return DataTypeData
case "response":
return DataTypeResponse
case "close":
return DataTypeClose
case "http_request":
return DataTypeHTTPRequest
case "http_response":
return DataTypeHTTPResponse
case "http_head":
return DataTypeHTTPHead
case "http_body_chunk":
return DataTypeHTTPBodyChunk
default:
return DataTypeData
}
}
// Binary format:
// +--------+--------+--------+--------+--------+
// | Flags | StreamID Length | RequestID Len |
// | 1 byte | 2 bytes | 2 bytes |
// +--------+--------+--------+--------+--------+
// | StreamID (variable) |
// +--------+--------+--------+--------+--------+
// | RequestID (variable) |
// +--------+--------+--------+--------+--------+
//
// Flags (8 bits):
// - Bit 0-2: Type (3 bits)
// - Bit 3: IsLast (1 bit)
// - Bit 4-7: Reserved (4 bits)
const (
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
)
// MarshalBinary encodes the header to binary format
func (h *DataHeader) MarshalBinary() []byte {
streamIDLen := len(h.StreamID)
requestIDLen := len(h.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
buf := make([]byte, totalLen)
// Encode flags
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
if h.IsLast {
flags |= 0x08 // IsLast uses bit 3
}
buf[0] = flags
// Encode lengths (big-endian)
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
// Encode StreamID
offset := binaryHeaderMinSize
copy(buf[offset:], h.StreamID)
offset += streamIDLen
// Encode RequestID
copy(buf[offset:], h.RequestID)
return buf
}
// UnmarshalBinary decodes the header from binary format
func (h *DataHeader) UnmarshalBinary(data []byte) error {
if len(data) < binaryHeaderMinSize {
return errors.New("invalid binary header: too short")
}
// Decode flags
flags := data[0]
h.Type = DataType(flags & 0x07) // Bits 0-2
h.IsLast = (flags & 0x08) != 0 // Bit 3
// Decode lengths
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
// Validate total length
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
if len(data) < expectedLen {
return errors.New("invalid binary header: length mismatch")
}
// Decode StreamID
offset := binaryHeaderMinSize
h.StreamID = string(data[offset : offset+streamIDLen])
offset += streamIDLen
// Decode RequestID
h.RequestID = string(data[offset : offset+requestIDLen])
return nil
}
// Size returns the size of the binary-encoded header
func (h *DataHeader) Size() int {
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
}

View File

@@ -1,34 +0,0 @@
package protocol
import (
json "github.com/goccy/go-json"
)
type FlowControlAction string
const (
FlowControlPause FlowControlAction = "pause"
FlowControlResume FlowControlAction = "resume"
)
type FlowControlMessage struct {
StreamID string `json:"stream_id"`
Action FlowControlAction `json:"action"`
}
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
msg := FlowControlMessage{
StreamID: streamID,
Action: action,
}
payload, _ := json.Marshal(&msg)
return NewFrame(FrameTypeFlowControl, payload)
}
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
var msg FlowControlMessage
if err := json.Unmarshal(payload, &msg); err != nil {
return nil, err
}
return &msg, nil
}

View File

@@ -18,14 +18,14 @@ const (
type FrameType byte
const (
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeData FrameType = 0x05
FrameTypeClose FrameType = 0x06
FrameTypeError FrameType = 0x07
FrameTypeFlowControl FrameType = 0x08
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeClose FrameType = 0x05
FrameTypeError FrameType = 0x06
FrameTypeDataConnect FrameType = 0x07
FrameTypeDataConnectAck FrameType = 0x08
)
// String returns the string representation of frame type
@@ -39,14 +39,14 @@ func (t FrameType) String() string {
return "Heartbeat"
case FrameTypeHeartbeatAck:
return "HeartbeatAck"
case FrameTypeData:
return "Data"
case FrameTypeClose:
return "Close"
case FrameTypeError:
return "Error"
case FrameTypeFlowControl:
return "FlowControl"
case FrameTypeDataConnect:
return "DataConnect"
case FrameTypeDataConnectAck:
return "DataConnectAck"
default:
return fmt.Sprintf("Unknown(%d)", t)
}
@@ -56,6 +56,9 @@ type Frame struct {
Type FrameType
Payload []byte
poolBuffer *[]byte
// queuedBytes is set by FrameWriter when the frame is enqueued.
// It allows the writer to decrement backlog counters exactly once.
queuedBytes int64
}
func WriteFrame(w io.Writer, frame *Frame) error {
@@ -130,6 +133,8 @@ func (f *Frame) Release() {
f.poolBuffer = nil
f.Payload = nil
}
// Reset queued marker to avoid carrying over stale state if the frame is reused.
f.queuedBytes = 0
}
// NewFrame creates a new frame

View File

@@ -1,119 +0,0 @@
package protocol
import (
"errors"
json "github.com/goccy/go-json"
"github.com/vmihailenco/msgpack/v5"
)
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
return msgpack.Marshal(req)
}
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var req HTTPRequest
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &req); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &req); err != nil {
return nil, err
}
}
return &req, nil
}
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPRequestHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
return msgpack.Marshal(resp)
}
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var resp HTTPResponse
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &resp); err != nil {
return nil, err
}
}
return &resp, nil
}
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPResponseHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}

View File

@@ -1,71 +0,0 @@
package protocol
// MessageType defines the type of tunnel message
type MessageType string
const (
// TypeRegister is sent when a client connects and gets a subdomain assigned
TypeRegister MessageType = "register"
// TypeRequest is sent from server to client when an HTTP request arrives
TypeRequest MessageType = "request"
// TypeResponse is sent from client to server with the HTTP response
TypeResponse MessageType = "response"
// TypeHeartbeat is sent periodically to keep the connection alive
TypeHeartbeat MessageType = "heartbeat"
// TypeError is sent when an error occurs
TypeError MessageType = "error"
)
// Message represents a tunnel protocol message
type Message struct {
Type MessageType `json:"type"`
ID string `json:"id,omitempty"`
Subdomain string `json:"subdomain,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// HTTPRequest represents an HTTP request to be forwarded
type HTTPRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPRequestHead represents HTTP request headers for streaming (no body)
type HTTPRequestHead struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// HTTPResponse represents an HTTP response from the local service
type HTTPResponse struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPResponseHead represents HTTP response headers for streaming (no body)
type HTTPResponseHead struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// RegisterData contains information sent when a tunnel is registered
type RegisterData struct {
Subdomain string `json:"subdomain"`
URL string `json:"url"`
Message string `json:"message"`
}
// ErrorData contains error information
type ErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}

View File

@@ -2,12 +2,23 @@ package protocol
import json "github.com/goccy/go-json"
// PoolCapabilities advertises client connection pool capabilities
type PoolCapabilities struct {
MaxDataConns int `json:"max_data_conns"` // Maximum data connections client supports
Version int `json:"version"` // Protocol version for pool features
}
// RegisterRequest is sent by client to register a tunnel
type RegisterRequest struct {
Token string `json:"token"` // Authentication token
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
LocalPort int `json:"local_port"` // Local port to forward to
// Connection pool fields (optional, for multi-connection support)
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
}
// RegisterResponse is sent by server after successful registration
@@ -16,6 +27,25 @@ type RegisterResponse struct {
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
URL string `json:"url"` // Full tunnel URL
Message string `json:"message"` // Success message
// Connection pool fields (optional, for multi-connection support)
TunnelID string `json:"tunnel_id,omitempty"` // Unique tunnel identifier
SupportsDataConn bool `json:"supports_data_conn,omitempty"` // Server supports multi-connection
RecommendedConns int `json:"recommended_conns,omitempty"` // Suggested data connection count
}
// DataConnectRequest is sent by data connections to join a tunnel
type DataConnectRequest struct {
TunnelID string `json:"tunnel_id"` // Tunnel to join
Token string `json:"token"` // Same auth token as primary
ConnectionID string `json:"connection_id"` // Unique connection identifier
}
// DataConnectResponse acknowledges data connection
type DataConnectResponse struct {
Accepted bool `json:"accepted"` // Whether connection was accepted
ConnectionID string `json:"connection_id"` // Echoed connection ID
Message string `json:"message,omitempty"` // Optional message
}
// ErrorMessage represents an error
@@ -24,9 +54,6 @@ type ErrorMessage struct {
Message string `json:"message"` // Error message
}
// Note: DataHeader is now defined in binary_header.go as a pure binary structure
// TCPData has been removed - use DataHeader + raw bytes directly
// Marshal helpers for control plane messages (JSON encoding)
func MarshalJSON(v interface{}) ([]byte, error) {
return json.Marshal(v)

View File

@@ -1,96 +0,0 @@
package protocol
import (
"encoding/binary"
"errors"
"drip/internal/shared/pool"
)
// encodeDataPayload encodes a data header and payload into a frame payload.
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
payload := make([]byte, totalLen)
flags := uint8(header.Type) & 0x07
if header.IsLast {
flags |= 0x08
}
payload[0] = flags
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
offset := binaryHeaderMinSize
copy(payload[offset:], header.StreamID)
offset += streamIDLen
copy(payload[offset:], header.RequestID)
offset += requestIDLen
copy(payload[offset:], data)
return payload, nil
}
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
// Returns payload slice and pool buffer pointer (may be nil).
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
dynamicThreshold := GetAdaptiveThreshold()
if totalLen < dynamicThreshold {
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
if totalLen > pool.SizeLarge {
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
poolBuffer = pool.GetBuffer(totalLen)
payload = (*poolBuffer)[:totalLen]
flags := uint8(header.Type) & 0x07
if header.IsLast {
flags |= 0x08
}
payload[0] = flags
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
offset := binaryHeaderMinSize
copy(payload[offset:], header.StreamID)
offset += streamIDLen
copy(payload[offset:], header.RequestID)
offset += requestIDLen
copy(payload[offset:], data)
return payload, poolBuffer, nil
}
// DecodeDataPayload decodes a frame payload into header and data.
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
if len(payload) < binaryHeaderMinSize {
return DataHeader{}, nil, errors.New("invalid payload: too short")
}
var header DataHeader
if err := header.UnmarshalBinary(payload); err != nil {
return DataHeader{}, nil, err
}
headerSize := header.Size()
if len(payload) < headerSize {
return DataHeader{}, nil, errors.New("invalid payload: data missing")
}
data := payload[headerSize:]
return header, data, nil
}

View File

@@ -4,20 +4,11 @@ import (
"sync"
)
// SafeFrame wraps Frame with automatic resource cleanup
type SafeFrame struct {
*Frame
once sync.Once
}
// NewSafeFrame creates a SafeFrame that implements io.Closer
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
return &SafeFrame{
Frame: NewFrame(frameType, payload),
}
}
// Close implements io.Closer, ensures Release is called exactly once
func (sf *SafeFrame) Close() error {
sf.once.Do(func() {
if sf.Frame != nil {
@@ -27,14 +18,6 @@ func (sf *SafeFrame) Close() error {
return nil
}
// WithFrame wraps an existing Frame with automatic cleanup
func WithFrame(frame *Frame) *SafeFrame {
return &SafeFrame{Frame: frame}
}
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
func (sf *SafeFrame) MustClose() {
if err := sf.Close(); err != nil {
panic(err)
}
}

View File

@@ -4,16 +4,18 @@ import (
"errors"
"io"
"sync"
"sync/atomic"
"time"
)
type FrameWriter struct {
conn io.Writer
queue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
conn io.Writer
queue chan *Frame
controlQueue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
maxBatch int
maxBatchWait time.Duration
@@ -24,13 +26,20 @@ type FrameWriter struct {
heartbeatControl chan struct{}
// Error handling
writeErr error
errOnce sync.Once
onWriteError func(error) // Callback for write errors
writeErr error
errOnce sync.Once
onWriteError func(error) // Callback for write errors
// Adaptive flushing
adaptiveFlush bool // Enable adaptive flush based on queue depth
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
adaptiveFlush bool // Enable adaptive flush based on queue depth
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
// Hooks
preWriteHook func(*Frame) // Called right before a frame is written to conn
// Backlog tracking
queuedFrames atomic.Int64
queuedBytes atomic.Int64
}
func NewFrameWriter(conn io.Writer) *FrameWriter {
@@ -41,8 +50,14 @@ func NewFrameWriter(conn io.Writer) *FrameWriter {
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
w := &FrameWriter{
conn: conn,
queue: make(chan *Frame, queueSize),
conn: conn,
queue: make(chan *Frame, queueSize),
controlQueue: make(chan *Frame, func() int {
if queueSize < 256 {
return queueSize
}
return 256
}()), // control path needs small, fast buffer
batch: make([]*Frame, 0, maxBatch),
maxBatch: maxBatch,
maxBatchWait: maxBatchWait,
@@ -74,6 +89,22 @@ func (w *FrameWriter) writeLoop() {
}()
for {
// Always drain control queue first to prioritize control/heartbeat frames.
select {
case frame, ok := <-w.controlQueue:
if !ok {
w.mu.Lock()
w.flushBatchLocked()
w.mu.Unlock()
return
}
w.mu.Lock()
w.flushFrameLocked(frame)
w.mu.Unlock()
continue
default:
}
select {
case frame, ok := <-w.queue:
if !ok {
@@ -105,8 +136,7 @@ func (w *FrameWriter) writeLoop() {
w.mu.Lock()
if w.heartbeatCallback != nil {
if frame := w.heartbeatCallback(); frame != nil {
w.batch = append(w.batch, frame)
w.flushBatchLocked()
w.flushFrameLocked(frame)
}
}
w.mu.Unlock()
@@ -139,22 +169,47 @@ func (w *FrameWriter) flushBatchLocked() {
}
for _, frame := range w.batch {
if err := WriteFrame(w.conn, frame); err != nil {
w.errOnce.Do(func() {
w.writeErr = err
if w.onWriteError != nil {
go w.onWriteError(err)
}
w.closed = true
})
}
frame.Release()
w.flushFrameLocked(frame)
}
w.batch = w.batch[:0]
}
// flushFrameLocked writes a single frame immediately. Caller must hold w.mu.
func (w *FrameWriter) flushFrameLocked(frame *Frame) {
if frame == nil {
return
}
if w.preWriteHook != nil {
w.preWriteHook(frame)
}
if err := WriteFrame(w.conn, frame); err != nil {
w.errOnce.Do(func() {
w.writeErr = err
if w.onWriteError != nil {
go w.onWriteError(err)
}
w.closed = true
})
}
w.unmarkQueued(frame)
frame.Release()
}
func (w *FrameWriter) WriteFrame(frame *Frame) error {
return w.WriteFrameWithCancel(frame, nil)
}
// WriteFrameWithCancel writes a frame with an optional cancellation channel
// If cancel is closed, the write will be aborted immediately
func (w *FrameWriter) WriteFrameWithCancel(frame *Frame, cancel <-chan struct{}) error {
if frame == nil {
return nil
}
w.mu.Lock()
if w.closed {
w.mu.Unlock()
@@ -165,10 +220,19 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
}
w.mu.Unlock()
size := int64(len(frame.Payload) + FrameHeaderSize)
w.queuedFrames.Add(1)
w.queuedBytes.Add(size)
atomic.StoreInt64(&frame.queuedBytes, size)
// Try non-blocking first for best performance
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
@@ -176,6 +240,54 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
return err
}
return errors.New("writer closed")
default:
}
// Queue full - block with cancellation support
if cancel != nil {
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-cancel:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("write cancelled")
}
}
// No cancel channel - block with timeout
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-time.After(30 * time.Second):
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("write queue full timeout")
}
}
@@ -189,8 +301,14 @@ func (w *FrameWriter) Close() error {
w.mu.Unlock()
close(w.queue)
close(w.controlQueue)
for frame := range w.queue {
w.unmarkQueued(frame)
frame.Release()
}
for frame := range w.controlQueue {
w.unmarkQueued(frame)
frame.Release()
}
@@ -264,3 +382,97 @@ func (w *FrameWriter) DisableAdaptiveFlush() {
w.adaptiveFlush = false
w.mu.Unlock()
}
// WriteControl enqueues a control/prioritized frame to be written ahead of data frames.
func (w *FrameWriter) WriteControl(frame *Frame) error {
if frame == nil {
return nil
}
w.mu.Lock()
if w.closed {
w.mu.Unlock()
if w.writeErr != nil {
return w.writeErr
}
return errors.New("writer closed")
}
w.mu.Unlock()
size := int64(len(frame.Payload) + FrameHeaderSize)
w.queuedFrames.Add(1)
w.queuedBytes.Add(size)
atomic.StoreInt64(&frame.queuedBytes, size)
// Try non-blocking first
select {
case w.controlQueue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
default:
}
// Queue full - wait with timeout
select {
case w.controlQueue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-time.After(50 * time.Millisecond):
// Control frames should have priority, shorter timeout
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("control queue full timeout")
}
}
// SetPreWriteHook registers a callback invoked just before a frame is written to the underlying writer.
func (w *FrameWriter) SetPreWriteHook(hook func(*Frame)) {
w.mu.Lock()
w.preWriteHook = hook
w.mu.Unlock()
}
// QueuedFrames returns the number of frames currently queued (data + control).
func (w *FrameWriter) QueuedFrames() int64 {
return w.queuedFrames.Load()
}
// QueuedBytes returns the approximate number of bytes currently queued.
func (w *FrameWriter) QueuedBytes() int64 {
return w.queuedBytes.Load()
}
// unmarkQueued decrements backlog counters for a frame once it is written or discarded.
func (w *FrameWriter) unmarkQueued(frame *Frame) {
if frame == nil {
return
}
size := atomic.SwapInt64(&frame.queuedBytes, 0)
if size <= 0 {
return
}
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
}

View File

@@ -0,0 +1,77 @@
package stats
// FormatBytes formats bytes to human readable string
func FormatBytes(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case bytes >= GB:
return formatFloat(float64(bytes)/float64(GB)) + " GB"
case bytes >= MB:
return formatFloat(float64(bytes)/float64(MB)) + " MB"
case bytes >= KB:
return formatFloat(float64(bytes)/float64(KB)) + " KB"
default:
return formatInt(bytes) + " B"
}
}
// FormatSpeed formats speed (bytes per second) to human readable string
func FormatSpeed(bytesPerSec int64) string {
if bytesPerSec == 0 {
return "0 B/s"
}
return FormatBytes(bytesPerSec) + "/s"
}
func formatFloat(f float64) string {
if f >= 100 {
return formatInt(int64(f))
} else if f >= 10 {
return formatOneDecimal(f)
}
return formatTwoDecimal(f)
}
func formatInt(i int64) string {
return intToStr(i)
}
func formatOneDecimal(f float64) string {
i := int64(f * 10)
whole := i / 10
frac := i % 10
return intToStr(whole) + "." + intToStr(frac)
}
func formatTwoDecimal(f float64) string {
i := int64(f * 100)
whole := i / 100
frac := i % 100
if frac < 10 {
return intToStr(whole) + ".0" + intToStr(frac)
}
return intToStr(whole) + "." + intToStr(frac)
}
func intToStr(i int64) string {
if i == 0 {
return "0"
}
if i < 0 {
return "-" + intToStr(-i)
}
var buf [20]byte
pos := len(buf)
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
return string(buf[pos:])
}

View File

@@ -1,4 +1,4 @@
package tcp
package stats
import (
"sync"
@@ -13,7 +13,8 @@ type TrafficStats struct {
totalBytesOut int64
// Request counts
totalRequests int64
totalRequests int64
activeConnections int64
// For speed calculation
lastBytesIn int64
@@ -53,6 +54,17 @@ func (s *TrafficStats) AddRequest() {
atomic.AddInt64(&s.totalRequests, 1)
}
func (s *TrafficStats) IncActiveConnections() {
atomic.AddInt64(&s.activeConnections, 1)
}
func (s *TrafficStats) DecActiveConnections() {
v := atomic.AddInt64(&s.activeConnections, -1)
if v < 0 {
atomic.StoreInt64(&s.activeConnections, 0)
}
}
// GetTotalBytesIn returns total incoming bytes
func (s *TrafficStats) GetTotalBytesIn() int64 {
return atomic.LoadInt64(&s.totalBytesIn)
@@ -68,6 +80,10 @@ func (s *TrafficStats) GetTotalRequests() int64 {
return atomic.LoadInt64(&s.totalRequests)
}
func (s *TrafficStats) GetActiveConnections() int64 {
return atomic.LoadInt64(&s.activeConnections)
}
// GetTotalBytes returns total bytes (in + out)
func (s *TrafficStats) GetTotalBytes() int64 {
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
@@ -81,8 +97,10 @@ func (s *TrafficStats) UpdateSpeed() {
now := time.Now()
elapsed := now.Sub(s.lastTime).Seconds()
// Require minimum interval of 100ms to avoid division issues
if elapsed < 0.1 {
return // Avoid division by zero or too frequent updates
return
}
currentIn := atomic.LoadInt64(&s.totalBytesIn)
@@ -91,8 +109,20 @@ func (s *TrafficStats) UpdateSpeed() {
deltaIn := currentIn - s.lastBytesIn
deltaOut := currentOut - s.lastBytesOut
s.speedIn = int64(float64(deltaIn) / elapsed)
s.speedOut = int64(float64(deltaOut) / elapsed)
// Calculate instantaneous speed
if deltaIn > 0 {
s.speedIn = int64(float64(deltaIn) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedIn = 0
}
if deltaOut > 0 {
s.speedOut = int64(float64(deltaOut) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedOut = 0
}
s.lastBytesIn = currentIn
s.lastBytesOut = currentOut
@@ -119,18 +149,19 @@ func (s *TrafficStats) GetUptime() time.Duration {
}
// Snapshot returns a snapshot of all stats
type StatsSnapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
Uptime time.Duration
type Snapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
ActiveConnections int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
Uptime time.Duration
}
// GetSnapshot returns a snapshot of current stats
func (s *TrafficStats) GetSnapshot() StatsSnapshot {
func (s *TrafficStats) GetSnapshot() Snapshot {
s.speedMu.Lock()
speedIn := s.speedIn
speedOut := s.speedOut
@@ -138,90 +169,16 @@ func (s *TrafficStats) GetSnapshot() StatsSnapshot {
totalIn := atomic.LoadInt64(&s.totalBytesIn)
totalOut := atomic.LoadInt64(&s.totalBytesOut)
active := atomic.LoadInt64(&s.activeConnections)
return StatsSnapshot{
TotalBytesIn: totalIn,
TotalBytesOut: totalOut,
TotalBytes: totalIn + totalOut,
TotalRequests: atomic.LoadInt64(&s.totalRequests),
SpeedIn: speedIn,
SpeedOut: speedOut,
Uptime: time.Since(s.startTime),
return Snapshot{
TotalBytesIn: totalIn,
TotalBytesOut: totalOut,
TotalBytes: totalIn + totalOut,
TotalRequests: atomic.LoadInt64(&s.totalRequests),
ActiveConnections: active,
SpeedIn: speedIn,
SpeedOut: speedOut,
Uptime: time.Since(s.startTime),
}
}
// FormatBytes formats bytes to human readable string
func FormatBytes(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case bytes >= GB:
return formatFloat(float64(bytes)/float64(GB)) + " GB"
case bytes >= MB:
return formatFloat(float64(bytes)/float64(MB)) + " MB"
case bytes >= KB:
return formatFloat(float64(bytes)/float64(KB)) + " KB"
default:
return formatInt(bytes) + " B"
}
}
// FormatSpeed formats speed (bytes per second) to human readable string
func FormatSpeed(bytesPerSec int64) string {
if bytesPerSec == 0 {
return "0 B/s"
}
return FormatBytes(bytesPerSec) + "/s"
}
func formatFloat(f float64) string {
if f >= 100 {
return formatInt(int64(f))
} else if f >= 10 {
return formatOneDecimal(f)
}
return formatTwoDecimal(f)
}
func formatInt(i int64) string {
return intToStr(i)
}
func formatOneDecimal(f float64) string {
i := int64(f * 10)
whole := i / 10
frac := i % 10
return intToStr(whole) + "." + intToStr(frac)
}
func formatTwoDecimal(f float64) string {
i := int64(f * 100)
whole := i / 100
frac := i % 100
if frac < 10 {
return intToStr(whole) + ".0" + intToStr(frac)
}
return intToStr(whole) + "." + intToStr(frac)
}
func intToStr(i int64) string {
if i == 0 {
return "0"
}
if i < 0 {
return "-" + intToStr(-i)
}
var buf [20]byte
pos := len(buf)
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
return string(buf[pos:])
}

View File

@@ -2,6 +2,9 @@ package ui
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
// RenderConfigInit renders config initialization UI
@@ -100,18 +103,59 @@ func RenderConfigValidation(serverValid bool, serverMsg string, tokenSet bool, t
}
// RenderDaemonStarted renders daemon started message
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string, url string, forwardAddr string, serverAddr string) string {
if forwardAddr == "" {
forwardAddr = fmt.Sprintf("localhost:%d", port)
}
urlLine := Muted("(resolving...)")
if url != "" {
urlBadge := lipgloss.NewStyle().
Background(successColor).
Foreground(lipgloss.Color("#f8fafc")).
Bold(true).
Padding(0, 1).
Render(url)
urlLine = urlBadge
}
headline := successStyle.Render("✓ Tunnel Started in Background")
lines := []string{
KeyValue("Type", Highlight(tunnelType)),
KeyValue("Port", fmt.Sprintf("%d", port)),
KeyValue("PID", fmt.Sprintf("%d", pid)),
KeyValue("Forward", forwardAddr),
}
if serverAddr != "" {
lines = append(lines, KeyValue("Server", serverAddr))
}
lines = append(lines,
"",
Muted("Commands:"),
Cyan(" drip list") + Muted(" Check tunnel status"),
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
Cyan(" drip list")+Muted(" Check tunnel status"),
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port))+Muted(" View logs"),
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port))+Muted(" Stop tunnel"),
"",
Muted("Logs: ") + mutedStyle.Render(logPath),
Muted("Logs: ")+mutedStyle.Render(logPath),
)
contentWidth := 0
for _, line := range append([]string{headline}, lines...) {
if w := lipgloss.Width(line); w > contentWidth {
contentWidth = w
}
}
return SuccessBox("Tunnel Started in Background", lines...)
if w := lipgloss.Width(urlLine); w > contentWidth {
contentWidth = w
}
centeredURL := lipgloss.PlaceHorizontal(contentWidth, lipgloss.Center, urlLine)
contentLines := make([]string, 0, len(lines)+4)
contentLines = append(contentLines, headline, "", centeredURL, "")
contentLines = append(contentLines, lines...)
return successBoxStyle.Render(strings.Join(contentLines, "\n"))
}

View File

@@ -93,6 +93,11 @@ func RenderTunnelStats(status *TunnelStatus) string {
_, _, accent := tunnelVisuals(status.Type)
requestLabel := "Requests"
if status.Type == "tcp" {
requestLabel = "Connections"
}
header := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render("◉"),
@@ -102,7 +107,7 @@ func RenderTunnelStats(status *TunnelStatus) string {
row1 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Latency", latencyStr, statsColumnWidth),
statColumn("Requests", highlightStyle.Render(requestsStr), statsColumnWidth),
statColumn(requestLabel, highlightStyle.Render(requestsStr), statsColumnWidth),
)
row2 := lipgloss.JoinHorizontal(