feat(tunnel): switch to yamux stream proxying and connection pooling

- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
Gouryella
2025-12-13 18:03:44 +08:00
parent 3c93789266
commit 0c19c3300c
55 changed files with 3380 additions and 4849 deletions

View File

@@ -12,7 +12,7 @@ import (
"syscall"
"time"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"github.com/spf13/cobra"
)
@@ -24,6 +24,7 @@ var attachCmd = &cobra.Command{
Examples:
drip attach List running tunnels and select one
drip attach http 3000 Attach to HTTP tunnel on port 3000
drip attach https 8443 Attach to HTTPS tunnel on port 8443
drip attach tcp 5432 Attach to TCP tunnel on port 5432
Press Ctrl+C to detach (tunnel will continue running).`,
@@ -36,7 +37,7 @@ func init() {
rootCmd.AddCommand(attachCmd)
}
func runAttach(cmd *cobra.Command, args []string) error {
func runAttach(_ *cobra.Command, args []string) error {
CleanupStaleDaemons()
daemons, err := ListAllDaemons()
@@ -59,8 +60,8 @@ func runAttach(cmd *cobra.Command, args []string) error {
if len(args) == 2 {
tunnelType := args[0]
if tunnelType != "http" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
}
port, err := strconv.Atoi(args[1])
@@ -119,10 +120,13 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
uptime := time.Since(d.StartTime)
var typeStr string
if d.Type == "http" {
typeStr = ui.Success("HTTP")
} else {
typeStr = ui.Highlight("TCP")
switch d.Type {
case "http":
typeStr = ui.Highlight("HTTP")
case "https":
typeStr = ui.Highlight("HTTPS")
default:
typeStr = ui.Cyan("TCP")
}
table.AddRow([]string{
@@ -196,7 +200,7 @@ func attachToDaemon(daemon *DaemonInfo) error {
select {
case <-sigCh:
if tailCmd.Process != nil {
tailCmd.Process.Kill()
_ = tailCmd.Process.Kill()
}
fmt.Println()
fmt.Println(ui.Warning("Detached from tunnel (tunnel is still running)"))

View File

@@ -7,7 +7,7 @@ import (
"os"
"strings"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"drip/pkg/config"
"github.com/spf13/cobra"
)
@@ -77,7 +77,7 @@ func init() {
rootCmd.AddCommand(configCmd)
}
func runConfigInit(cmd *cobra.Command, args []string) error {
func runConfigInit(_ *cobra.Command, _ []string) error {
fmt.Print(ui.RenderConfigInit())
reader := bufio.NewReader(os.Stdin)
@@ -109,7 +109,7 @@ func runConfigInit(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigShow(cmd *cobra.Command, args []string) error {
func runConfigShow(_ *cobra.Command, _ []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
return err
@@ -137,7 +137,7 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigSet(cmd *cobra.Command, args []string) error {
func runConfigSet(_ *cobra.Command, _ []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
cfg = &config.ClientConfig{
@@ -173,7 +173,7 @@ func runConfigSet(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigReset(cmd *cobra.Command, args []string) error {
func runConfigReset(_ *cobra.Command, _ []string) error {
configPath := config.DefaultClientConfigPath()
if !config.ConfigExists("") {
@@ -202,7 +202,7 @@ func runConfigReset(cmd *cobra.Command, args []string) error {
return nil
}
func runConfigValidate(cmd *cobra.Command, args []string) error {
func runConfigValidate(_ *cobra.Command, _ []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
fmt.Println(ui.Error("Failed to load configuration"))

View File

@@ -5,10 +5,9 @@ import (
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
json "github.com/goccy/go-json"
)
@@ -194,8 +193,8 @@ func StartDaemon(tunnelType string, port int, args []string) error {
return fmt.Errorf("failed to start daemon: %w", err)
}
// Don't wait for the process - let it run in background
// The child process will save its own daemon info after connecting
_ = logFile.Close()
_ = devNull.Close()
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath))
@@ -220,28 +219,16 @@ func CleanupStaleDaemons() error {
// FormatDuration formats a duration in a human-readable way
func FormatDuration(d time.Duration) string {
if d < time.Minute {
switch {
case d < time.Minute:
return fmt.Sprintf("%ds", int(d.Seconds()))
} else if d < time.Hour {
case d < time.Hour:
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
} else if d < 24*time.Hour {
case d < 24*time.Hour:
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
}
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
return fmt.Sprintf("%dd %dh", days, hours)
}
// ParsePortFromArgs extracts the port number from command arguments
func ParsePortFromArgs(args []string) (int, error) {
for _, arg := range args {
if len(arg) > 0 && arg[0] == '-' {
continue
}
port, err := strconv.Atoi(arg)
if err == nil && port > 0 && port <= 65535 {
return port, nil
}
}
return 0, fmt.Errorf("port number not found in arguments")
}

View File

@@ -2,22 +2,14 @@ package cli
import (
"fmt"
"os"
"strconv"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/pkg/config"
"github.com/spf13/cobra"
)
const (
maxReconnectAttempts = 5
reconnectInterval = 3 * time.Second
)
var (
subdomain string
daemonMode bool
@@ -52,55 +44,19 @@ func init() {
rootCmd.AddCommand(httpCmd)
}
func runHTTP(cmd *cobra.Command, args []string) error {
func runHTTP(_ *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if daemonMode && !daemonMarker {
daemonArgs := append([]string{"http"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("http", port, daemonArgs)
return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress))
}
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip http %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
serverAddr, token, err := resolveServerAddrAndToken("http", port)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
@@ -115,15 +71,7 @@ Please run 'drip config init' first, or use flags:
var daemon *DaemonInfo
if daemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "http",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
daemon = newDaemonInfo("http", port, subdomain, serverAddr)
}
return runTunnelWithUI(connConfig, daemon)

View File

@@ -2,24 +2,14 @@ package cli
import (
"fmt"
"os"
"strconv"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/pkg/config"
"github.com/spf13/cobra"
)
var (
httpsSubdomain string
httpsDaemonMode bool
httpsDaemonMarker bool
httpsLocalAddress string
)
var httpsCmd = &cobra.Command{
Use: "https <port>",
Short: "Start HTTPS tunnel",
@@ -39,86 +29,42 @@ Note: Uses TCP over TLS 1.3 for secure communication`,
}
func init() {
httpsCmd.Flags().StringVarP(&httpsSubdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpsCmd.Flags().BoolVarP(&httpsDaemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpsCmd.Flags().StringVarP(&httpsLocalAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpsCmd.Flags().BoolVar(&httpsDaemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
}
func runHTTPS(cmd *cobra.Command, args []string) error {
func runHTTPS(_ *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if httpsDaemonMode && !httpsDaemonMarker {
daemonArgs := append([]string{"https"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if httpsSubdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", httpsSubdomain)
}
if httpsLocalAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", httpsLocalAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("https", port, daemonArgs)
if daemonMode && !daemonMarker {
return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress))
}
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip https %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
serverAddr, token, err := resolveServerAddrAndToken("https", port)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
TunnelType: protocol.TunnelTypeHTTPS,
LocalHost: httpsLocalAddress,
LocalHost: localAddress,
LocalPort: port,
Subdomain: httpsSubdomain,
Subdomain: subdomain,
Insecure: insecure,
}
var daemon *DaemonInfo
if httpsDaemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "https",
Port: port,
Subdomain: httpsSubdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
if daemonMarker {
daemon = newDaemonInfo("https", port, subdomain, serverAddr)
}
return runTunnelWithUI(connConfig, daemon)

View File

@@ -8,13 +8,11 @@ import (
"strings"
"time"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"github.com/spf13/cobra"
)
var (
interactiveMode bool
)
var interactiveMode bool
var listCmd = &cobra.Command{
Use: "list",
@@ -26,7 +24,7 @@ Example:
drip list -i Interactive mode (select to attach/stop)
This command shows:
- Tunnel type (HTTP/TCP)
- Tunnel type (HTTP/HTTPS/TCP)
- Local port being tunneled
- Public URL
- Process ID (PID)
@@ -44,7 +42,7 @@ func init() {
rootCmd.AddCommand(listCmd)
}
func runList(cmd *cobra.Command, args []string) error {
func runList(_ *cobra.Command, _ []string) error {
CleanupStaleDaemons()
daemons, err := ListAllDaemons()
@@ -99,7 +97,7 @@ func runList(cmd *cobra.Command, args []string) error {
fmt.Print(table.Render())
if interactiveMode || shouldPromptForAction() {
if interactiveMode {
return runInteractiveList(daemons)
}
@@ -114,13 +112,6 @@ func runList(cmd *cobra.Command, args []string) error {
return nil
}
func shouldPromptForAction() bool {
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
return false
}
return false
}
func runInteractiveList(daemons []*DaemonInfo) error {
var runningDaemons []*DaemonInfo
for _, d := range daemons {
@@ -161,9 +152,9 @@ func runInteractiveList(daemons []*DaemonInfo) error {
"",
ui.Muted("What would you like to do?"),
"",
ui.Cyan(" 1.") + " Attach (view logs)",
ui.Cyan(" 2.") + " Stop tunnel",
ui.Muted(" q.") + " Cancel",
ui.Cyan(" 1.")+" Attach (view logs)",
ui.Cyan(" 2.")+" Stop tunnel",
ui.Muted(" q.")+" Cancel",
))
fmt.Print(ui.Muted("Choose an action: "))

View File

@@ -3,7 +3,7 @@ package cli
import (
"fmt"
"drip/internal/client/cli/ui"
"drip/internal/shared/ui"
"github.com/spf13/cobra"
)
@@ -56,14 +56,12 @@ func init() {
versionCmd.Flags().BoolVar(&versionPlain, "short", false, "Print version information without styling")
rootCmd.AddCommand(versionCmd)
// http and tcp commands are added in their respective init() functions
// config command is added in config.go init() function
}
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
Run: func(cmd *cobra.Command, args []string) {
Run: func(_ *cobra.Command, _ []string) {
if versionPlain {
fmt.Printf("Version: %s\nGit Commit: %s\nBuild Time: %s\n", Version, GitCommit, BuildTime)
return

View File

@@ -59,7 +59,7 @@ func init() {
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
}
func runServer(cmd *cobra.Command, args []string) error {
func runServer(_ *cobra.Command, _ []string) error {
if serverTLSCert == "" {
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
}
@@ -126,11 +126,9 @@ func runServer(cmd *cobra.Command, args []string) error {
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
responseHandler := proxy.NewResponseHandler(logger)
httpHandler := proxy.NewHandler(tunnelManager, logger, serverDomain, serverAuthToken)
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler)
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))

View File

@@ -28,7 +28,7 @@ func init() {
rootCmd.AddCommand(stopCmd)
}
func runStop(cmd *cobra.Command, args []string) error {
func runStop(_ *cobra.Command, args []string) error {
if args[0] == "all" {
return stopAllDaemons()
}

View File

@@ -2,13 +2,10 @@ package cli
import (
"fmt"
"os"
"strconv"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/pkg/config"
"github.com/spf13/cobra"
)
@@ -47,55 +44,19 @@ func init() {
rootCmd.AddCommand(tcpCmd)
}
func runTCP(cmd *cobra.Command, args []string) error {
func runTCP(_ *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if daemonMode && !daemonMarker {
daemonArgs := append([]string{"tcp"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("tcp", port, daemonArgs)
return StartDaemon("tcp", port, buildDaemonArgs("tcp", args, subdomain, localAddress))
}
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
serverAddr, token, err := resolveServerAddrAndToken("tcp", port)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
@@ -110,15 +71,7 @@ Please run 'drip config init' first, or use flags:
var daemon *DaemonInfo
if daemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "tcp",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
daemon = newDaemonInfo("tcp", port, subdomain, serverAddr)
}
return runTunnelWithUI(connConfig, daemon)

View File

@@ -0,0 +1,67 @@
package cli
import (
"fmt"
"os"
"time"
"drip/pkg/config"
)
func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAddress string) []string {
daemonArgs := append([]string{tunnelType}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return daemonArgs
}
func resolveServerAddrAndToken(tunnelType string, port int) (string, string, error) {
if serverURL != "" {
return serverURL, authToken, nil
}
cfg, err := config.LoadClientConfig("")
if err != nil {
return "", "", fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip %s %d --server SERVER:PORT --token TOKEN`, tunnelType, port)
}
if cfg.Server == "" {
return "", "", fmt.Errorf("server address is required")
}
return cfg.Server, cfg.Token, nil
}
func newDaemonInfo(tunnelType string, port int, subdomain string, serverAddr string) *DaemonInfo {
return &DaemonInfo{
PID: os.Getpid(),
Type: tunnelType,
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
}

View File

@@ -8,13 +8,17 @@ import (
"syscall"
"time"
"drip/internal/client/cli/ui"
"drip/internal/client/tcp"
"drip/internal/shared/ui"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
// runTunnelWithUI runs a tunnel with the new UI
const (
maxReconnectAttempts = 5
reconnectInterval = 3 * time.Second
)
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
@@ -28,7 +32,7 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
reconnectAttempts := 0
for {
connector := tcp.NewConnector(connConfig, logger)
connector := tcp.NewTunnelClient(connConfig, logger)
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
@@ -87,8 +91,8 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
renderTicker := time.NewTicker(1 * time.Second)
defer renderTicker.Stop()
var lastLatency time.Duration
lastRenderedLines := 0
@@ -97,27 +101,38 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
case <-renderTicker.C:
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
status.Latency = lastLatency
status.BytesIn = snapshot.TotalBytesIn
status.BytesOut = snapshot.TotalBytesOut
status.SpeedIn = float64(snapshot.SpeedIn)
status.SpeedOut = float64(snapshot.SpeedOut)
status.TotalRequest = snapshot.TotalRequests
statsView := ui.RenderTunnelStats(status)
if lastRenderedLines > 0 {
fmt.Print(clearLines(lastRenderedLines))
}
fmt.Print(statsView)
lastRenderedLines = countRenderedLines(statsView)
if stats == nil {
continue
}
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
status.Latency = lastLatency
status.BytesIn = snapshot.TotalBytesIn
status.BytesOut = snapshot.TotalBytesOut
status.SpeedIn = float64(snapshot.SpeedIn)
status.SpeedOut = float64(snapshot.SpeedOut)
if status.Type == "tcp" {
if snapshot.SpeedIn == 0 && snapshot.SpeedOut == 0 {
status.TotalRequest = 0
} else {
status.TotalRequest = snapshot.ActiveConnections
}
} else {
status.TotalRequest = snapshot.TotalRequests
}
statsView := ui.RenderTunnelStats(status)
if lastRenderedLines > 0 {
fmt.Print(clearLines(lastRenderedLines))
}
fmt.Print(statsView)
lastRenderedLines = countRenderedLines(statsView)
case <-stopDisplay:
return
}

View File

@@ -1,117 +0,0 @@
package ui
import (
"fmt"
)
// RenderConfigInit renders config initialization UI
func RenderConfigInit() string {
title := "Drip Configuration Setup"
box := boxStyle.Width(50)
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
}
// RenderConfigShow renders the config display
func RenderConfigShow(server, token string, tokenHidden bool, tlsEnabled bool, configPath string) string {
lines := []string{
KeyValue("Server", server),
}
if token != "" {
if tokenHidden {
if len(token) > 10 {
displayToken := token[:3] + "***" + token[len(token)-3:]
lines = append(lines, KeyValue("Token", Muted(displayToken+" (hidden)")))
} else {
lines = append(lines, KeyValue("Token", Muted(token[:3]+"*** (hidden)")))
}
} else {
lines = append(lines, KeyValue("Token", token))
}
} else {
lines = append(lines, KeyValue("Token", Muted("(not set)")))
}
tlsStatus := "enabled"
if !tlsEnabled {
tlsStatus = "disabled"
}
lines = append(lines, KeyValue("TLS", tlsStatus))
lines = append(lines, KeyValue("Config", Muted(configPath)))
return Info("Current Configuration", lines...)
}
// RenderConfigSaved renders config saved message
func RenderConfigSaved(configPath string) string {
return SuccessBox(
"Configuration Saved",
Muted("Config saved to: ")+configPath,
"",
Muted("You can now use 'drip' without --server and --token flags"),
)
}
// RenderConfigUpdated renders config updated message
func RenderConfigUpdated(updates []string) string {
lines := make([]string, len(updates)+1)
for i, update := range updates {
lines[i] = Success(update)
}
lines[len(updates)] = ""
lines = append(lines, Muted("Configuration has been updated"))
return SuccessBox("Configuration Updated", lines...)
}
// RenderConfigDeleted renders config deleted message
func RenderConfigDeleted() string {
return SuccessBox("Configuration Deleted", Muted("Configuration file has been removed"))
}
// RenderConfigValidation renders config validation results
func RenderConfigValidation(serverValid bool, serverMsg string, tokenSet bool, tokenMsg string, tlsEnabled bool) string {
lines := []string{}
if serverValid {
lines = append(lines, Success(serverMsg))
} else {
lines = append(lines, Error(serverMsg))
}
if tokenSet {
lines = append(lines, Success(tokenMsg))
} else {
lines = append(lines, Warning(tokenMsg))
}
if tlsEnabled {
lines = append(lines, Success("TLS is enabled"))
} else {
lines = append(lines, Warning("TLS is disabled (not recommended for production)"))
}
lines = append(lines, "")
lines = append(lines, Muted("Configuration validation complete"))
if serverValid && tokenSet && tlsEnabled {
return SuccessBox("Configuration Valid", lines...)
}
return WarningBox("Configuration Validation", lines...)
}
// RenderDaemonStarted renders daemon started message
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
lines := []string{
KeyValue("Type", Highlight(tunnelType)),
KeyValue("Port", fmt.Sprintf("%d", port)),
KeyValue("PID", fmt.Sprintf("%d", pid)),
"",
Muted("Commands:"),
Cyan(" drip list") + Muted(" Check tunnel status"),
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
"",
Muted("Logs: ") + mutedStyle.Render(logPath),
}
return SuccessBox("Tunnel Started in Background", lines...)
}

View File

@@ -1,184 +0,0 @@
package ui
import (
"github.com/charmbracelet/lipgloss"
)
var (
// Colors inspired by Vercel CLI
successColor = lipgloss.Color("#0070F3")
warningColor = lipgloss.Color("#F5A623")
errorColor = lipgloss.Color("#E00")
mutedColor = lipgloss.Color("#888")
highlightColor = lipgloss.Color("#0070F3")
cyanColor = lipgloss.Color("#50E3C2")
// Box styles - Vercel-like clean box
boxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
Padding(1, 2).
MarginTop(1).
MarginBottom(1)
successBoxStyle = boxStyle.BorderForeground(successColor)
warningBoxStyle = boxStyle.BorderForeground(warningColor)
errorBoxStyle = boxStyle.BorderForeground(errorColor)
// Text styles
titleStyle = lipgloss.NewStyle().
Bold(true)
subtitleStyle = lipgloss.NewStyle().
Foreground(mutedColor)
successStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
errorStyle = lipgloss.NewStyle().
Foreground(errorColor).
Bold(true)
warningStyle = lipgloss.NewStyle().
Foreground(warningColor).
Bold(true)
mutedStyle = lipgloss.NewStyle().
Foreground(mutedColor)
highlightStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
cyanStyle = lipgloss.NewStyle().
Foreground(cyanColor)
urlStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Underline(true).
Bold(true)
labelStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Width(12)
valueStyle = lipgloss.NewStyle().
Bold(true)
// Table styles (padding handled manually for consistent Windows output)
tableHeaderStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Bold(true)
)
// Success returns a styled success message
func Success(text string) string {
return successStyle.Render("✓ " + text)
}
// Error returns a styled error message
func Error(text string) string {
return errorStyle.Render("✗ " + text)
}
// Warning returns a styled warning message
func Warning(text string) string {
return warningStyle.Render("⚠ " + text)
}
// Muted returns a styled muted text
func Muted(text string) string {
return mutedStyle.Render(text)
}
// Highlight returns a styled highlighted text
func Highlight(text string) string {
return highlightStyle.Render(text)
}
// Cyan returns a styled cyan text
func Cyan(text string) string {
return cyanStyle.Render(text)
}
// URL returns a styled URL
func URL(text string) string {
return urlStyle.Render(text)
}
// Title returns a styled title
func Title(text string) string {
return titleStyle.Render(text)
}
// Subtitle returns a styled subtitle
func Subtitle(text string) string {
return subtitleStyle.Render(text)
}
// KeyValue returns a styled key-value pair
func KeyValue(key, value string) string {
return labelStyle.Render(key+":") + " " + valueStyle.Render(value)
}
// Info renders an info box (Vercel-style)
func Info(title string, lines ...string) string {
content := titleStyle.Render(title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return boxStyle.Render(content)
}
// SuccessBox renders a success box
func SuccessBox(title string, lines ...string) string {
content := successStyle.Render("✓ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return successBoxStyle.Render(content)
}
// WarningBox renders a warning box
func WarningBox(title string, lines ...string) string {
content := warningStyle.Render("⚠ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return warningBoxStyle.Render(content)
}
// ErrorBox renders an error box
func ErrorBox(title string, lines ...string) string {
content := errorStyle.Render("✗ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return errorBoxStyle.Render(content)
}

View File

@@ -1,145 +0,0 @@
package ui
import (
"fmt"
"runtime"
"strings"
"github.com/charmbracelet/lipgloss"
)
// Table represents a simple table for CLI output
type Table struct {
headers []string
rows [][]string
title string
}
// NewTable creates a new table
func NewTable(headers []string) *Table {
return &Table{
headers: headers,
rows: [][]string{},
}
}
// WithTitle sets the table title
func (t *Table) WithTitle(title string) *Table {
t.title = title
return t
}
// AddRow adds a row to the table
func (t *Table) AddRow(row []string) *Table {
t.rows = append(t.rows, row)
return t
}
// Render renders the table (Vercel-style)
func (t *Table) Render() string {
if len(t.rows) == 0 {
return ""
}
// Calculate column widths
colWidths := make([]int, len(t.headers))
for i, header := range t.headers {
colWidths[i] = lipgloss.Width(header)
}
for _, row := range t.rows {
for i, cell := range row {
if i < len(colWidths) {
width := lipgloss.Width(cell)
if width > colWidths[i] {
colWidths[i] = width
}
}
}
}
var output strings.Builder
// Title
if t.title != "" {
output.WriteString("\n")
output.WriteString(titleStyle.Render(t.title))
output.WriteString("\n\n")
}
// Header
headerParts := make([]string, len(t.headers))
for i, header := range t.headers {
styled := tableHeaderStyle.Render(header)
headerParts[i] = padRight(styled, colWidths[i])
}
output.WriteString(strings.Join(headerParts, " "))
output.WriteString("\n")
// Separator line
separatorChar := "─"
if runtime.GOOS == "windows" {
separatorChar = "-"
}
separatorParts := make([]string, len(t.headers))
for i := range t.headers {
separatorParts[i] = mutedStyle.Render(strings.Repeat(separatorChar, colWidths[i]))
}
output.WriteString(strings.Join(separatorParts, " "))
output.WriteString("\n")
// Rows
for _, row := range t.rows {
rowParts := make([]string, len(t.headers))
for i, cell := range row {
if i < len(colWidths) {
rowParts[i] = padRight(cell, colWidths[i])
}
}
output.WriteString(strings.Join(rowParts, " "))
output.WriteString("\n")
}
output.WriteString("\n")
return output.String()
}
// padRight pads
func padRight(text string, targetWidth int) string {
visibleWidth := lipgloss.Width(text)
if visibleWidth >= targetWidth {
return text
}
padding := strings.Repeat(" ", targetWidth-visibleWidth)
return text + padding
}
// Print prints the table
func (t *Table) Print() {
fmt.Print(t.Render())
}
// RenderList renders a simple list with bullet points
func RenderList(items []string) string {
bullet := "•"
if runtime.GOOS == "windows" {
bullet = "*"
}
var output strings.Builder
for _, item := range items {
output.WriteString(mutedStyle.Render(" " + bullet + " "))
output.WriteString(item)
output.WriteString("\n")
}
return output.String()
}
// RenderNumberedList renders a numbered list
func RenderNumberedList(items []string) string {
var output strings.Builder
for i, item := range items {
output.WriteString(mutedStyle.Render(fmt.Sprintf(" %d. ", i+1)))
output.WriteString(item)
output.WriteString("\n")
}
return output.String()
}

View File

@@ -1,246 +0,0 @@
package ui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/lipgloss"
)
const (
tunnelCardWidth = 76
statsColumnWidth = 32
)
var (
latencyFastColor = lipgloss.Color("#22c55e") // green
latencyYellowColor = lipgloss.Color("#eab308") // yellow
latencyOrangeColor = lipgloss.Color("#f97316") // orange
latencyRedColor = lipgloss.Color("#ef4444") // red
)
// TunnelStatus represents the status of a tunnel
type TunnelStatus struct {
Type string // "http", "https", "tcp"
URL string // Public URL
LocalAddr string // Local address
Latency time.Duration // Current latency
BytesIn int64 // Bytes received
BytesOut int64 // Bytes sent
SpeedIn float64 // Download speed
SpeedOut float64 // Upload speed
TotalRequest int64 // Total requests
}
// RenderTunnelConnected renders the tunnel connection card
func RenderTunnelConnected(status *TunnelStatus) string {
icon, typeStr, accent := tunnelVisuals(status.Type)
card := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accent).
Padding(1, 2).
Width(tunnelCardWidth)
typeBadge := lipgloss.NewStyle().
Background(accent).
Foreground(lipgloss.Color("#f8fafc")).
Bold(true).
Padding(0, 1).
Render(strings.ToUpper(typeStr) + " TUNNEL")
headline := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render(icon),
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Tunnel Connected"),
lipgloss.NewStyle().MarginLeft(2).Render(typeBadge),
)
urlLine := lipgloss.JoinHorizontal(
lipgloss.Left,
urlStyle.Foreground(accent).Render(status.URL),
lipgloss.NewStyle().MarginLeft(1).Foreground(mutedColor).Render("(forwarded address)"),
)
forwardLine := lipgloss.NewStyle().
MarginLeft(2).
Render(Muted("⇢ ") + valueStyle.Render(status.LocalAddr))
hint := lipgloss.NewStyle().
Foreground(latencyOrangeColor).
Render("Ctrl+C to stop • reconnects automatically")
content := lipgloss.JoinVertical(
lipgloss.Left,
headline,
"",
urlLine,
forwardLine,
"",
hint,
)
return "\n" + card.Render(content) + "\n"
}
// RenderTunnelStats renders real-time tunnel statistics in a card
func RenderTunnelStats(status *TunnelStatus) string {
latencyStr := formatLatency(status.Latency)
trafficStr := fmt.Sprintf("↓ %s ↑ %s", formatBytes(status.BytesIn), formatBytes(status.BytesOut))
speedStr := fmt.Sprintf("↓ %s ↑ %s", formatSpeed(status.SpeedIn), formatSpeed(status.SpeedOut))
requestsStr := fmt.Sprintf("%d", status.TotalRequest)
_, _, accent := tunnelVisuals(status.Type)
header := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render("◉"),
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Live Metrics"),
)
row1 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Latency", latencyStr, statsColumnWidth),
statColumn("Requests", highlightStyle.Render(requestsStr), statsColumnWidth),
)
row2 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Traffic", Cyan(trafficStr), statsColumnWidth),
statColumn("Speed", warningStyle.Render(speedStr), statsColumnWidth),
)
card := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accent).
Padding(1, 2).
Width(tunnelCardWidth)
body := lipgloss.JoinVertical(
lipgloss.Left,
header,
"",
row1,
row2,
)
return "\n" + card.Render(body) + "\n"
}
// RenderConnecting renders the connecting message
func RenderConnecting(serverAddr string, attempt int, maxAttempts int) string {
if attempt == 0 {
return Highlight("◌") + " Connecting to " + Muted(serverAddr) + "..."
}
return Warning(fmt.Sprintf("◌ Reconnecting to %s (attempt %d/%d)...", serverAddr, attempt, maxAttempts))
}
// RenderConnectionFailed renders connection failure message
func RenderConnectionFailed(err error) string {
return Error(fmt.Sprintf("Connection failed: %v", err))
}
// RenderShuttingDown renders shutdown message
func RenderShuttingDown() string {
return Warning("⏹ Shutting down...")
}
// RenderConnectionLost renders connection lost message
func RenderConnectionLost() string {
return Error("⚠ Connection lost!")
}
// RenderRetrying renders retry message
func RenderRetrying(interval time.Duration) string {
return Muted(fmt.Sprintf(" Retrying in %v...", interval))
}
// formatLatency formats latency with color
func formatLatency(d time.Duration) string {
if d == 0 {
return mutedStyle.Render("measuring...")
}
ms := d.Milliseconds()
var style lipgloss.Style
switch {
case ms < 50:
style = lipgloss.NewStyle().Foreground(latencyFastColor)
case ms < 150:
style = lipgloss.NewStyle().Foreground(latencyYellowColor)
case ms < 300:
style = lipgloss.NewStyle().Foreground(latencyOrangeColor)
default:
style = lipgloss.NewStyle().Foreground(latencyRedColor)
}
if ms == 0 {
us := d.Microseconds()
return style.Render(fmt.Sprintf("%dµs", us))
}
return style.Render(fmt.Sprintf("%dms", ms))
}
// formatBytes formats bytes to human readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// formatSpeed formats speed to human readable format
func formatSpeed(bytesPerSec float64) string {
const unit = 1024.0
if bytesPerSec < unit {
return fmt.Sprintf("%.0f B/s", bytesPerSec)
}
div, exp := unit, 0
for n := bytesPerSec / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB/s", bytesPerSec/div, "KMGTPE"[exp])
}
func statColumn(label, value string, width int) string {
labelView := lipgloss.NewStyle().
Foreground(mutedColor).
Render(strings.ToUpper(label))
block := lipgloss.JoinHorizontal(
lipgloss.Left,
labelView,
lipgloss.NewStyle().MarginLeft(1).Render(value),
)
if width <= 0 {
return block
}
return lipgloss.NewStyle().
Width(width).
Render(block)
}
func tunnelVisuals(tunnelType string) (string, string, lipgloss.Color) {
switch tunnelType {
case "http":
return "🚀", "HTTP", lipgloss.Color("#0070F3")
case "https":
return "🔒", "HTTPS", lipgloss.Color("#2D8CFF")
case "tcp":
return "🔌", "TCP", lipgloss.Color("#50E3C2")
default:
return "🌐", strings.ToUpper(tunnelType), lipgloss.Color("#0070F3")
}
}

View File

@@ -1,493 +1,50 @@
package tcp
import (
"crypto/tls"
"fmt"
"net"
json "github.com/goccy/go-json"
"sync"
"strings"
"time"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/recovery"
"drip/pkg/config"
"drip/internal/shared/stats"
"go.uber.org/zap"
)
// LatencyCallback is called when latency is measured
type LatencyCallback func(latency time.Duration)
// Connector manages the TCP connection to the server
type Connector struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
conn net.Conn
logger *zap.Logger
stopCh chan struct{}
once sync.Once
registered bool
assignedURL string
frameHandler *FrameHandler
frameWriter *protocol.FrameWriter
latencyCallback LatencyCallback
heartbeatSentAt time.Time
heartbeatMu sync.Mutex
lastLatency time.Duration
handlerWg sync.WaitGroup // Tracks active data frame handlers
closed bool
closedMu sync.RWMutex
// Worker pool for handling data frames
dataFrameQueue chan *protocol.Frame
workerCount int
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
}
// ConnectorConfig holds connector configuration
type ConnectorConfig struct {
ServerAddr string
Token string
TunnelType protocol.TunnelType
LocalHost string // Local host address (default: 127.0.0.1)
LocalHost string
LocalPort int
Subdomain string // Optional custom subdomain
Insecure bool // Skip TLS verification (testing only)
Subdomain string
Insecure bool
PoolSize int
PoolMin int
PoolMax int
}
// NewConnector creates a new connector
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
numCPU := pool.NumCPU()
workerCount := max(numCPU+numCPU/2, 4)
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Connector{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: cfg.TunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
logger: logger,
stopCh: make(chan struct{}),
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
workerCount: workerCount,
recoverer: recoverer,
panicMetrics: panicMetrics,
}
type TunnelClient interface {
Connect() error
Close() error
Wait()
GetURL() string
GetSubdomain() string
SetLatencyCallback(cb LatencyCallback)
GetLatency() time.Duration
GetStats() *stats.TrafficStats
IsClosed() bool
}
// Connect connects to the server and registers the tunnel
func (c *Connector) Connect() error {
c.logger.Info("Connecting to server",
zap.String("server", c.serverAddr),
zap.String("tunnel_type", string(c.tunnelType)),
zap.String("local_host", c.localHost),
zap.Int("local_port", c.localPort),
)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
c.conn = conn
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
conn.Close()
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
c.logger.Info("TLS connection established",
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
)
if err := c.register(); err != nil {
conn.Close()
return fmt.Errorf("registration failed: %w", err)
}
c.frameWriter = protocol.NewFrameWriter(c.conn)
bufferPool := pool.NewBufferPool()
c.frameHandler = NewFrameHandler(
c.conn,
c.frameWriter,
c.localHost,
c.localPort,
c.tunnelType,
c.logger,
c.IsClosed,
bufferPool,
)
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
for i := 0; i < c.workerCount; i++ {
c.handlerWg.Add(1)
go c.dataFrameWorker(i)
}
go c.frameHandler.WarmupConnectionPool(3)
go c.monitorQueuePressure()
go c.handleFrames()
return nil
func NewTunnelClient(cfg *ConnectorConfig, logger *zap.Logger) TunnelClient {
return NewPoolClient(cfg, logger)
}
// register sends registration request and waits for acknowledgment
func (c *Connector) register() error {
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
}
payload, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
err = protocol.WriteFrame(c.conn, regFrame)
if err != nil {
return fmt.Errorf("failed to send registration: %w", err)
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ackFrame, err := protocol.ReadFrame(c.conn)
if err != nil {
return fmt.Errorf("failed to read ack: %w", err)
}
defer ackFrame.Release()
c.conn.SetReadDeadline(time.Time{})
if ackFrame.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
return fmt.Errorf("registration error")
}
if ackFrame.Type != protocol.FrameTypeRegisterAck {
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
c.registered = true
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
c.logger.Info("Tunnel registered successfully",
zap.String("subdomain", resp.Subdomain),
zap.String("url", resp.URL),
zap.Int("remote_port", resp.Port),
)
return nil
}
func (c *Connector) dataFrameWorker(workerID int) {
defer c.handlerWg.Done()
defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID))
for {
select {
case frame, ok := <-c.dataFrameQueue:
if !ok {
return
}
func() {
sf := protocol.WithFrame(frame)
defer sf.Close()
defer c.recoverer.Recover("handleDataFrame")
if err := c.frameHandler.HandleDataFrame(sf.Frame); err != nil {
c.logger.Error("Failed to handle data frame",
zap.Int("worker_id", workerID),
zap.Error(err))
}
}()
case <-c.stopCh:
return
}
}
}
// handleFrames handles incoming frames from server
func (c *Connector) handleFrames() {
defer c.Close()
defer c.recoverer.Recover("handleFrames")
for {
select {
case <-c.stopCh:
return
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(c.conn)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Warn("Read timeout")
return
}
select {
case <-c.stopCh:
return
default:
c.logger.Error("Failed to read frame", zap.Error(err))
return
}
}
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeHeartbeatAck:
c.heartbeatMu.Lock()
if !c.heartbeatSentAt.IsZero() {
latency := time.Since(c.heartbeatSentAt)
c.lastLatency = latency
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
if c.latencyCallback != nil {
c.latencyCallback(latency)
}
} else {
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack")
}
sf.Close()
case protocol.FrameTypeData:
select {
case c.dataFrameQueue <- sf.Frame:
case <-c.stopCh:
sf.Close()
return
default:
c.logger.Warn("Data frame queue full, dropping frame")
sf.Close()
}
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Server requested close")
return
case protocol.FrameTypeError:
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(sf.Frame.Payload, &errMsg); err == nil {
c.logger.Error("Received error from server",
zap.String("code", errMsg.Code),
zap.String("message", errMsg.Message),
)
}
sf.Close()
return
default:
sf.Close()
c.logger.Warn("Unexpected frame type",
zap.String("type", sf.Frame.Type.String()),
)
}
}
}
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
c.closedMu.RLock()
if c.closed {
c.closedMu.RUnlock()
return nil
}
c.closedMu.RUnlock()
c.heartbeatMu.Lock()
c.heartbeatSentAt = time.Now()
c.heartbeatMu.Unlock()
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
}
// SendFrame sends a frame to the server
func (c *Connector) SendFrame(frame *protocol.Frame) error {
if !c.registered {
return fmt.Errorf("not registered")
}
return c.frameWriter.WriteFrame(frame)
}
func (c *Connector) Close() error {
c.once.Do(func() {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
close(c.stopCh)
close(c.dataFrameQueue)
done := make(chan struct{})
go func() {
c.handlerWg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
c.logger.Warn("Force closing: some handlers are still active")
}
if c.conn != nil {
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
if c.frameWriter != nil {
c.frameWriter.WriteFrame(closeFrame)
c.frameWriter.Close()
} else {
protocol.WriteFrame(c.conn, closeFrame)
}
c.conn.Close()
}
c.logger.Info("Connector closed")
})
return nil
}
// Wait blocks until connection is closed
func (c *Connector) Wait() {
<-c.stopCh
}
// GetURL returns the assigned tunnel URL
func (c *Connector) GetURL() string {
return c.assignedURL
}
// GetSubdomain returns the assigned subdomain
func (c *Connector) GetSubdomain() string {
return c.subdomain
}
// SetLatencyCallback sets the callback for latency updates
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
c.latencyCallback = cb
}
// GetLatency returns the last measured latency
func (c *Connector) GetLatency() time.Duration {
c.heartbeatMu.Lock()
defer c.heartbeatMu.Unlock()
return c.lastLatency
}
// GetStats returns the traffic stats from the frame handler
func (c *Connector) GetStats() *TrafficStats {
if c.frameHandler != nil {
return c.frameHandler.GetStats()
}
return nil
}
// IsClosed returns whether the connector has been closed
func (c *Connector) IsClosed() bool {
c.closedMu.RLock()
defer c.closedMu.RUnlock()
return c.closed
}
func (c *Connector) monitorQueuePressure() {
defer c.recoverer.Recover("monitorQueuePressure")
const (
pauseThreshold = 0.80
resumeThreshold = 0.50
checkInterval = 100 * time.Millisecond
)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
isPaused := false
for {
select {
case <-ticker.C:
queueLen := len(c.dataFrameQueue)
queueCap := cap(c.dataFrameQueue)
usage := float64(queueLen) / float64(queueCap)
if usage > pauseThreshold && !isPaused {
c.sendFlowControl("*", protocol.FlowControlPause)
isPaused = true
c.logger.Warn("Queue pressure high, sent pause signal",
zap.Int("queue_len", queueLen),
zap.Int("queue_cap", queueCap),
zap.Float64("usage", usage))
} else if usage < resumeThreshold && isPaused {
c.sendFlowControl("*", protocol.FlowControlResume)
isPaused = false
c.logger.Info("Queue pressure normal, sent resume signal",
zap.Int("queue_len", queueLen),
zap.Int("queue_cap", queueCap),
zap.Float64("usage", usage))
}
case <-c.stopCh:
return
}
}
}
func (c *Connector) sendFlowControl(streamID string, action protocol.FlowControlAction) {
frame := protocol.NewFlowControlFrame(streamID, action)
if err := c.SendFrame(frame); err != nil {
c.logger.Error("Failed to send flow control",
zap.String("action", string(action)),
zap.Error(err))
}
func isExpectedCloseError(err error) bool {
s := err.Error()
return strings.Contains(s, "EOF") ||
strings.Contains(s, "use of closed") ||
strings.Contains(s, "connection reset")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,450 @@
package tcp
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/pkg/config"
)
// PoolClient manages a pool of yamux sessions for tunnel connections.
type PoolClient struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
assignedURL string
tunnelID string
minSessions int
maxSessions int
initialSessions int
stats *stats.TrafficStats
httpClient *http.Client
latencyCallback atomic.Value // LatencyCallback
latencyNanos atomic.Int64
ctx context.Context
cancel context.CancelFunc
stopCh chan struct{}
doneCh chan struct{}
once sync.Once
wg sync.WaitGroup
closed atomic.Bool
primary *sessionHandle
mu sync.RWMutex
dataSessions map[string]*sessionHandle
desiredTotal int
lastScale time.Time
logger *zap.Logger
}
// NewPoolClient creates a new pool client.
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
tunnelType := cfg.TunnelType
if tunnelType == "" {
tunnelType = protocol.TunnelTypeTCP
}
numCPU := runtime.NumCPU()
minSessions := cfg.PoolMin
if minSessions <= 0 {
minSessions = 2
}
maxSessions := cfg.PoolMax
if maxSessions <= 0 {
maxSessions = max(numCPU*16, minSessions)
}
if maxSessions < minSessions {
maxSessions = minSessions
}
initialSessions := cfg.PoolSize
if initialSessions <= 0 {
initialSessions = 4
}
initialSessions = min(max(initialSessions, minSessions), maxSessions)
ctx, cancel := context.WithCancel(context.Background())
c := &PoolClient{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: tunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
minSessions: minSessions,
maxSessions: maxSessions,
initialSessions: initialSessions,
stats: stats.NewTrafficStats(),
ctx: ctx,
cancel: cancel,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
dataSessions: make(map[string]*sessionHandle),
logger: logger,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
c.httpClient = newLocalHTTPClient(tunnelType)
}
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
return c
}
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dialTLS()
if err != nil {
return err
}
maxData := max(c.maxSessions-1, 0)
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
ConnectionType: "primary",
PoolCapabilities: &protocol.PoolCapabilities{
MaxDataConns: maxData,
Version: 1,
},
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to marshal request: %w", err)
}
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to send registration: %w", err)
}
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ack, err := protocol.ReadFrame(primaryConn)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to read register ack: %w", err)
}
defer ack.Release()
_ = primaryConn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = primaryConn.Close()
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = primaryConn.Close()
return fmt.Errorf("registration error")
}
if ack.Type != protocol.FrameTypeRegisterAck {
_ = primaryConn.Close()
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to parse register response: %w", err)
}
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
if resp.SupportsDataConn && resp.TunnelID != "" {
c.tunnelID = resp.TunnelID
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Server(primaryConn, yamuxCfg)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
primary := &sessionHandle{
id: "primary",
conn: primaryConn,
session: session,
}
primary.touch()
c.primary = primary
c.wg.Add(1)
go func() {
defer c.wg.Done()
<-c.stopCh
}()
c.wg.Add(1)
go c.acceptLoop(primary, true)
c.wg.Add(1)
go c.sessionWatcher(primary, true)
c.wg.Add(1)
go c.pingLoop(primary)
if c.tunnelID != "" {
c.mu.Lock()
c.desiredTotal = c.initialSessions
c.mu.Unlock()
c.ensureSessions()
c.wg.Add(1)
go c.scalerLoop()
}
go func() {
c.wg.Wait()
close(c.doneCh)
}()
return nil
}
func (c *PoolClient) dialTLS() (net.Conn, error) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
_ = conn.Close()
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
return conn, nil
}
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
for {
select {
case <-c.stopCh:
return
default:
}
stream, err := h.session.Accept()
if err != nil {
if c.IsClosed() || isExpectedCloseError(err) {
return
}
if isPrimary {
c.logger.Debug("Primary session accept failed", zap.Error(err))
_ = c.Close()
return
}
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
c.removeDataSession(h.id)
return
}
h.active.Add(1)
h.touch()
c.stats.AddRequest()
c.stats.IncActiveConnections()
c.wg.Add(1)
go c.handleStream(h, stream)
}
}
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
select {
case <-c.stopCh:
return
case <-h.session.CloseChan():
if isPrimary {
_ = c.Close()
return
}
c.removeDataSession(h.id)
}
}
func (c *PoolClient) pingLoop(h *sessionHandle) {
defer c.wg.Done()
const maxConsecutiveFailures = 3
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
consecutiveFailures := 0
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
if h.session == nil || h.session.IsClosed() {
return
}
latency, err := h.session.Ping()
if err != nil {
consecutiveFailures++
c.logger.Debug("Ping failed",
zap.String("session_id", h.id),
zap.Int("consecutive_failures", consecutiveFailures),
zap.Error(err),
)
if consecutiveFailures >= maxConsecutiveFailures {
c.logger.Warn("Session ping failed too many times, closing",
zap.String("session_id", h.id),
zap.Int("failures", consecutiveFailures),
)
if h.id == "primary" {
_ = c.Close()
return
}
c.removeDataSession(h.id)
return
}
continue
}
consecutiveFailures = 0
h.touch()
c.latencyNanos.Store(int64(latency))
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
cb(latency)
}
}
}
// Close shuts down the client and all sessions.
func (c *PoolClient) Close() error {
var closeErr error
c.once.Do(func() {
c.closed.Store(true)
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
var data []*sessionHandle
var primary *sessionHandle
c.mu.Lock()
for _, h := range c.dataSessions {
data = append(data, h)
}
c.dataSessions = make(map[string]*sessionHandle)
primary = c.primary
c.primary = nil
c.mu.Unlock()
for _, h := range data {
if h == nil {
continue
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
}
if primary != nil {
if primary.session != nil {
closeErr = primary.session.Close()
}
if primary.conn != nil {
_ = primary.conn.Close()
}
}
})
return closeErr
}
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
if cb == nil {
cb = func(time.Duration) {}
}
c.latencyCallback.Store(cb)
}

View File

@@ -0,0 +1,253 @@
package tcp
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"time"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// handleStream routes incoming stream to appropriate handler.
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
defer c.wg.Done()
defer func() {
h.active.Add(-1)
c.stats.DecActiveConnections()
}()
defer stream.Close()
switch c.tunnelType {
case protocol.TunnelTypeHTTP, protocol.TunnelTypeHTTPS:
c.handleHTTPStream(stream)
default:
c.handleTCPStream(stream)
}
}
// handleTCPStream handles raw TCP tunneling.
func (c *PoolClient) handleTCPStream(stream net.Conn) {
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
if err != nil {
c.logger.Debug("Dial local failed", zap.Error(err))
return
}
defer localConn.Close()
if tcpConn, ok := localConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
_ = netutil.PipeWithCallbacksAndBufferSize(
c.ctx,
stream,
localConn,
pool.SizeLarge,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
)
}
// handleHTTPStream handles HTTP/HTTPS proxy requests.
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
cc := netutil.NewCountingConn(stream,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
)
br := bufio.NewReaderSize(cc, 32*1024)
req, err := http.ReadRequest(br)
if err != nil {
return
}
defer req.Body.Close()
_ = stream.SetReadDeadline(time.Time{})
if httputil.IsWebSocketUpgrade(req) {
c.handleWebSocketUpgrade(cc, req)
return
}
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
scheme := "http"
if c.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL := fmt.Sprintf("%s://%s:%d%s", scheme, c.localHost, c.localPort, req.URL.RequestURI())
outReq, err := http.NewRequestWithContext(ctx, req.Method, targetURL, req.Body)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
return
}
origHost := req.Host
httputil.CopyHeaders(outReq.Header, req.Header)
httputil.CleanHopByHopHeaders(outReq.Header)
targetHost := c.localHost
if c.localPort != 80 && c.localPort != 443 {
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
}
outReq.Host = targetHost
outReq.Header.Set("Host", targetHost)
if origHost != "" {
outReq.Header.Set("X-Forwarded-Host", origHost)
}
outReq.Header.Set("X-Forwarded-Proto", "https")
resp, err := c.httpClient.Do(outReq)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Local service unavailable")
return
}
defer resp.Body.Close()
_ = stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
if err := writeResponseHeader(cc, resp); err != nil {
return
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
buf := make([]byte, 32*1024)
for {
nr, er := resp.Body.Read(buf)
if nr > 0 {
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
nw, ew := cc.Write(buf[:nr])
if ew != nil || nr != nw {
break
}
}
if er != nil {
break
}
}
close(done)
}
// handleWebSocketUpgrade handles WebSocket upgrade requests.
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
scheme := "ws"
if c.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "wss"
}
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "WebSocket backend unavailable")
return
}
defer localConn.Close()
if c.tunnelType == protocol.TunnelTypeHTTPS {
tlsConn := tls.Client(localConn, &tls.Config{InsecureSkipVerify: true})
if err := tlsConn.Handshake(); err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "TLS handshake failed")
return
}
localConn = tlsConn
}
req.URL.Scheme = scheme
req.URL.Host = targetAddr
if err := req.Write(localConn); err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
return
}
localBr := bufio.NewReader(localConn)
resp, err := http.ReadResponse(localBr, req)
if err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to read upgrade response")
return
}
if err := resp.Write(cc); err != nil {
return
}
if resp.StatusCode == http.StatusSwitchingProtocols {
_ = netutil.PipeWithCallbacksAndBufferSize(
c.ctx,
cc,
localConn,
pool.SizeLarge,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
)
}
}
// newLocalHTTPClient creates an HTTP client for local service requests.
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {
tlsConfig = &tls.Config{InsecureSkipVerify: true}
}
return &http.Client{
Transport: &http.Transport{
MaxIdleConns: 2000,
MaxIdleConnsPerHost: 1000,
MaxConnsPerHost: 0,
IdleConnTimeout: 180 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
TLSHandshakeTimeout: 5 * time.Second,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 500 * time.Millisecond,
WriteBufferSize: 32 * 1024,
ReadBufferSize: 32 * 1024,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
func writeResponseHeader(w io.Writer, resp *http.Response) error {
statusLine := fmt.Sprintf("HTTP/%d.%d %d %s\r\n",
resp.ProtoMajor, resp.ProtoMinor,
resp.StatusCode, http.StatusText(resp.StatusCode))
if _, err := io.WriteString(w, statusLine); err != nil {
return err
}
if err := resp.Header.Write(w); err != nil {
return err
}
_, err := io.WriteString(w, "\r\n")
return err
}

View File

@@ -0,0 +1,312 @@
package tcp
import (
"fmt"
"io"
"net"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
)
// sessionHandle wraps a yamux session with metadata.
type sessionHandle struct {
id string
conn net.Conn
session *yamux.Session
active atomic.Int64
lastActive atomic.Int64 // unix nanos
closed atomic.Bool
}
func (h *sessionHandle) touch() {
h.lastActive.Store(time.Now().UnixNano())
}
func (h *sessionHandle) lastActiveTime() time.Time {
n := h.lastActive.Load()
if n == 0 {
return time.Time{}
}
return time.Unix(0, n)
}
// scalerLoop monitors load and adjusts session count.
func (c *PoolClient) scalerLoop() {
defer c.wg.Done()
const (
checkInterval = 5 * time.Second
scaleUpCooldown = 5 * time.Second
scaleDownCooldown = 60 * time.Second
capacityPerSession = int64(64)
scaleUpLoad = 0.7
scaleDownLoad = 0.3
)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
c.mu.Lock()
desired := c.desiredTotal
if desired == 0 {
desired = c.initialSessions
c.desiredTotal = desired
}
lastScale := c.lastScale
c.mu.Unlock()
current := c.sessionCount()
if current <= 0 {
continue
}
active := c.stats.GetActiveConnections()
load := float64(active) / float64(int64(current)*capacityPerSession)
sinceLastScale := time.Since(lastScale)
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
c.mu.Lock()
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
c.lastScale = time.Now()
c.mu.Unlock()
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
c.mu.Lock()
c.desiredTotal = max(c.desiredTotal-1, c.minSessions)
c.lastScale = time.Now()
c.mu.Unlock()
}
c.ensureSessions()
}
}
// ensureSessions adjusts session count to match desired.
func (c *PoolClient) ensureSessions() {
if c.IsClosed() || c.tunnelID == "" {
return
}
c.mu.RLock()
desired := c.desiredTotal
c.mu.RUnlock()
desired = min(max(desired, c.minSessions), c.maxSessions)
current := c.sessionCount()
if current < desired {
for i := 0; i < desired-current; i++ {
if err := c.addDataSession(); err != nil {
c.logger.Debug("Add data session failed", zap.Error(err))
break
}
}
return
}
if current > desired {
c.removeIdleSessions(current - desired)
}
}
// addDataSession creates a new data session.
func (c *PoolClient) addDataSession() error {
select {
case <-c.stopCh:
return net.ErrClosed
default:
}
if c.tunnelID == "" {
return fmt.Errorf("server does not support data connections")
}
conn, err := c.dialTLS()
if err != nil {
return err
}
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
req := protocol.DataConnectRequest{
TunnelID: c.tunnelID,
Token: c.token,
ConnectionID: connID,
}
payload, err := json.Marshal(req)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to marshal data connect request: %w", err)
}
if err := protocol.WriteFrame(conn, protocol.NewFrame(protocol.FrameTypeDataConnect, payload)); err != nil {
_ = conn.Close()
return fmt.Errorf("failed to send data connect: %w", err)
}
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
ack, err := protocol.ReadFrame(conn)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to read data connect ack: %w", err)
}
defer ack.Release()
_ = conn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = conn.Close()
return fmt.Errorf("data connect error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = conn.Close()
return fmt.Errorf("data connect error")
}
if ack.Type != protocol.FrameTypeDataConnectAck {
_ = conn.Close()
return fmt.Errorf("unexpected data connect ack frame: %s", ack.Type)
}
var resp protocol.DataConnectResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = conn.Close()
return fmt.Errorf("failed to parse data connect response: %w", err)
}
if !resp.Accepted {
_ = conn.Close()
return fmt.Errorf("data connection rejected: %s", resp.Message)
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Server(conn, yamuxCfg)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
h := &sessionHandle{
id: connID,
conn: conn,
session: session,
}
h.touch()
c.mu.Lock()
c.dataSessions[connID] = h
c.mu.Unlock()
c.wg.Add(1)
go c.acceptLoop(h, false)
c.wg.Add(1)
go c.sessionWatcher(h, false)
return nil
}
// removeIdleSessions removes n idle sessions.
func (c *PoolClient) removeIdleSessions(n int) {
if n <= 0 {
return
}
type candidate struct {
id string
active int64
lastActive time.Time
}
c.mu.RLock()
candidates := make([]candidate, 0, len(c.dataSessions))
for id, h := range c.dataSessions {
candidates = append(candidates, candidate{
id: id,
active: h.active.Load(),
lastActive: h.lastActiveTime(),
})
}
c.mu.RUnlock()
removed := 0
for removed < n {
var best candidate
found := false
for _, cand := range candidates {
if cand.active != 0 {
continue
}
if !found || cand.lastActive.Before(best.lastActive) {
best = cand
found = true
}
}
if !found {
return
}
if c.removeDataSession(best.id) {
removed++
}
for i := range candidates {
if candidates[i].id == best.id {
candidates[i].active = 1
break
}
}
}
}
// removeDataSession removes a data session by ID.
func (c *PoolClient) removeDataSession(id string) bool {
var h *sessionHandle
c.mu.Lock()
h = c.dataSessions[id]
delete(c.dataSessions, id)
c.mu.Unlock()
if h == nil {
return false
}
if !h.closed.CompareAndSwap(false, true) {
return false
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
return true
}
// sessionCount returns the total number of active sessions.
func (c *PoolClient) sessionCount() int {
c.mu.RLock()
defer c.mu.RUnlock()
count := len(c.dataSessions)
if c.primary != nil {
count++
}
return count
}

View File

@@ -1,227 +0,0 @@
package tcp
import (
"sync"
"sync/atomic"
"time"
)
// TrafficStats tracks traffic statistics for a tunnel connection
type TrafficStats struct {
// Total bytes
totalBytesIn int64
totalBytesOut int64
// Request counts
totalRequests int64
// For speed calculation
lastBytesIn int64
lastBytesOut int64
lastTime time.Time
speedMu sync.Mutex
// Current speed (bytes per second)
speedIn int64
speedOut int64
// Start time
startTime time.Time
}
// NewTrafficStats creates a new traffic stats tracker
func NewTrafficStats() *TrafficStats {
now := time.Now()
return &TrafficStats{
startTime: now,
lastTime: now,
}
}
// AddBytesIn adds incoming bytes to the counter
func (s *TrafficStats) AddBytesIn(n int64) {
atomic.AddInt64(&s.totalBytesIn, n)
}
// AddBytesOut adds outgoing bytes to the counter
func (s *TrafficStats) AddBytesOut(n int64) {
atomic.AddInt64(&s.totalBytesOut, n)
}
// AddRequest increments the request counter
func (s *TrafficStats) AddRequest() {
atomic.AddInt64(&s.totalRequests, 1)
}
// GetTotalBytesIn returns total incoming bytes
func (s *TrafficStats) GetTotalBytesIn() int64 {
return atomic.LoadInt64(&s.totalBytesIn)
}
// GetTotalBytesOut returns total outgoing bytes
func (s *TrafficStats) GetTotalBytesOut() int64 {
return atomic.LoadInt64(&s.totalBytesOut)
}
// GetTotalRequests returns total request count
func (s *TrafficStats) GetTotalRequests() int64 {
return atomic.LoadInt64(&s.totalRequests)
}
// GetTotalBytes returns total bytes (in + out)
func (s *TrafficStats) GetTotalBytes() int64 {
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
}
// UpdateSpeed calculates current transfer speed
// Should be called periodically (e.g., every second)
func (s *TrafficStats) UpdateSpeed() {
s.speedMu.Lock()
defer s.speedMu.Unlock()
now := time.Now()
elapsed := now.Sub(s.lastTime).Seconds()
if elapsed < 0.1 {
return // Avoid division by zero or too frequent updates
}
currentIn := atomic.LoadInt64(&s.totalBytesIn)
currentOut := atomic.LoadInt64(&s.totalBytesOut)
deltaIn := currentIn - s.lastBytesIn
deltaOut := currentOut - s.lastBytesOut
s.speedIn = int64(float64(deltaIn) / elapsed)
s.speedOut = int64(float64(deltaOut) / elapsed)
s.lastBytesIn = currentIn
s.lastBytesOut = currentOut
s.lastTime = now
}
// GetSpeedIn returns current incoming speed in bytes per second
func (s *TrafficStats) GetSpeedIn() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedIn
}
// GetSpeedOut returns current outgoing speed in bytes per second
func (s *TrafficStats) GetSpeedOut() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedOut
}
// GetUptime returns how long the connection has been active
func (s *TrafficStats) GetUptime() time.Duration {
return time.Since(s.startTime)
}
// Snapshot returns a snapshot of all stats
type StatsSnapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
Uptime time.Duration
}
// GetSnapshot returns a snapshot of current stats
func (s *TrafficStats) GetSnapshot() StatsSnapshot {
s.speedMu.Lock()
speedIn := s.speedIn
speedOut := s.speedOut
s.speedMu.Unlock()
totalIn := atomic.LoadInt64(&s.totalBytesIn)
totalOut := atomic.LoadInt64(&s.totalBytesOut)
return StatsSnapshot{
TotalBytesIn: totalIn,
TotalBytesOut: totalOut,
TotalBytes: totalIn + totalOut,
TotalRequests: atomic.LoadInt64(&s.totalRequests),
SpeedIn: speedIn,
SpeedOut: speedOut,
Uptime: time.Since(s.startTime),
}
}
// FormatBytes formats bytes to human readable string
func FormatBytes(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case bytes >= GB:
return formatFloat(float64(bytes)/float64(GB)) + " GB"
case bytes >= MB:
return formatFloat(float64(bytes)/float64(MB)) + " MB"
case bytes >= KB:
return formatFloat(float64(bytes)/float64(KB)) + " KB"
default:
return formatInt(bytes) + " B"
}
}
// FormatSpeed formats speed (bytes per second) to human readable string
func FormatSpeed(bytesPerSec int64) string {
if bytesPerSec == 0 {
return "0 B/s"
}
return FormatBytes(bytesPerSec) + "/s"
}
func formatFloat(f float64) string {
if f >= 100 {
return formatInt(int64(f))
} else if f >= 10 {
return formatOneDecimal(f)
}
return formatTwoDecimal(f)
}
func formatInt(i int64) string {
return intToStr(i)
}
func formatOneDecimal(f float64) string {
i := int64(f * 10)
whole := i / 10
frac := i % 10
return intToStr(whole) + "." + intToStr(frac)
}
func formatTwoDecimal(f float64) string {
i := int64(f * 100)
whole := i / 100
frac := i % 100
if frac < 10 {
return intToStr(whole) + ".0" + intToStr(frac)
}
return intToStr(whole) + "." + intToStr(frac)
}
func intToStr(i int64) string {
if i == 0 {
return "0"
}
if i < 0 {
return "-" + intToStr(-i)
}
var buf [20]byte
pos := len(buf)
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
return string(buf[pos:])
}