mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-26 22:31:35 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ var attachCmd = &cobra.Command{
|
||||
Examples:
|
||||
drip attach List running tunnels and select one
|
||||
drip attach http 3000 Attach to HTTP tunnel on port 3000
|
||||
drip attach https 8443 Attach to HTTPS tunnel on port 8443
|
||||
drip attach tcp 5432 Attach to TCP tunnel on port 5432
|
||||
|
||||
Press Ctrl+C to detach (tunnel will continue running).`,
|
||||
@@ -36,7 +37,7 @@ func init() {
|
||||
rootCmd.AddCommand(attachCmd)
|
||||
}
|
||||
|
||||
func runAttach(cmd *cobra.Command, args []string) error {
|
||||
func runAttach(_ *cobra.Command, args []string) error {
|
||||
CleanupStaleDaemons()
|
||||
|
||||
daemons, err := ListAllDaemons()
|
||||
@@ -59,8 +60,8 @@ func runAttach(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if len(args) == 2 {
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
|
||||
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(args[1])
|
||||
@@ -119,10 +120,13 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
uptime := time.Since(d.StartTime)
|
||||
|
||||
var typeStr string
|
||||
if d.Type == "http" {
|
||||
typeStr = ui.Success("HTTP")
|
||||
} else {
|
||||
typeStr = ui.Highlight("TCP")
|
||||
switch d.Type {
|
||||
case "http":
|
||||
typeStr = ui.Highlight("HTTP")
|
||||
case "https":
|
||||
typeStr = ui.Highlight("HTTPS")
|
||||
default:
|
||||
typeStr = ui.Cyan("TCP")
|
||||
}
|
||||
|
||||
table.AddRow([]string{
|
||||
@@ -196,7 +200,7 @@ func attachToDaemon(daemon *DaemonInfo) error {
|
||||
select {
|
||||
case <-sigCh:
|
||||
if tailCmd.Process != nil {
|
||||
tailCmd.Process.Kill()
|
||||
_ = tailCmd.Process.Kill()
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println(ui.Warning("Detached from tunnel (tunnel is still running)"))
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -77,7 +77,7 @@ func init() {
|
||||
rootCmd.AddCommand(configCmd)
|
||||
}
|
||||
|
||||
func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
func runConfigInit(_ *cobra.Command, _ []string) error {
|
||||
fmt.Print(ui.RenderConfigInit())
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
@@ -109,7 +109,7 @@ func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
func runConfigShow(_ *cobra.Command, _ []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -137,7 +137,7 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
func runConfigSet(_ *cobra.Command, _ []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
cfg = &config.ClientConfig{
|
||||
@@ -173,7 +173,7 @@ func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
func runConfigReset(_ *cobra.Command, _ []string) error {
|
||||
configPath := config.DefaultClientConfigPath()
|
||||
|
||||
if !config.ConfigExists("") {
|
||||
@@ -202,7 +202,7 @@ func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigValidate(cmd *cobra.Command, args []string) error {
|
||||
func runConfigValidate(_ *cobra.Command, _ []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
fmt.Println(ui.Error("Failed to load configuration"))
|
||||
|
||||
@@ -5,10 +5,9 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
@@ -194,8 +193,8 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
return fmt.Errorf("failed to start daemon: %w", err)
|
||||
}
|
||||
|
||||
// Don't wait for the process - let it run in background
|
||||
// The child process will save its own daemon info after connecting
|
||||
_ = logFile.Close()
|
||||
_ = devNull.Close()
|
||||
|
||||
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath))
|
||||
|
||||
@@ -220,28 +219,16 @@ func CleanupStaleDaemons() error {
|
||||
|
||||
// FormatDuration formats a duration in a human-readable way
|
||||
func FormatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
switch {
|
||||
case d < time.Minute:
|
||||
return fmt.Sprintf("%ds", int(d.Seconds()))
|
||||
} else if d < time.Hour {
|
||||
case d < time.Hour:
|
||||
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
|
||||
} else if d < 24*time.Hour {
|
||||
case d < 24*time.Hour:
|
||||
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
|
||||
}
|
||||
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
|
||||
// ParsePortFromArgs extracts the port number from command arguments
|
||||
func ParsePortFromArgs(args []string) (int, error) {
|
||||
for _, arg := range args {
|
||||
if len(arg) > 0 && arg[0] == '-' {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(arg)
|
||||
if err == nil && port > 0 && port <= 65535 {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("port number not found in arguments")
|
||||
}
|
||||
|
||||
@@ -2,22 +2,14 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
maxReconnectAttempts = 5
|
||||
reconnectInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
subdomain string
|
||||
daemonMode bool
|
||||
@@ -52,55 +44,19 @@ func init() {
|
||||
rootCmd.AddCommand(httpCmd)
|
||||
}
|
||||
|
||||
func runHTTP(cmd *cobra.Command, args []string) error {
|
||||
func runHTTP(_ *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if daemonMode && !daemonMarker {
|
||||
daemonArgs := append([]string{"http"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("http", port, daemonArgs)
|
||||
return StartDaemon("http", port, buildDaemonArgs("http", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip http %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
serverAddr, token, err := resolveServerAddrAndToken("http", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
@@ -115,15 +71,7 @@ Please run 'drip config init' first, or use flags:
|
||||
|
||||
var daemon *DaemonInfo
|
||||
if daemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "http",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
daemon = newDaemonInfo("http", port, subdomain, serverAddr)
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
|
||||
@@ -2,24 +2,14 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
httpsSubdomain string
|
||||
httpsDaemonMode bool
|
||||
httpsDaemonMarker bool
|
||||
httpsLocalAddress string
|
||||
)
|
||||
|
||||
var httpsCmd = &cobra.Command{
|
||||
Use: "https <port>",
|
||||
Short: "Start HTTPS tunnel",
|
||||
@@ -39,86 +29,42 @@ Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
}
|
||||
|
||||
func init() {
|
||||
httpsCmd.Flags().StringVarP(&httpsSubdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
httpsCmd.Flags().BoolVarP(&httpsDaemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
httpsCmd.Flags().StringVarP(&httpsLocalAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
httpsCmd.Flags().BoolVar(&httpsDaemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
httpsCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
httpsCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpsCmd)
|
||||
}
|
||||
|
||||
func runHTTPS(cmd *cobra.Command, args []string) error {
|
||||
func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if httpsDaemonMode && !httpsDaemonMarker {
|
||||
daemonArgs := append([]string{"https"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if httpsSubdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", httpsSubdomain)
|
||||
}
|
||||
if httpsLocalAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", httpsLocalAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("https", port, daemonArgs)
|
||||
if daemonMode && !daemonMarker {
|
||||
return StartDaemon("https", port, buildDaemonArgs("https", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip https %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
serverAddr, token, err := resolveServerAddrAndToken("https", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
TunnelType: protocol.TunnelTypeHTTPS,
|
||||
LocalHost: httpsLocalAddress,
|
||||
LocalHost: localAddress,
|
||||
LocalPort: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Subdomain: subdomain,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
if httpsDaemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "https",
|
||||
Port: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if daemonMarker {
|
||||
daemon = newDaemonInfo("https", port, subdomain, serverAddr)
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
|
||||
@@ -8,13 +8,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
interactiveMode bool
|
||||
)
|
||||
var interactiveMode bool
|
||||
|
||||
var listCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
@@ -26,7 +24,7 @@ Example:
|
||||
drip list -i Interactive mode (select to attach/stop)
|
||||
|
||||
This command shows:
|
||||
- Tunnel type (HTTP/TCP)
|
||||
- Tunnel type (HTTP/HTTPS/TCP)
|
||||
- Local port being tunneled
|
||||
- Public URL
|
||||
- Process ID (PID)
|
||||
@@ -44,7 +42,7 @@ func init() {
|
||||
rootCmd.AddCommand(listCmd)
|
||||
}
|
||||
|
||||
func runList(cmd *cobra.Command, args []string) error {
|
||||
func runList(_ *cobra.Command, _ []string) error {
|
||||
CleanupStaleDaemons()
|
||||
|
||||
daemons, err := ListAllDaemons()
|
||||
@@ -99,7 +97,7 @@ func runList(cmd *cobra.Command, args []string) error {
|
||||
|
||||
fmt.Print(table.Render())
|
||||
|
||||
if interactiveMode || shouldPromptForAction() {
|
||||
if interactiveMode {
|
||||
return runInteractiveList(daemons)
|
||||
}
|
||||
|
||||
@@ -114,13 +112,6 @@ func runList(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldPromptForAction() bool {
|
||||
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
var runningDaemons []*DaemonInfo
|
||||
for _, d := range daemons {
|
||||
@@ -161,9 +152,9 @@ func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
"",
|
||||
ui.Muted("What would you like to do?"),
|
||||
"",
|
||||
ui.Cyan(" 1.") + " Attach (view logs)",
|
||||
ui.Cyan(" 2.") + " Stop tunnel",
|
||||
ui.Muted(" q.") + " Cancel",
|
||||
ui.Cyan(" 1.")+" Attach (view logs)",
|
||||
ui.Cyan(" 2.")+" Stop tunnel",
|
||||
ui.Muted(" q.")+" Cancel",
|
||||
))
|
||||
|
||||
fmt.Print(ui.Muted("Choose an action: "))
|
||||
|
||||
@@ -3,7 +3,7 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/shared/ui"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -56,14 +56,12 @@ func init() {
|
||||
versionCmd.Flags().BoolVar(&versionPlain, "short", false, "Print version information without styling")
|
||||
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
// http and tcp commands are added in their respective init() functions
|
||||
// config command is added in config.go init() function
|
||||
}
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print version information",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, _ []string) {
|
||||
if versionPlain {
|
||||
fmt.Printf("Version: %s\nGit Commit: %s\nBuild Time: %s\n", Version, GitCommit, BuildTime)
|
||||
return
|
||||
|
||||
@@ -59,7 +59,7 @@ func init() {
|
||||
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
func runServer(_ *cobra.Command, _ []string) error {
|
||||
if serverTLSCert == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
|
||||
}
|
||||
@@ -126,11 +126,9 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
|
||||
|
||||
responseHandler := proxy.NewResponseHandler(logger)
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, serverDomain, serverAuthToken)
|
||||
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
|
||||
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler)
|
||||
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
|
||||
@@ -28,7 +28,7 @@ func init() {
|
||||
rootCmd.AddCommand(stopCmd)
|
||||
}
|
||||
|
||||
func runStop(cmd *cobra.Command, args []string) error {
|
||||
func runStop(_ *cobra.Command, args []string) error {
|
||||
if args[0] == "all" {
|
||||
return stopAllDaemons()
|
||||
}
|
||||
|
||||
@@ -2,13 +2,10 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -47,55 +44,19 @@ func init() {
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
}
|
||||
|
||||
func runTCP(cmd *cobra.Command, args []string) error {
|
||||
func runTCP(_ *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if daemonMode && !daemonMarker {
|
||||
daemonArgs := append([]string{"tcp"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("tcp", port, daemonArgs)
|
||||
return StartDaemon("tcp", port, buildDaemonArgs("tcp", args, subdomain, localAddress))
|
||||
}
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
serverAddr, token, err := resolveServerAddrAndToken("tcp", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
@@ -110,15 +71,7 @@ Please run 'drip config init' first, or use flags:
|
||||
|
||||
var daemon *DaemonInfo
|
||||
if daemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "tcp",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
daemon = newDaemonInfo("tcp", port, subdomain, serverAddr)
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
|
||||
67
internal/client/cli/tunnel_helpers.go
Normal file
67
internal/client/cli/tunnel_helpers.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAddress string) []string {
|
||||
daemonArgs := append([]string{tunnelType}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
|
||||
return daemonArgs
|
||||
}
|
||||
|
||||
func resolveServerAddrAndToken(tunnelType string, port int) (string, string, error) {
|
||||
if serverURL != "" {
|
||||
return serverURL, authToken, nil
|
||||
}
|
||||
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip %s %d --server SERVER:PORT --token TOKEN`, tunnelType, port)
|
||||
}
|
||||
|
||||
if cfg.Server == "" {
|
||||
return "", "", fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
return cfg.Server, cfg.Token, nil
|
||||
}
|
||||
|
||||
func newDaemonInfo(tunnelType string, port int, subdomain string, serverAddr string) *DaemonInfo {
|
||||
return &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: tunnelType,
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,17 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/internal/shared/utils"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// runTunnelWithUI runs a tunnel with the new UI
|
||||
const (
|
||||
maxReconnectAttempts = 5
|
||||
reconnectInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
@@ -28,7 +32,7 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
connector := tcp.NewTunnelClient(connConfig, logger)
|
||||
|
||||
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
|
||||
|
||||
@@ -87,8 +91,8 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
renderTicker := time.NewTicker(1 * time.Second)
|
||||
defer renderTicker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
lastRenderedLines := 0
|
||||
@@ -97,27 +101,38 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
case <-renderTicker.C:
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
status.Latency = lastLatency
|
||||
status.BytesIn = snapshot.TotalBytesIn
|
||||
status.BytesOut = snapshot.TotalBytesOut
|
||||
status.SpeedIn = float64(snapshot.SpeedIn)
|
||||
status.SpeedOut = float64(snapshot.SpeedOut)
|
||||
status.TotalRequest = snapshot.TotalRequests
|
||||
|
||||
statsView := ui.RenderTunnelStats(status)
|
||||
if lastRenderedLines > 0 {
|
||||
fmt.Print(clearLines(lastRenderedLines))
|
||||
}
|
||||
|
||||
fmt.Print(statsView)
|
||||
lastRenderedLines = countRenderedLines(statsView)
|
||||
if stats == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
status.Latency = lastLatency
|
||||
status.BytesIn = snapshot.TotalBytesIn
|
||||
status.BytesOut = snapshot.TotalBytesOut
|
||||
status.SpeedIn = float64(snapshot.SpeedIn)
|
||||
status.SpeedOut = float64(snapshot.SpeedOut)
|
||||
|
||||
if status.Type == "tcp" {
|
||||
if snapshot.SpeedIn == 0 && snapshot.SpeedOut == 0 {
|
||||
status.TotalRequest = 0
|
||||
} else {
|
||||
status.TotalRequest = snapshot.ActiveConnections
|
||||
}
|
||||
} else {
|
||||
status.TotalRequest = snapshot.TotalRequests
|
||||
}
|
||||
|
||||
statsView := ui.RenderTunnelStats(status)
|
||||
if lastRenderedLines > 0 {
|
||||
fmt.Print(clearLines(lastRenderedLines))
|
||||
}
|
||||
|
||||
fmt.Print(statsView)
|
||||
lastRenderedLines = countRenderedLines(statsView)
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// RenderConfigInit renders config initialization UI
|
||||
func RenderConfigInit() string {
|
||||
title := "Drip Configuration Setup"
|
||||
box := boxStyle.Width(50)
|
||||
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
|
||||
}
|
||||
|
||||
// RenderConfigShow renders the config display
|
||||
func RenderConfigShow(server, token string, tokenHidden bool, tlsEnabled bool, configPath string) string {
|
||||
lines := []string{
|
||||
KeyValue("Server", server),
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
if tokenHidden {
|
||||
if len(token) > 10 {
|
||||
displayToken := token[:3] + "***" + token[len(token)-3:]
|
||||
lines = append(lines, KeyValue("Token", Muted(displayToken+" (hidden)")))
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", Muted(token[:3]+"*** (hidden)")))
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", token))
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", Muted("(not set)")))
|
||||
}
|
||||
|
||||
tlsStatus := "enabled"
|
||||
if !tlsEnabled {
|
||||
tlsStatus = "disabled"
|
||||
}
|
||||
lines = append(lines, KeyValue("TLS", tlsStatus))
|
||||
lines = append(lines, KeyValue("Config", Muted(configPath)))
|
||||
|
||||
return Info("Current Configuration", lines...)
|
||||
}
|
||||
|
||||
// RenderConfigSaved renders config saved message
|
||||
func RenderConfigSaved(configPath string) string {
|
||||
return SuccessBox(
|
||||
"Configuration Saved",
|
||||
Muted("Config saved to: ")+configPath,
|
||||
"",
|
||||
Muted("You can now use 'drip' without --server and --token flags"),
|
||||
)
|
||||
}
|
||||
|
||||
// RenderConfigUpdated renders config updated message
|
||||
func RenderConfigUpdated(updates []string) string {
|
||||
lines := make([]string, len(updates)+1)
|
||||
for i, update := range updates {
|
||||
lines[i] = Success(update)
|
||||
}
|
||||
lines[len(updates)] = ""
|
||||
lines = append(lines, Muted("Configuration has been updated"))
|
||||
return SuccessBox("Configuration Updated", lines...)
|
||||
}
|
||||
|
||||
// RenderConfigDeleted renders config deleted message
|
||||
func RenderConfigDeleted() string {
|
||||
return SuccessBox("Configuration Deleted", Muted("Configuration file has been removed"))
|
||||
}
|
||||
|
||||
// RenderConfigValidation renders config validation results
|
||||
func RenderConfigValidation(serverValid bool, serverMsg string, tokenSet bool, tokenMsg string, tlsEnabled bool) string {
|
||||
lines := []string{}
|
||||
|
||||
if serverValid {
|
||||
lines = append(lines, Success(serverMsg))
|
||||
} else {
|
||||
lines = append(lines, Error(serverMsg))
|
||||
}
|
||||
|
||||
if tokenSet {
|
||||
lines = append(lines, Success(tokenMsg))
|
||||
} else {
|
||||
lines = append(lines, Warning(tokenMsg))
|
||||
}
|
||||
|
||||
if tlsEnabled {
|
||||
lines = append(lines, Success("TLS is enabled"))
|
||||
} else {
|
||||
lines = append(lines, Warning("TLS is disabled (not recommended for production)"))
|
||||
}
|
||||
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, Muted("Configuration validation complete"))
|
||||
|
||||
if serverValid && tokenSet && tlsEnabled {
|
||||
return SuccessBox("Configuration Valid", lines...)
|
||||
}
|
||||
return WarningBox("Configuration Validation", lines...)
|
||||
}
|
||||
|
||||
// RenderDaemonStarted renders daemon started message
|
||||
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
|
||||
lines := []string{
|
||||
KeyValue("Type", Highlight(tunnelType)),
|
||||
KeyValue("Port", fmt.Sprintf("%d", port)),
|
||||
KeyValue("PID", fmt.Sprintf("%d", pid)),
|
||||
"",
|
||||
Muted("Commands:"),
|
||||
Cyan(" drip list") + Muted(" Check tunnel status"),
|
||||
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
|
||||
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
|
||||
"",
|
||||
Muted("Logs: ") + mutedStyle.Render(logPath),
|
||||
}
|
||||
return SuccessBox("Tunnel Started in Background", lines...)
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
// Colors inspired by Vercel CLI
|
||||
successColor = lipgloss.Color("#0070F3")
|
||||
warningColor = lipgloss.Color("#F5A623")
|
||||
errorColor = lipgloss.Color("#E00")
|
||||
mutedColor = lipgloss.Color("#888")
|
||||
highlightColor = lipgloss.Color("#0070F3")
|
||||
cyanColor = lipgloss.Color("#50E3C2")
|
||||
|
||||
// Box styles - Vercel-like clean box
|
||||
boxStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
Padding(1, 2).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
successBoxStyle = boxStyle.BorderForeground(successColor)
|
||||
|
||||
warningBoxStyle = boxStyle.BorderForeground(warningColor)
|
||||
|
||||
errorBoxStyle = boxStyle.BorderForeground(errorColor)
|
||||
|
||||
// Text styles
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
subtitleStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor)
|
||||
|
||||
successStyle = lipgloss.NewStyle().
|
||||
Foreground(successColor).
|
||||
Bold(true)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(errorColor).
|
||||
Bold(true)
|
||||
|
||||
warningStyle = lipgloss.NewStyle().
|
||||
Foreground(warningColor).
|
||||
Bold(true)
|
||||
|
||||
mutedStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor)
|
||||
|
||||
highlightStyle = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Bold(true)
|
||||
|
||||
cyanStyle = lipgloss.NewStyle().
|
||||
Foreground(cyanColor)
|
||||
|
||||
urlStyle = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Underline(true).
|
||||
Bold(true)
|
||||
|
||||
labelStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Width(12)
|
||||
|
||||
valueStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
// Table styles (padding handled manually for consistent Windows output)
|
||||
tableHeaderStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// Success returns a styled success message
|
||||
func Success(text string) string {
|
||||
return successStyle.Render("✓ " + text)
|
||||
}
|
||||
|
||||
// Error returns a styled error message
|
||||
func Error(text string) string {
|
||||
return errorStyle.Render("✗ " + text)
|
||||
}
|
||||
|
||||
// Warning returns a styled warning message
|
||||
func Warning(text string) string {
|
||||
return warningStyle.Render("⚠ " + text)
|
||||
}
|
||||
|
||||
// Muted returns a styled muted text
|
||||
func Muted(text string) string {
|
||||
return mutedStyle.Render(text)
|
||||
}
|
||||
|
||||
// Highlight returns a styled highlighted text
|
||||
func Highlight(text string) string {
|
||||
return highlightStyle.Render(text)
|
||||
}
|
||||
|
||||
// Cyan returns a styled cyan text
|
||||
func Cyan(text string) string {
|
||||
return cyanStyle.Render(text)
|
||||
}
|
||||
|
||||
// URL returns a styled URL
|
||||
func URL(text string) string {
|
||||
return urlStyle.Render(text)
|
||||
}
|
||||
|
||||
// Title returns a styled title
|
||||
func Title(text string) string {
|
||||
return titleStyle.Render(text)
|
||||
}
|
||||
|
||||
// Subtitle returns a styled subtitle
|
||||
func Subtitle(text string) string {
|
||||
return subtitleStyle.Render(text)
|
||||
}
|
||||
|
||||
// KeyValue returns a styled key-value pair
|
||||
func KeyValue(key, value string) string {
|
||||
return labelStyle.Render(key+":") + " " + valueStyle.Render(value)
|
||||
}
|
||||
|
||||
// Info renders an info box (Vercel-style)
|
||||
func Info(title string, lines ...string) string {
|
||||
content := titleStyle.Render(title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return boxStyle.Render(content)
|
||||
}
|
||||
|
||||
// SuccessBox renders a success box
|
||||
func SuccessBox(title string, lines ...string) string {
|
||||
content := successStyle.Render("✓ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return successBoxStyle.Render(content)
|
||||
}
|
||||
|
||||
// WarningBox renders a warning box
|
||||
func WarningBox(title string, lines ...string) string {
|
||||
content := warningStyle.Render("⚠ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return warningBoxStyle.Render(content)
|
||||
}
|
||||
|
||||
// ErrorBox renders an error box
|
||||
func ErrorBox(title string, lines ...string) string {
|
||||
content := errorStyle.Render("✗ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return errorBoxStyle.Render(content)
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Table represents a simple table for CLI output
|
||||
type Table struct {
|
||||
headers []string
|
||||
rows [][]string
|
||||
title string
|
||||
}
|
||||
|
||||
// NewTable creates a new table
|
||||
func NewTable(headers []string) *Table {
|
||||
return &Table{
|
||||
headers: headers,
|
||||
rows: [][]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// WithTitle sets the table title
|
||||
func (t *Table) WithTitle(title string) *Table {
|
||||
t.title = title
|
||||
return t
|
||||
}
|
||||
|
||||
// AddRow adds a row to the table
|
||||
func (t *Table) AddRow(row []string) *Table {
|
||||
t.rows = append(t.rows, row)
|
||||
return t
|
||||
}
|
||||
|
||||
// Render renders the table (Vercel-style)
|
||||
func (t *Table) Render() string {
|
||||
if len(t.rows) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate column widths
|
||||
colWidths := make([]int, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
colWidths[i] = lipgloss.Width(header)
|
||||
}
|
||||
for _, row := range t.rows {
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
width := lipgloss.Width(cell)
|
||||
if width > colWidths[i] {
|
||||
colWidths[i] = width
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Title
|
||||
if t.title != "" {
|
||||
output.WriteString("\n")
|
||||
output.WriteString(titleStyle.Render(t.title))
|
||||
output.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Header
|
||||
headerParts := make([]string, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
styled := tableHeaderStyle.Render(header)
|
||||
headerParts[i] = padRight(styled, colWidths[i])
|
||||
}
|
||||
output.WriteString(strings.Join(headerParts, " "))
|
||||
output.WriteString("\n")
|
||||
|
||||
// Separator line
|
||||
separatorChar := "─"
|
||||
if runtime.GOOS == "windows" {
|
||||
separatorChar = "-"
|
||||
}
|
||||
separatorParts := make([]string, len(t.headers))
|
||||
for i := range t.headers {
|
||||
separatorParts[i] = mutedStyle.Render(strings.Repeat(separatorChar, colWidths[i]))
|
||||
}
|
||||
output.WriteString(strings.Join(separatorParts, " "))
|
||||
output.WriteString("\n")
|
||||
|
||||
// Rows
|
||||
for _, row := range t.rows {
|
||||
rowParts := make([]string, len(t.headers))
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
rowParts[i] = padRight(cell, colWidths[i])
|
||||
}
|
||||
}
|
||||
output.WriteString(strings.Join(rowParts, " "))
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
output.WriteString("\n")
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// padRight pads
|
||||
func padRight(text string, targetWidth int) string {
|
||||
visibleWidth := lipgloss.Width(text)
|
||||
if visibleWidth >= targetWidth {
|
||||
return text
|
||||
}
|
||||
padding := strings.Repeat(" ", targetWidth-visibleWidth)
|
||||
return text + padding
|
||||
}
|
||||
|
||||
// Print prints the table
|
||||
func (t *Table) Print() {
|
||||
fmt.Print(t.Render())
|
||||
}
|
||||
|
||||
// RenderList renders a simple list with bullet points
|
||||
func RenderList(items []string) string {
|
||||
bullet := "•"
|
||||
if runtime.GOOS == "windows" {
|
||||
bullet = "*"
|
||||
}
|
||||
var output strings.Builder
|
||||
for _, item := range items {
|
||||
output.WriteString(mutedStyle.Render(" " + bullet + " "))
|
||||
output.WriteString(item)
|
||||
output.WriteString("\n")
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// RenderNumberedList renders a numbered list
|
||||
func RenderNumberedList(items []string) string {
|
||||
var output strings.Builder
|
||||
for i, item := range items {
|
||||
output.WriteString(mutedStyle.Render(fmt.Sprintf(" %d. ", i+1)))
|
||||
output.WriteString(item)
|
||||
output.WriteString("\n")
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const (
|
||||
tunnelCardWidth = 76
|
||||
statsColumnWidth = 32
|
||||
)
|
||||
|
||||
var (
|
||||
latencyFastColor = lipgloss.Color("#22c55e") // green
|
||||
latencyYellowColor = lipgloss.Color("#eab308") // yellow
|
||||
latencyOrangeColor = lipgloss.Color("#f97316") // orange
|
||||
latencyRedColor = lipgloss.Color("#ef4444") // red
|
||||
)
|
||||
|
||||
// TunnelStatus represents the status of a tunnel
|
||||
type TunnelStatus struct {
|
||||
Type string // "http", "https", "tcp"
|
||||
URL string // Public URL
|
||||
LocalAddr string // Local address
|
||||
Latency time.Duration // Current latency
|
||||
BytesIn int64 // Bytes received
|
||||
BytesOut int64 // Bytes sent
|
||||
SpeedIn float64 // Download speed
|
||||
SpeedOut float64 // Upload speed
|
||||
TotalRequest int64 // Total requests
|
||||
}
|
||||
|
||||
// RenderTunnelConnected renders the tunnel connection card
|
||||
func RenderTunnelConnected(status *TunnelStatus) string {
|
||||
icon, typeStr, accent := tunnelVisuals(status.Type)
|
||||
|
||||
card := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accent).
|
||||
Padding(1, 2).
|
||||
Width(tunnelCardWidth)
|
||||
|
||||
typeBadge := lipgloss.NewStyle().
|
||||
Background(accent).
|
||||
Foreground(lipgloss.Color("#f8fafc")).
|
||||
Bold(true).
|
||||
Padding(0, 1).
|
||||
Render(strings.ToUpper(typeStr) + " TUNNEL")
|
||||
|
||||
headline := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render(icon),
|
||||
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Tunnel Connected"),
|
||||
lipgloss.NewStyle().MarginLeft(2).Render(typeBadge),
|
||||
)
|
||||
|
||||
urlLine := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
urlStyle.Foreground(accent).Render(status.URL),
|
||||
lipgloss.NewStyle().MarginLeft(1).Foreground(mutedColor).Render("(forwarded address)"),
|
||||
)
|
||||
|
||||
forwardLine := lipgloss.NewStyle().
|
||||
MarginLeft(2).
|
||||
Render(Muted("⇢ ") + valueStyle.Render(status.LocalAddr))
|
||||
|
||||
hint := lipgloss.NewStyle().
|
||||
Foreground(latencyOrangeColor).
|
||||
Render("Ctrl+C to stop • reconnects automatically")
|
||||
|
||||
content := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
headline,
|
||||
"",
|
||||
urlLine,
|
||||
forwardLine,
|
||||
"",
|
||||
hint,
|
||||
)
|
||||
|
||||
return "\n" + card.Render(content) + "\n"
|
||||
}
|
||||
|
||||
// RenderTunnelStats renders real-time tunnel statistics in a card
|
||||
func RenderTunnelStats(status *TunnelStatus) string {
|
||||
latencyStr := formatLatency(status.Latency)
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", formatBytes(status.BytesIn), formatBytes(status.BytesOut))
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", formatSpeed(status.SpeedIn), formatSpeed(status.SpeedOut))
|
||||
requestsStr := fmt.Sprintf("%d", status.TotalRequest)
|
||||
|
||||
_, _, accent := tunnelVisuals(status.Type)
|
||||
|
||||
header := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render("◉"),
|
||||
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Live Metrics"),
|
||||
)
|
||||
|
||||
row1 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Latency", latencyStr, statsColumnWidth),
|
||||
statColumn("Requests", highlightStyle.Render(requestsStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
row2 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Traffic", Cyan(trafficStr), statsColumnWidth),
|
||||
statColumn("Speed", warningStyle.Render(speedStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
card := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accent).
|
||||
Padding(1, 2).
|
||||
Width(tunnelCardWidth)
|
||||
|
||||
body := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
header,
|
||||
"",
|
||||
row1,
|
||||
row2,
|
||||
)
|
||||
|
||||
return "\n" + card.Render(body) + "\n"
|
||||
}
|
||||
|
||||
// RenderConnecting renders the connecting message
|
||||
func RenderConnecting(serverAddr string, attempt int, maxAttempts int) string {
|
||||
if attempt == 0 {
|
||||
return Highlight("◌") + " Connecting to " + Muted(serverAddr) + "..."
|
||||
}
|
||||
return Warning(fmt.Sprintf("◌ Reconnecting to %s (attempt %d/%d)...", serverAddr, attempt, maxAttempts))
|
||||
}
|
||||
|
||||
// RenderConnectionFailed renders connection failure message
|
||||
func RenderConnectionFailed(err error) string {
|
||||
return Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
}
|
||||
|
||||
// RenderShuttingDown renders shutdown message
|
||||
func RenderShuttingDown() string {
|
||||
return Warning("⏹ Shutting down...")
|
||||
}
|
||||
|
||||
// RenderConnectionLost renders connection lost message
|
||||
func RenderConnectionLost() string {
|
||||
return Error("⚠ Connection lost!")
|
||||
}
|
||||
|
||||
// RenderRetrying renders retry message
|
||||
func RenderRetrying(interval time.Duration) string {
|
||||
return Muted(fmt.Sprintf(" Retrying in %v...", interval))
|
||||
}
|
||||
|
||||
// formatLatency formats latency with color
|
||||
func formatLatency(d time.Duration) string {
|
||||
if d == 0 {
|
||||
return mutedStyle.Render("measuring...")
|
||||
}
|
||||
|
||||
ms := d.Milliseconds()
|
||||
var style lipgloss.Style
|
||||
|
||||
switch {
|
||||
case ms < 50:
|
||||
style = lipgloss.NewStyle().Foreground(latencyFastColor)
|
||||
case ms < 150:
|
||||
style = lipgloss.NewStyle().Foreground(latencyYellowColor)
|
||||
case ms < 300:
|
||||
style = lipgloss.NewStyle().Foreground(latencyOrangeColor)
|
||||
default:
|
||||
style = lipgloss.NewStyle().Foreground(latencyRedColor)
|
||||
}
|
||||
|
||||
if ms == 0 {
|
||||
us := d.Microseconds()
|
||||
return style.Render(fmt.Sprintf("%dµs", us))
|
||||
}
|
||||
|
||||
return style.Render(fmt.Sprintf("%dms", ms))
|
||||
}
|
||||
|
||||
// formatBytes formats bytes to human readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// formatSpeed formats speed to human readable format
|
||||
func formatSpeed(bytesPerSec float64) string {
|
||||
const unit = 1024.0
|
||||
if bytesPerSec < unit {
|
||||
return fmt.Sprintf("%.0f B/s", bytesPerSec)
|
||||
}
|
||||
div, exp := unit, 0
|
||||
for n := bytesPerSec / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB/s", bytesPerSec/div, "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func statColumn(label, value string, width int) string {
|
||||
labelView := lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Render(strings.ToUpper(label))
|
||||
|
||||
block := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
labelView,
|
||||
lipgloss.NewStyle().MarginLeft(1).Render(value),
|
||||
)
|
||||
|
||||
if width <= 0 {
|
||||
return block
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Render(block)
|
||||
}
|
||||
|
||||
func tunnelVisuals(tunnelType string) (string, string, lipgloss.Color) {
|
||||
switch tunnelType {
|
||||
case "http":
|
||||
return "🚀", "HTTP", lipgloss.Color("#0070F3")
|
||||
case "https":
|
||||
return "🔒", "HTTPS", lipgloss.Color("#2D8CFF")
|
||||
case "tcp":
|
||||
return "🔌", "TCP", lipgloss.Color("#50E3C2")
|
||||
default:
|
||||
return "🌐", strings.ToUpper(tunnelType), lipgloss.Color("#0070F3")
|
||||
}
|
||||
}
|
||||
@@ -1,493 +1,50 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/recovery"
|
||||
"drip/pkg/config"
|
||||
"drip/internal/shared/stats"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LatencyCallback is called when latency is measured
|
||||
type LatencyCallback func(latency time.Duration)
|
||||
|
||||
// Connector manages the TCP connection to the server
|
||||
type Connector struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
token string
|
||||
tunnelType protocol.TunnelType
|
||||
localHost string
|
||||
localPort int
|
||||
subdomain string
|
||||
conn net.Conn
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
registered bool
|
||||
assignedURL string
|
||||
frameHandler *FrameHandler
|
||||
frameWriter *protocol.FrameWriter
|
||||
latencyCallback LatencyCallback
|
||||
heartbeatSentAt time.Time
|
||||
heartbeatMu sync.Mutex
|
||||
lastLatency time.Duration
|
||||
handlerWg sync.WaitGroup // Tracks active data frame handlers
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
|
||||
// Worker pool for handling data frames
|
||||
dataFrameQueue chan *protocol.Frame
|
||||
workerCount int
|
||||
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
}
|
||||
|
||||
// ConnectorConfig holds connector configuration
|
||||
type ConnectorConfig struct {
|
||||
ServerAddr string
|
||||
Token string
|
||||
TunnelType protocol.TunnelType
|
||||
LocalHost string // Local host address (default: 127.0.0.1)
|
||||
LocalHost string
|
||||
LocalPort int
|
||||
Subdomain string // Optional custom subdomain
|
||||
Insecure bool // Skip TLS verification (testing only)
|
||||
Subdomain string
|
||||
Insecure bool
|
||||
|
||||
PoolSize int
|
||||
PoolMin int
|
||||
PoolMax int
|
||||
}
|
||||
|
||||
// NewConnector creates a new connector
|
||||
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg.Insecure {
|
||||
tlsConfig = config.GetClientTLSConfigInsecure()
|
||||
} else {
|
||||
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
||||
tlsConfig = config.GetClientTLSConfig(host)
|
||||
}
|
||||
|
||||
localHost := cfg.LocalHost
|
||||
if localHost == "" {
|
||||
localHost = "127.0.0.1"
|
||||
}
|
||||
|
||||
numCPU := pool.NumCPU()
|
||||
workerCount := max(numCPU+numCPU/2, 4)
|
||||
|
||||
panicMetrics := recovery.NewPanicMetrics(logger, nil)
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
return &Connector{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: cfg.TunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
|
||||
workerCount: workerCount,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
}
|
||||
type TunnelClient interface {
|
||||
Connect() error
|
||||
Close() error
|
||||
Wait()
|
||||
GetURL() string
|
||||
GetSubdomain() string
|
||||
SetLatencyCallback(cb LatencyCallback)
|
||||
GetLatency() time.Duration
|
||||
GetStats() *stats.TrafficStats
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
// Connect connects to the server and registers the tunnel
|
||||
func (c *Connector) Connect() error {
|
||||
c.logger.Info("Connecting to server",
|
||||
zap.String("server", c.serverAddr),
|
||||
zap.String("tunnel_type", string(c.tunnelType)),
|
||||
zap.String("local_host", c.localHost),
|
||||
zap.Int("local_port", c.localPort),
|
||||
)
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
conn.Close()
|
||||
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
c.logger.Info("TLS connection established",
|
||||
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
|
||||
)
|
||||
|
||||
if err := c.register(); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
bufferPool := pool.NewBufferPool()
|
||||
|
||||
c.frameHandler = NewFrameHandler(
|
||||
c.conn,
|
||||
c.frameWriter,
|
||||
c.localHost,
|
||||
c.localPort,
|
||||
c.tunnelType,
|
||||
c.logger,
|
||||
c.IsClosed,
|
||||
bufferPool,
|
||||
)
|
||||
|
||||
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
|
||||
|
||||
for i := 0; i < c.workerCount; i++ {
|
||||
c.handlerWg.Add(1)
|
||||
go c.dataFrameWorker(i)
|
||||
}
|
||||
|
||||
go c.frameHandler.WarmupConnectionPool(3)
|
||||
go c.monitorQueuePressure()
|
||||
go c.handleFrames()
|
||||
|
||||
return nil
|
||||
func NewTunnelClient(cfg *ConnectorConfig, logger *zap.Logger) TunnelClient {
|
||||
return NewPoolClient(cfg, logger)
|
||||
}
|
||||
|
||||
// register sends registration request and waits for acknowledgment
|
||||
func (c *Connector) register() error {
|
||||
req := protocol.RegisterRequest{
|
||||
Token: c.token,
|
||||
CustomSubdomain: c.subdomain,
|
||||
TunnelType: c.tunnelType,
|
||||
LocalPort: c.localPort,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
|
||||
err = protocol.WriteFrame(c.conn, regFrame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send registration: %w", err)
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
ackFrame, err := protocol.ReadFrame(c.conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ack: %w", err)
|
||||
}
|
||||
defer ackFrame.Release()
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ackFrame.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
|
||||
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
return fmt.Errorf("registration error")
|
||||
}
|
||||
|
||||
if ackFrame.Type != protocol.FrameTypeRegisterAck {
|
||||
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
|
||||
}
|
||||
|
||||
var resp protocol.RegisterResponse
|
||||
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
c.registered = true
|
||||
c.assignedURL = resp.URL
|
||||
c.subdomain = resp.Subdomain
|
||||
|
||||
c.logger.Info("Tunnel registered successfully",
|
||||
zap.String("subdomain", resp.Subdomain),
|
||||
zap.String("url", resp.URL),
|
||||
zap.Int("remote_port", resp.Port),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connector) dataFrameWorker(workerID int) {
|
||||
defer c.handlerWg.Done()
|
||||
defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID))
|
||||
|
||||
for {
|
||||
select {
|
||||
case frame, ok := <-c.dataFrameQueue:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
sf := protocol.WithFrame(frame)
|
||||
defer sf.Close()
|
||||
defer c.recoverer.Recover("handleDataFrame")
|
||||
|
||||
if err := c.frameHandler.HandleDataFrame(sf.Frame); err != nil {
|
||||
c.logger.Error("Failed to handle data frame",
|
||||
zap.Int("worker_id", workerID),
|
||||
zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames from server
|
||||
func (c *Connector) handleFrames() {
|
||||
defer c.Close()
|
||||
defer c.recoverer.Recover("handleFrames")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(c.conn)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
c.logger.Warn("Read timeout")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
c.logger.Error("Failed to read frame", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeHeartbeatAck:
|
||||
c.heartbeatMu.Lock()
|
||||
if !c.heartbeatSentAt.IsZero() {
|
||||
latency := time.Since(c.heartbeatSentAt)
|
||||
c.lastLatency = latency
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
|
||||
|
||||
if c.latencyCallback != nil {
|
||||
c.latencyCallback(latency)
|
||||
}
|
||||
} else {
|
||||
c.heartbeatMu.Unlock()
|
||||
c.logger.Debug("Received heartbeat ack")
|
||||
}
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
select {
|
||||
case c.dataFrameQueue <- sf.Frame:
|
||||
case <-c.stopCh:
|
||||
sf.Close()
|
||||
return
|
||||
default:
|
||||
c.logger.Warn("Data frame queue full, dropping frame")
|
||||
sf.Close()
|
||||
}
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Server requested close")
|
||||
return
|
||||
|
||||
case protocol.FrameTypeError:
|
||||
var errMsg protocol.ErrorMessage
|
||||
if err := json.Unmarshal(sf.Frame.Payload, &errMsg); err == nil {
|
||||
c.logger.Error("Received error from server",
|
||||
zap.String("code", errMsg.Code),
|
||||
zap.String("message", errMsg.Message),
|
||||
)
|
||||
}
|
||||
sf.Close()
|
||||
return
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
|
||||
c.closedMu.RLock()
|
||||
if c.closed {
|
||||
c.closedMu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
c.closedMu.RUnlock()
|
||||
|
||||
c.heartbeatMu.Lock()
|
||||
c.heartbeatSentAt = time.Now()
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the server
|
||||
func (c *Connector) SendFrame(frame *protocol.Frame) error {
|
||||
if !c.registered {
|
||||
return fmt.Errorf("not registered")
|
||||
}
|
||||
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
func (c *Connector) Close() error {
|
||||
c.once.Do(func() {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
|
||||
close(c.stopCh)
|
||||
close(c.dataFrameQueue)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.handlerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
c.logger.Warn("Force closing: some handlers are still active")
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.WriteFrame(closeFrame)
|
||||
c.frameWriter.Close()
|
||||
} else {
|
||||
protocol.WriteFrame(c.conn, closeFrame)
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
}
|
||||
c.logger.Info("Connector closed")
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait blocks until connection is closed
|
||||
func (c *Connector) Wait() {
|
||||
<-c.stopCh
|
||||
}
|
||||
|
||||
// GetURL returns the assigned tunnel URL
|
||||
func (c *Connector) GetURL() string {
|
||||
return c.assignedURL
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connector) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// SetLatencyCallback sets the callback for latency updates
|
||||
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
|
||||
c.latencyCallback = cb
|
||||
}
|
||||
|
||||
// GetLatency returns the last measured latency
|
||||
func (c *Connector) GetLatency() time.Duration {
|
||||
c.heartbeatMu.Lock()
|
||||
defer c.heartbeatMu.Unlock()
|
||||
return c.lastLatency
|
||||
}
|
||||
|
||||
// GetStats returns the traffic stats from the frame handler
|
||||
func (c *Connector) GetStats() *TrafficStats {
|
||||
if c.frameHandler != nil {
|
||||
return c.frameHandler.GetStats()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connector has been closed
|
||||
func (c *Connector) IsClosed() bool {
|
||||
c.closedMu.RLock()
|
||||
defer c.closedMu.RUnlock()
|
||||
return c.closed
|
||||
}
|
||||
func (c *Connector) monitorQueuePressure() {
|
||||
defer c.recoverer.Recover("monitorQueuePressure")
|
||||
|
||||
const (
|
||||
pauseThreshold = 0.80
|
||||
resumeThreshold = 0.50
|
||||
checkInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
isPaused := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
queueLen := len(c.dataFrameQueue)
|
||||
queueCap := cap(c.dataFrameQueue)
|
||||
usage := float64(queueLen) / float64(queueCap)
|
||||
|
||||
if usage > pauseThreshold && !isPaused {
|
||||
c.sendFlowControl("*", protocol.FlowControlPause)
|
||||
isPaused = true
|
||||
c.logger.Warn("Queue pressure high, sent pause signal",
|
||||
zap.Int("queue_len", queueLen),
|
||||
zap.Int("queue_cap", queueCap),
|
||||
zap.Float64("usage", usage))
|
||||
} else if usage < resumeThreshold && isPaused {
|
||||
c.sendFlowControl("*", protocol.FlowControlResume)
|
||||
isPaused = false
|
||||
c.logger.Info("Queue pressure normal, sent resume signal",
|
||||
zap.Int("queue_len", queueLen),
|
||||
zap.Int("queue_cap", queueCap),
|
||||
zap.Float64("usage", usage))
|
||||
}
|
||||
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) sendFlowControl(streamID string, action protocol.FlowControlAction) {
|
||||
frame := protocol.NewFlowControlFrame(streamID, action)
|
||||
if err := c.SendFrame(frame); err != nil {
|
||||
c.logger.Error("Failed to send flow control",
|
||||
zap.String("action", string(action)),
|
||||
zap.Error(err))
|
||||
}
|
||||
func isExpectedCloseError(err error) bool {
|
||||
s := err.Error()
|
||||
return strings.Contains(s, "EOF") ||
|
||||
strings.Contains(s, "use of closed") ||
|
||||
strings.Contains(s, "connection reset")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
450
internal/client/tcp/pool_client.go
Normal file
450
internal/client/tcp/pool_client.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/stats"
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
// PoolClient manages a pool of yamux sessions for tunnel connections.
|
||||
type PoolClient struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
token string
|
||||
tunnelType protocol.TunnelType
|
||||
localHost string
|
||||
localPort int
|
||||
subdomain string
|
||||
|
||||
assignedURL string
|
||||
tunnelID string
|
||||
|
||||
minSessions int
|
||||
maxSessions int
|
||||
initialSessions int
|
||||
|
||||
stats *stats.TrafficStats
|
||||
|
||||
httpClient *http.Client
|
||||
|
||||
latencyCallback atomic.Value // LatencyCallback
|
||||
latencyNanos atomic.Int64
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
closed atomic.Bool
|
||||
|
||||
primary *sessionHandle
|
||||
|
||||
mu sync.RWMutex
|
||||
dataSessions map[string]*sessionHandle
|
||||
desiredTotal int
|
||||
lastScale time.Time
|
||||
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewPoolClient creates a new pool client.
|
||||
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg.Insecure {
|
||||
tlsConfig = config.GetClientTLSConfigInsecure()
|
||||
} else {
|
||||
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
||||
tlsConfig = config.GetClientTLSConfig(host)
|
||||
}
|
||||
|
||||
localHost := cfg.LocalHost
|
||||
if localHost == "" {
|
||||
localHost = "127.0.0.1"
|
||||
}
|
||||
|
||||
tunnelType := cfg.TunnelType
|
||||
if tunnelType == "" {
|
||||
tunnelType = protocol.TunnelTypeTCP
|
||||
}
|
||||
|
||||
numCPU := runtime.NumCPU()
|
||||
|
||||
minSessions := cfg.PoolMin
|
||||
if minSessions <= 0 {
|
||||
minSessions = 2
|
||||
}
|
||||
|
||||
maxSessions := cfg.PoolMax
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = max(numCPU*16, minSessions)
|
||||
}
|
||||
if maxSessions < minSessions {
|
||||
maxSessions = minSessions
|
||||
}
|
||||
|
||||
initialSessions := cfg.PoolSize
|
||||
if initialSessions <= 0 {
|
||||
initialSessions = 4
|
||||
}
|
||||
initialSessions = min(max(initialSessions, minSessions), maxSessions)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
c := &PoolClient{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: tunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
minSessions: minSessions,
|
||||
maxSessions: maxSessions,
|
||||
initialSessions: initialSessions,
|
||||
stats: stats.NewTrafficStats(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
dataSessions: make(map[string]*sessionHandle),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
c.httpClient = newLocalHTTPClient(tunnelType)
|
||||
}
|
||||
|
||||
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect establishes the primary connection and starts background workers.
|
||||
func (c *PoolClient) Connect() error {
|
||||
primaryConn, err := c.dialTLS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maxData := max(c.maxSessions-1, 0)
|
||||
req := protocol.RegisterRequest{
|
||||
Token: c.token,
|
||||
CustomSubdomain: c.subdomain,
|
||||
TunnelType: c.tunnelType,
|
||||
LocalPort: c.localPort,
|
||||
ConnectionType: "primary",
|
||||
PoolCapabilities: &protocol.PoolCapabilities{
|
||||
MaxDataConns: maxData,
|
||||
Version: 1,
|
||||
},
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to send registration: %w", err)
|
||||
}
|
||||
|
||||
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
ack, err := protocol.ReadFrame(primaryConn)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to read register ack: %w", err)
|
||||
}
|
||||
defer ack.Release()
|
||||
_ = primaryConn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ack.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("registration error")
|
||||
}
|
||||
if ack.Type != protocol.FrameTypeRegisterAck {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
|
||||
}
|
||||
|
||||
var resp protocol.RegisterResponse
|
||||
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to parse register response: %w", err)
|
||||
}
|
||||
|
||||
c.assignedURL = resp.URL
|
||||
c.subdomain = resp.Subdomain
|
||||
if resp.SupportsDataConn && resp.TunnelID != "" {
|
||||
c.tunnelID = resp.TunnelID
|
||||
}
|
||||
|
||||
yamuxCfg := yamux.DefaultConfig()
|
||||
yamuxCfg.EnableKeepAlive = false
|
||||
yamuxCfg.LogOutput = io.Discard
|
||||
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Server(primaryConn, yamuxCfg)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
|
||||
primary := &sessionHandle{
|
||||
id: "primary",
|
||||
conn: primaryConn,
|
||||
session: session,
|
||||
}
|
||||
primary.touch()
|
||||
c.primary = primary
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
<-c.stopCh
|
||||
}()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.acceptLoop(primary, true)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.sessionWatcher(primary, true)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.pingLoop(primary)
|
||||
|
||||
if c.tunnelID != "" {
|
||||
c.mu.Lock()
|
||||
c.desiredTotal = c.initialSessions
|
||||
c.mu.Unlock()
|
||||
|
||||
c.ensureSessions()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.scalerLoop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
close(c.doneCh)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) dialTLS() (net.Conn, error) {
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
|
||||
defer c.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
stream, err := h.session.Accept()
|
||||
if err != nil {
|
||||
if c.IsClosed() || isExpectedCloseError(err) {
|
||||
return
|
||||
}
|
||||
if isPrimary {
|
||||
c.logger.Debug("Primary session accept failed", zap.Error(err))
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
|
||||
c.removeDataSession(h.id)
|
||||
return
|
||||
}
|
||||
|
||||
h.active.Add(1)
|
||||
h.touch()
|
||||
|
||||
c.stats.AddRequest()
|
||||
c.stats.IncActiveConnections()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.handleStream(h, stream)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
|
||||
defer c.wg.Done()
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-h.session.CloseChan():
|
||||
if isPrimary {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
c.removeDataSession(h.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PoolClient) pingLoop(h *sessionHandle) {
|
||||
defer c.wg.Done()
|
||||
|
||||
const maxConsecutiveFailures = 3
|
||||
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
consecutiveFailures := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
if h.session == nil || h.session.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
latency, err := h.session.Ping()
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
c.logger.Debug("Ping failed",
|
||||
zap.String("session_id", h.id),
|
||||
zap.Int("consecutive_failures", consecutiveFailures),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
c.logger.Warn("Session ping failed too many times, closing",
|
||||
zap.String("session_id", h.id),
|
||||
zap.Int("failures", consecutiveFailures),
|
||||
)
|
||||
if h.id == "primary" {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
c.removeDataSession(h.id)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
consecutiveFailures = 0
|
||||
h.touch()
|
||||
|
||||
c.latencyNanos.Store(int64(latency))
|
||||
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
|
||||
cb(latency)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the client and all sessions.
|
||||
func (c *PoolClient) Close() error {
|
||||
var closeErr error
|
||||
|
||||
c.once.Do(func() {
|
||||
c.closed.Store(true)
|
||||
close(c.stopCh)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
var data []*sessionHandle
|
||||
var primary *sessionHandle
|
||||
|
||||
c.mu.Lock()
|
||||
for _, h := range c.dataSessions {
|
||||
data = append(data, h)
|
||||
}
|
||||
c.dataSessions = make(map[string]*sessionHandle)
|
||||
primary = c.primary
|
||||
c.primary = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, h := range data {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
if h.session != nil {
|
||||
_ = h.session.Close()
|
||||
}
|
||||
if h.conn != nil {
|
||||
_ = h.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if primary != nil {
|
||||
if primary.session != nil {
|
||||
closeErr = primary.session.Close()
|
||||
}
|
||||
if primary.conn != nil {
|
||||
_ = primary.conn.Close()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (c *PoolClient) Wait() { <-c.doneCh }
|
||||
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
||||
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
||||
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
||||
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
||||
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
||||
|
||||
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
|
||||
if cb == nil {
|
||||
cb = func(time.Duration) {}
|
||||
}
|
||||
c.latencyCallback.Store(cb)
|
||||
}
|
||||
253
internal/client/tcp/pool_handler.go
Normal file
253
internal/client/tcp/pool_handler.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// handleStream routes incoming stream to appropriate handler.
|
||||
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
|
||||
defer c.wg.Done()
|
||||
defer func() {
|
||||
h.active.Add(-1)
|
||||
c.stats.DecActiveConnections()
|
||||
}()
|
||||
defer stream.Close()
|
||||
|
||||
switch c.tunnelType {
|
||||
case protocol.TunnelTypeHTTP, protocol.TunnelTypeHTTPS:
|
||||
c.handleHTTPStream(stream)
|
||||
default:
|
||||
c.handleTCPStream(stream)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTCPStream handles raw TCP tunneling.
|
||||
func (c *PoolClient) handleTCPStream(stream net.Conn) {
|
||||
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
|
||||
if err != nil {
|
||||
c.logger.Debug("Dial local failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
defer localConn.Close()
|
||||
|
||||
if tcpConn, ok := localConn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
c.ctx,
|
||||
stream,
|
||||
localConn,
|
||||
pool.SizeLarge,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
)
|
||||
}
|
||||
|
||||
// handleHTTPStream handles HTTP/HTTPS proxy requests.
|
||||
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
cc := netutil.NewCountingConn(stream,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
)
|
||||
|
||||
br := bufio.NewReaderSize(cc, 32*1024)
|
||||
req, err := http.ReadRequest(br)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
|
||||
if httputil.IsWebSocketUpgrade(req) {
|
||||
c.handleWebSocketUpgrade(cc, req)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.ctx)
|
||||
defer cancel()
|
||||
|
||||
scheme := "http"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
targetURL := fmt.Sprintf("%s://%s:%d%s", scheme, c.localHost, c.localPort, req.URL.RequestURI())
|
||||
outReq, err := http.NewRequestWithContext(ctx, req.Method, targetURL, req.Body)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
|
||||
return
|
||||
}
|
||||
|
||||
origHost := req.Host
|
||||
httputil.CopyHeaders(outReq.Header, req.Header)
|
||||
httputil.CleanHopByHopHeaders(outReq.Header)
|
||||
|
||||
targetHost := c.localHost
|
||||
if c.localPort != 80 && c.localPort != 443 {
|
||||
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
|
||||
}
|
||||
outReq.Host = targetHost
|
||||
outReq.Header.Set("Host", targetHost)
|
||||
if origHost != "" {
|
||||
outReq.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
outReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
resp, err := c.httpClient.Do(outReq)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Local service unavailable")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
if err := writeResponseHeader(cc, resp); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := resp.Body.Read(buf)
|
||||
if nr > 0 {
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
nw, ew := cc.Write(buf[:nr])
|
||||
if ew != nil || nr != nw {
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}
|
||||
|
||||
// handleWebSocketUpgrade handles WebSocket upgrade requests.
|
||||
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
scheme := "ws"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "wss"
|
||||
}
|
||||
|
||||
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
|
||||
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "WebSocket backend unavailable")
|
||||
return
|
||||
}
|
||||
defer localConn.Close()
|
||||
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
tlsConn := tls.Client(localConn, &tls.Config{InsecureSkipVerify: true})
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "TLS handshake failed")
|
||||
return
|
||||
}
|
||||
localConn = tlsConn
|
||||
}
|
||||
|
||||
req.URL.Scheme = scheme
|
||||
req.URL.Host = targetAddr
|
||||
if err := req.Write(localConn); err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
|
||||
return
|
||||
}
|
||||
|
||||
localBr := bufio.NewReader(localConn)
|
||||
resp, err := http.ReadResponse(localBr, req)
|
||||
if err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to read upgrade response")
|
||||
return
|
||||
}
|
||||
|
||||
if err := resp.Write(cc); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
c.ctx,
|
||||
cc,
|
||||
localConn,
|
||||
pool.SizeLarge,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// newLocalHTTPClient creates an HTTP client for local service requests.
|
||||
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
|
||||
var tlsConfig *tls.Config
|
||||
if tunnelType == protocol.TunnelTypeHTTPS {
|
||||
tlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 2000,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
MaxConnsPerHost: 0,
|
||||
IdleConnTimeout: 180 * time.Second,
|
||||
DisableCompression: true,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 500 * time.Millisecond,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func writeResponseHeader(w io.Writer, resp *http.Response) error {
|
||||
statusLine := fmt.Sprintf("HTTP/%d.%d %d %s\r\n",
|
||||
resp.ProtoMajor, resp.ProtoMinor,
|
||||
resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||
if _, err := io.WriteString(w, statusLine); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := resp.Header.Write(w); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
312
internal/client/tcp/pool_session.go
Normal file
312
internal/client/tcp/pool_session.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
// sessionHandle wraps a yamux session with metadata.
|
||||
type sessionHandle struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
session *yamux.Session
|
||||
active atomic.Int64
|
||||
lastActive atomic.Int64 // unix nanos
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (h *sessionHandle) touch() {
|
||||
h.lastActive.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (h *sessionHandle) lastActiveTime() time.Time {
|
||||
n := h.lastActive.Load()
|
||||
if n == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(0, n)
|
||||
}
|
||||
|
||||
// scalerLoop monitors load and adjusts session count.
|
||||
func (c *PoolClient) scalerLoop() {
|
||||
defer c.wg.Done()
|
||||
|
||||
const (
|
||||
checkInterval = 5 * time.Second
|
||||
scaleUpCooldown = 5 * time.Second
|
||||
scaleDownCooldown = 60 * time.Second
|
||||
capacityPerSession = int64(64)
|
||||
scaleUpLoad = 0.7
|
||||
scaleDownLoad = 0.3
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
desired := c.desiredTotal
|
||||
if desired == 0 {
|
||||
desired = c.initialSessions
|
||||
c.desiredTotal = desired
|
||||
}
|
||||
lastScale := c.lastScale
|
||||
c.mu.Unlock()
|
||||
|
||||
current := c.sessionCount()
|
||||
if current <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
active := c.stats.GetActiveConnections()
|
||||
load := float64(active) / float64(int64(current)*capacityPerSession)
|
||||
|
||||
sinceLastScale := time.Since(lastScale)
|
||||
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
|
||||
c.mu.Lock()
|
||||
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
|
||||
c.lastScale = time.Now()
|
||||
c.mu.Unlock()
|
||||
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
|
||||
c.mu.Lock()
|
||||
c.desiredTotal = max(c.desiredTotal-1, c.minSessions)
|
||||
c.lastScale = time.Now()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
c.ensureSessions()
|
||||
}
|
||||
}
|
||||
|
||||
// ensureSessions adjusts session count to match desired.
|
||||
func (c *PoolClient) ensureSessions() {
|
||||
if c.IsClosed() || c.tunnelID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
desired := c.desiredTotal
|
||||
c.mu.RUnlock()
|
||||
|
||||
desired = min(max(desired, c.minSessions), c.maxSessions)
|
||||
|
||||
current := c.sessionCount()
|
||||
if current < desired {
|
||||
for i := 0; i < desired-current; i++ {
|
||||
if err := c.addDataSession(); err != nil {
|
||||
c.logger.Debug("Add data session failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if current > desired {
|
||||
c.removeIdleSessions(current - desired)
|
||||
}
|
||||
}
|
||||
|
||||
// addDataSession creates a new data session.
|
||||
func (c *PoolClient) addDataSession() error {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
if c.tunnelID == "" {
|
||||
return fmt.Errorf("server does not support data connections")
|
||||
}
|
||||
|
||||
conn, err := c.dialTLS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
|
||||
|
||||
req := protocol.DataConnectRequest{
|
||||
TunnelID: c.tunnelID,
|
||||
Token: c.token,
|
||||
ConnectionID: connID,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to marshal data connect request: %w", err)
|
||||
}
|
||||
|
||||
if err := protocol.WriteFrame(conn, protocol.NewFrame(protocol.FrameTypeDataConnect, payload)); err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to send data connect: %w", err)
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
ack, err := protocol.ReadFrame(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to read data connect ack: %w", err)
|
||||
}
|
||||
defer ack.Release()
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ack.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("data connect error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("data connect error")
|
||||
}
|
||||
if ack.Type != protocol.FrameTypeDataConnectAck {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("unexpected data connect ack frame: %s", ack.Type)
|
||||
}
|
||||
|
||||
var resp protocol.DataConnectResponse
|
||||
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to parse data connect response: %w", err)
|
||||
}
|
||||
if !resp.Accepted {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("data connection rejected: %s", resp.Message)
|
||||
}
|
||||
|
||||
yamuxCfg := yamux.DefaultConfig()
|
||||
yamuxCfg.EnableKeepAlive = false
|
||||
yamuxCfg.LogOutput = io.Discard
|
||||
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Server(conn, yamuxCfg)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
|
||||
h := &sessionHandle{
|
||||
id: connID,
|
||||
conn: conn,
|
||||
session: session,
|
||||
}
|
||||
h.touch()
|
||||
|
||||
c.mu.Lock()
|
||||
c.dataSessions[connID] = h
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.acceptLoop(h, false)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.sessionWatcher(h, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeIdleSessions removes n idle sessions.
|
||||
func (c *PoolClient) removeIdleSessions(n int) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
id string
|
||||
active int64
|
||||
lastActive time.Time
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
candidates := make([]candidate, 0, len(c.dataSessions))
|
||||
for id, h := range c.dataSessions {
|
||||
candidates = append(candidates, candidate{
|
||||
id: id,
|
||||
active: h.active.Load(),
|
||||
lastActive: h.lastActiveTime(),
|
||||
})
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
removed := 0
|
||||
for removed < n {
|
||||
var best candidate
|
||||
found := false
|
||||
for _, cand := range candidates {
|
||||
if cand.active != 0 {
|
||||
continue
|
||||
}
|
||||
if !found || cand.lastActive.Before(best.lastActive) {
|
||||
best = cand
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
if c.removeDataSession(best.id) {
|
||||
removed++
|
||||
}
|
||||
for i := range candidates {
|
||||
if candidates[i].id == best.id {
|
||||
candidates[i].active = 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeDataSession removes a data session by ID.
|
||||
func (c *PoolClient) removeDataSession(id string) bool {
|
||||
var h *sessionHandle
|
||||
|
||||
c.mu.Lock()
|
||||
h = c.dataSessions[id]
|
||||
delete(c.dataSessions, id)
|
||||
c.mu.Unlock()
|
||||
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !h.closed.CompareAndSwap(false, true) {
|
||||
return false
|
||||
}
|
||||
|
||||
if h.session != nil {
|
||||
_ = h.session.Close()
|
||||
}
|
||||
if h.conn != nil {
|
||||
_ = h.conn.Close()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// sessionCount returns the total number of active sessions.
|
||||
func (c *PoolClient) sessionCount() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
count := len(c.dataSessions)
|
||||
if c.primary != nil {
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TrafficStats tracks traffic statistics for a tunnel connection
|
||||
type TrafficStats struct {
|
||||
// Total bytes
|
||||
totalBytesIn int64
|
||||
totalBytesOut int64
|
||||
|
||||
// Request counts
|
||||
totalRequests int64
|
||||
|
||||
// For speed calculation
|
||||
lastBytesIn int64
|
||||
lastBytesOut int64
|
||||
lastTime time.Time
|
||||
speedMu sync.Mutex
|
||||
|
||||
// Current speed (bytes per second)
|
||||
speedIn int64
|
||||
speedOut int64
|
||||
|
||||
// Start time
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewTrafficStats creates a new traffic stats tracker
|
||||
func NewTrafficStats() *TrafficStats {
|
||||
now := time.Now()
|
||||
return &TrafficStats{
|
||||
startTime: now,
|
||||
lastTime: now,
|
||||
}
|
||||
}
|
||||
|
||||
// AddBytesIn adds incoming bytes to the counter
|
||||
func (s *TrafficStats) AddBytesIn(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesIn, n)
|
||||
}
|
||||
|
||||
// AddBytesOut adds outgoing bytes to the counter
|
||||
func (s *TrafficStats) AddBytesOut(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesOut, n)
|
||||
}
|
||||
|
||||
// AddRequest increments the request counter
|
||||
func (s *TrafficStats) AddRequest() {
|
||||
atomic.AddInt64(&s.totalRequests, 1)
|
||||
}
|
||||
|
||||
// GetTotalBytesIn returns total incoming bytes
|
||||
func (s *TrafficStats) GetTotalBytesIn() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesIn)
|
||||
}
|
||||
|
||||
// GetTotalBytesOut returns total outgoing bytes
|
||||
func (s *TrafficStats) GetTotalBytesOut() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesOut)
|
||||
}
|
||||
|
||||
// GetTotalRequests returns total request count
|
||||
func (s *TrafficStats) GetTotalRequests() int64 {
|
||||
return atomic.LoadInt64(&s.totalRequests)
|
||||
}
|
||||
|
||||
// GetTotalBytes returns total bytes (in + out)
|
||||
func (s *TrafficStats) GetTotalBytes() int64 {
|
||||
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
|
||||
}
|
||||
|
||||
// UpdateSpeed calculates current transfer speed
|
||||
// Should be called periodically (e.g., every second)
|
||||
func (s *TrafficStats) UpdateSpeed() {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(s.lastTime).Seconds()
|
||||
if elapsed < 0.1 {
|
||||
return // Avoid division by zero or too frequent updates
|
||||
}
|
||||
|
||||
currentIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
currentOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
|
||||
deltaIn := currentIn - s.lastBytesIn
|
||||
deltaOut := currentOut - s.lastBytesOut
|
||||
|
||||
s.speedIn = int64(float64(deltaIn) / elapsed)
|
||||
s.speedOut = int64(float64(deltaOut) / elapsed)
|
||||
|
||||
s.lastBytesIn = currentIn
|
||||
s.lastBytesOut = currentOut
|
||||
s.lastTime = now
|
||||
}
|
||||
|
||||
// GetSpeedIn returns current incoming speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedIn() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedIn
|
||||
}
|
||||
|
||||
// GetSpeedOut returns current outgoing speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedOut() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedOut
|
||||
}
|
||||
|
||||
// GetUptime returns how long the connection has been active
|
||||
func (s *TrafficStats) GetUptime() time.Duration {
|
||||
return time.Since(s.startTime)
|
||||
}
|
||||
|
||||
// Snapshot returns a snapshot of all stats
|
||||
type StatsSnapshot struct {
|
||||
TotalBytesIn int64
|
||||
TotalBytesOut int64
|
||||
TotalBytes int64
|
||||
TotalRequests int64
|
||||
SpeedIn int64 // bytes per second
|
||||
SpeedOut int64 // bytes per second
|
||||
Uptime time.Duration
|
||||
}
|
||||
|
||||
// GetSnapshot returns a snapshot of current stats
|
||||
func (s *TrafficStats) GetSnapshot() StatsSnapshot {
|
||||
s.speedMu.Lock()
|
||||
speedIn := s.speedIn
|
||||
speedOut := s.speedOut
|
||||
s.speedMu.Unlock()
|
||||
|
||||
totalIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
totalOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
|
||||
return StatsSnapshot{
|
||||
TotalBytesIn: totalIn,
|
||||
TotalBytesOut: totalOut,
|
||||
TotalBytes: totalIn + totalOut,
|
||||
TotalRequests: atomic.LoadInt64(&s.totalRequests),
|
||||
SpeedIn: speedIn,
|
||||
SpeedOut: speedOut,
|
||||
Uptime: time.Since(s.startTime),
|
||||
}
|
||||
}
|
||||
|
||||
// FormatBytes formats bytes to human readable string
|
||||
func FormatBytes(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return formatFloat(float64(bytes)/float64(GB)) + " GB"
|
||||
case bytes >= MB:
|
||||
return formatFloat(float64(bytes)/float64(MB)) + " MB"
|
||||
case bytes >= KB:
|
||||
return formatFloat(float64(bytes)/float64(KB)) + " KB"
|
||||
default:
|
||||
return formatInt(bytes) + " B"
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSpeed formats speed (bytes per second) to human readable string
|
||||
func FormatSpeed(bytesPerSec int64) string {
|
||||
if bytesPerSec == 0 {
|
||||
return "0 B/s"
|
||||
}
|
||||
return FormatBytes(bytesPerSec) + "/s"
|
||||
}
|
||||
|
||||
func formatFloat(f float64) string {
|
||||
if f >= 100 {
|
||||
return formatInt(int64(f))
|
||||
} else if f >= 10 {
|
||||
return formatOneDecimal(f)
|
||||
}
|
||||
return formatTwoDecimal(f)
|
||||
}
|
||||
|
||||
func formatInt(i int64) string {
|
||||
return intToStr(i)
|
||||
}
|
||||
|
||||
func formatOneDecimal(f float64) string {
|
||||
i := int64(f * 10)
|
||||
whole := i / 10
|
||||
frac := i % 10
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func formatTwoDecimal(f float64) string {
|
||||
i := int64(f * 100)
|
||||
whole := i / 100
|
||||
frac := i % 100
|
||||
if frac < 10 {
|
||||
return intToStr(whole) + ".0" + intToStr(frac)
|
||||
}
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func intToStr(i int64) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
if i < 0 {
|
||||
return "-" + intToStr(-i)
|
||||
}
|
||||
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}
|
||||
Reference in New Issue
Block a user